@@ -407,6 +407,8 @@ EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
407407
408408static bool vsock_use_local_transport (unsigned int remote_cid )
409409{
410+ lockdep_assert_held (& vsock_register_mutex );
411+
410412 if (!transport_local )
411413 return false;
412414
@@ -464,6 +466,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
464466
465467 remote_flags = vsk -> remote_addr .svm_flags ;
466468
469+ mutex_lock (& vsock_register_mutex );
470+
467471 switch (sk -> sk_type ) {
468472 case SOCK_DGRAM :
469473 new_transport = transport_dgram ;
@@ -479,12 +483,15 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
479483 new_transport = transport_h2g ;
480484 break ;
481485 default :
482- return - ESOCKTNOSUPPORT ;
486+ ret = - ESOCKTNOSUPPORT ;
487+ goto err ;
483488 }
484489
485490 if (vsk -> transport ) {
486- if (vsk -> transport == new_transport )
487- return 0 ;
491+ if (vsk -> transport == new_transport ) {
492+ ret = 0 ;
493+ goto err ;
494+ }
488495
489496 /* transport->release() must be called with sock lock acquired.
490497 * This path can only be taken during vsock_connect(), where we
@@ -508,8 +515,16 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
508515 /* We increase the module refcnt to prevent the transport unloading
509516 * while there are open sockets assigned to it.
510517 */
511- if (!new_transport || !try_module_get (new_transport -> module ))
512- return - ENODEV ;
518+ if (!new_transport || !try_module_get (new_transport -> module )) {
519+ ret = - ENODEV ;
520+ goto err ;
521+ }
522+
523+ /* It's safe to release the mutex after a successful try_module_get().
524+ * Whichever transport `new_transport` points at, it won't go away until
525+ * the last module_put() below or in vsock_deassign_transport().
526+ */
527+ mutex_unlock (& vsock_register_mutex );
513528
514529 if (sk -> sk_type == SOCK_SEQPACKET ) {
515530 if (!new_transport -> seqpacket_allow ||
@@ -528,12 +543,31 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
528543 vsk -> transport = new_transport ;
529544
530545 return 0 ;
546+ err :
547+ mutex_unlock (& vsock_register_mutex );
548+ return ret ;
531549}
532550EXPORT_SYMBOL_GPL (vsock_assign_transport );
533551
552+ /*
553+ * Provide safe access to static transport_{h2g,g2h,dgram,local} callbacks.
554+ * Otherwise we may race with module removal. Do not use on `vsk->transport`.
555+ */
556+ static u32 vsock_registered_transport_cid (const struct vsock_transport * * transport )
557+ {
558+ u32 cid = VMADDR_CID_ANY ;
559+
560+ mutex_lock (& vsock_register_mutex );
561+ if (* transport )
562+ cid = (* transport )-> get_local_cid ();
563+ mutex_unlock (& vsock_register_mutex );
564+
565+ return cid ;
566+ }
567+
534568bool vsock_find_cid (unsigned int cid )
535569{
536- if (transport_g2h && cid == transport_g2h -> get_local_cid ( ))
570+ if (cid == vsock_registered_transport_cid ( & transport_g2h ))
537571 return true;
538572
539573 if (transport_h2g && cid == VMADDR_CID_HOST )
@@ -2536,18 +2570,19 @@ static long vsock_dev_do_ioctl(struct file *filp,
25362570 unsigned int cmd , void __user * ptr )
25372571{
25382572 u32 __user * p = ptr ;
2539- u32 cid = VMADDR_CID_ANY ;
25402573 int retval = 0 ;
2574+ u32 cid ;
25412575
25422576 switch (cmd ) {
25432577 case IOCTL_VM_SOCKETS_GET_LOCAL_CID :
25442578 /* To be compatible with the VMCI behavior, we prioritize the
25452579 * guest CID instead of well-know host CID (VMADDR_CID_HOST).
25462580 */
2547- if (transport_g2h )
2548- cid = transport_g2h -> get_local_cid ();
2549- else if (transport_h2g )
2550- cid = transport_h2g -> get_local_cid ();
2581+ cid = vsock_registered_transport_cid (& transport_g2h );
2582+ if (cid == VMADDR_CID_ANY )
2583+ cid = vsock_registered_transport_cid (& transport_h2g );
2584+ if (cid == VMADDR_CID_ANY )
2585+ cid = vsock_registered_transport_cid (& transport_local );
25512586
25522587 if (put_user (cid , p ) != 0 )
25532588 retval = - EFAULT ;
0 commit comments