@@ -407,6 +407,8 @@ EXPORT_SYMBOL_GPL(vsock_enqueue_accept);
407
407
408
408
static bool vsock_use_local_transport (unsigned int remote_cid )
409
409
{
410
+ lockdep_assert_held (& vsock_register_mutex );
411
+
410
412
if (!transport_local )
411
413
return false;
412
414
@@ -464,6 +466,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
464
466
465
467
remote_flags = vsk -> remote_addr .svm_flags ;
466
468
469
+ mutex_lock (& vsock_register_mutex );
470
+
467
471
switch (sk -> sk_type ) {
468
472
case SOCK_DGRAM :
469
473
new_transport = transport_dgram ;
@@ -479,12 +483,15 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
479
483
new_transport = transport_h2g ;
480
484
break ;
481
485
default :
482
- return - ESOCKTNOSUPPORT ;
486
+ ret = - ESOCKTNOSUPPORT ;
487
+ goto err ;
483
488
}
484
489
485
490
if (vsk -> transport ) {
486
- if (vsk -> transport == new_transport )
487
- return 0 ;
491
+ if (vsk -> transport == new_transport ) {
492
+ ret = 0 ;
493
+ goto err ;
494
+ }
488
495
489
496
/* transport->release() must be called with sock lock acquired.
490
497
* This path can only be taken during vsock_connect(), where we
@@ -508,8 +515,16 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
508
515
/* We increase the module refcnt to prevent the transport unloading
509
516
* while there are open sockets assigned to it.
510
517
*/
511
- if (!new_transport || !try_module_get (new_transport -> module ))
512
- return - ENODEV ;
518
+ if (!new_transport || !try_module_get (new_transport -> module )) {
519
+ ret = - ENODEV ;
520
+ goto err ;
521
+ }
522
+
523
+ /* It's safe to release the mutex after a successful try_module_get().
524
+ * Whichever transport `new_transport` points at, it won't go away until
525
+ * the last module_put() below or in vsock_deassign_transport().
526
+ */
527
+ mutex_unlock (& vsock_register_mutex );
513
528
514
529
if (sk -> sk_type == SOCK_SEQPACKET ) {
515
530
if (!new_transport -> seqpacket_allow ||
@@ -528,12 +543,31 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk)
528
543
vsk -> transport = new_transport ;
529
544
530
545
return 0 ;
546
+ err :
547
+ mutex_unlock (& vsock_register_mutex );
548
+ return ret ;
531
549
}
532
550
EXPORT_SYMBOL_GPL (vsock_assign_transport );
533
551
552
+ /*
553
+ * Provide safe access to static transport_{h2g,g2h,dgram,local} callbacks.
554
+ * Otherwise we may race with module removal. Do not use on `vsk->transport`.
555
+ */
556
+ static u32 vsock_registered_transport_cid (const struct vsock_transport * * transport )
557
+ {
558
+ u32 cid = VMADDR_CID_ANY ;
559
+
560
+ mutex_lock (& vsock_register_mutex );
561
+ if (* transport )
562
+ cid = (* transport )-> get_local_cid ();
563
+ mutex_unlock (& vsock_register_mutex );
564
+
565
+ return cid ;
566
+ }
567
+
534
568
bool vsock_find_cid (unsigned int cid )
535
569
{
536
- if (transport_g2h && cid == transport_g2h -> get_local_cid ( ))
570
+ if (cid == vsock_registered_transport_cid ( & transport_g2h ))
537
571
return true;
538
572
539
573
if (transport_h2g && cid == VMADDR_CID_HOST )
@@ -2536,18 +2570,19 @@ static long vsock_dev_do_ioctl(struct file *filp,
2536
2570
unsigned int cmd , void __user * ptr )
2537
2571
{
2538
2572
u32 __user * p = ptr ;
2539
- u32 cid = VMADDR_CID_ANY ;
2540
2573
int retval = 0 ;
2574
+ u32 cid ;
2541
2575
2542
2576
switch (cmd ) {
2543
2577
case IOCTL_VM_SOCKETS_GET_LOCAL_CID :
2544
2578
/* To be compatible with the VMCI behavior, we prioritize the
2545
2579
* guest CID instead of well-know host CID (VMADDR_CID_HOST).
2546
2580
*/
2547
- if (transport_g2h )
2548
- cid = transport_g2h -> get_local_cid ();
2549
- else if (transport_h2g )
2550
- cid = transport_h2g -> get_local_cid ();
2581
+ cid = vsock_registered_transport_cid (& transport_g2h );
2582
+ if (cid == VMADDR_CID_ANY )
2583
+ cid = vsock_registered_transport_cid (& transport_h2g );
2584
+ if (cid == VMADDR_CID_ANY )
2585
+ cid = vsock_registered_transport_cid (& transport_local );
2551
2586
2552
2587
if (put_user (cid , p ) != 0 )
2553
2588
retval = - EFAULT ;
0 commit comments