@@ -517,6 +517,11 @@ static void init_multicast_for_block(
517517 using McHandleType =
518518 std::conditional_t <use_fabric_handle, CUmemFabricHandle, int >;
519519
520+ McHandleType invalidator;
521+ std::memset (&invalidator, UINT8_MAX, sizeof (McHandleType));
522+
523+ // Phase 1: export handle (rank 0 only)
524+ McHandleType mc_exported_handle{};
520525 if (rank == 0 ) {
521526 CUmulticastObjectProp mc_prop{};
522527 mc_prop.numDevices = world_size;
@@ -525,68 +530,82 @@ static void init_multicast_for_block(
525530
526531 // create a multicast object, which acts as a handle that allows multiple
527532 // devices or processes to access the same memory allocation coherently.
528- auto err = driver_api->cuMulticastCreate_ (&mc_handle, &mc_prop);
529- if (err != CUDA_SUCCESS) {
530- const char * err_str;
531- CUresult get_error_str_err = driver_api->cuGetErrorString_ (err, &err_str);
532- if (get_error_str_err != CUDA_SUCCESS) {
533- err_str = " unknown cuda driver error" ;
534- }
535- LOG (WARNING)
536- << " SymmetricMemory: cuMulticastCreate failed with: \" " << err_str
537- << " \" . Gracefully skipping multicast initialization. "
538- << " However, this is unexpected. Please report the issue on GitHub." ;
533+ try {
534+ C10_CUDA_DRIVER_CHECK (
535+ driver_api->cuMulticastCreate_ (&mc_handle, &mc_prop));
536+ // using the CUDA Driver API to export a multicast object into a POSIX file
537+ // descriptor.
538+ C10_CUDA_DRIVER_CHECK (driver_api->cuMemExportToShareableHandle_ (
539+ &mc_exported_handle, mc_handle, handleType, 0 ));
540+ } catch (const std::exception& e) {
539541 // Allow peers gracefully skip multicast initialization by sending -1
540- // TODO: allow graceful skip for fabric
541- if constexpr (!use_fabric_handle) {
542- ipc_channel.broadcast_fds (rank, 0 , pids, -1 );
543- }
544- return ;
545- }
546-
547- McHandleType mc_exported_handle;
548- // using the CUDA Driver API to export a multicast object into a POSIX file
549- // descriptor.
550- C10_CUDA_DRIVER_CHECK (driver_api->cuMemExportToShareableHandle_ (
551- &mc_exported_handle, mc_handle, handleType, 0 ));
552- if constexpr (!use_fabric_handle) {
553- ipc_channel.broadcast_fds (rank, 0 , pids, mc_exported_handle);
554- // Ref count is incremented as soon as SCM_RIGHTS send happens
555- close (mc_exported_handle);
556- } else {
557- // TODO implement storeExchange.broadcast
558- storeExchange.all_gather (store, rank, world_size, mc_exported_handle);
542+ mc_exported_handle = invalidator;
543+ LOG (WARNING)
544+ << " SymmetricMemory: fail to export multicast handle.\n "
545+ << e.what ();
559546 }
547+ }
560548
549+ // Phase 2: Exchange handle
550+ McHandleType recv_handle;
551+ if constexpr (!use_fabric_handle) {
552+ recv_handle = ipc_channel.broadcast_fds (rank, 0 , pids, mc_exported_handle);
561553 } else {
554+ // TODO implement storeExchange.broadcast
555+ auto gathered_handles = storeExchange.all_gather (store, rank, world_size, mc_exported_handle);
556+ recv_handle = std::move (gathered_handles[0 ]);
557+ }
558+
559+ // Check exchange result
560+ if (memcmp (&recv_handle, &invalidator, sizeof (McHandleType)) == 0 ) {
561+ LOG (WARNING) << " Gracefully skipping multicast initialization." ;
562+ return ;
563+ }
564+
565+ // Flip to true after all CUDA steps finish
566+ bool success_end = false ;
567+
568+ // Phase 3: Import handle (non-0 ranks only)
569+ if (rank != 0 ) {
562570 if constexpr (!use_fabric_handle) {
563- int mc_fd = ipc_channel.broadcast_fds (rank, 0 , pids, -1 );
564- if (mc_fd == -1 ) {
565- return ;
566- }
567571 // Convert back to a handle from the broadcasted POSIX file descriptor.
568- C10_CUDA_DRIVER_CHECK (driver_api->cuMemImportFromShareableHandle_ (
572+ C10_CUDA_DRIVER_CHECK_GOTO (driver_api->cuMemImportFromShareableHandle_ (
569573 &mc_handle,
570- (void *)(uintptr_t )mc_fd,
571- CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
572- close (mc_fd);
574+ (void *)(uintptr_t )recv_handle,
575+ CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), check_all);
573576 } else {
574- CUmemFabricHandle null_handle{};
575- auto mc_handles =
576- storeExchange.all_gather (store, rank, world_size, null_handle);
577- C10_CUDA_DRIVER_CHECK (driver_api->cuMemImportFromShareableHandle_ (
578- &mc_handle, (void *)&(mc_handles[0 ]), CU_MEM_HANDLE_TYPE_FABRIC));
577+ C10_CUDA_DRIVER_CHECK_GOTO (driver_api->cuMemImportFromShareableHandle_ (
578+ &mc_handle, (void *)&(recv_handle), CU_MEM_HANDLE_TYPE_FABRIC), check_all);
579579 }
580580 }
581581
582+ // Phase 4: Bind memory
582583 // All rank adds their physical allocation to the multicast object
583- C10_CUDA_DRIVER_CHECK (
584- driver_api->cuMulticastAddDevice_ (mc_handle, block->device_idx ));
585- C10_CUDA_DRIVER_CHECK (driver_api->cuMulticastBindMem_ (
586- mc_handle, 0 , block->alloc_ref ->handle , 0 , block->block_size , 0 ));
584+ C10_CUDA_DRIVER_CHECK_GOTO (
585+ driver_api->cuMulticastAddDevice_ (mc_handle, block->device_idx ), check_all);
586+ C10_CUDA_DRIVER_CHECK_GOTO (driver_api->cuMulticastBindMem_ (
587+ mc_handle, 0 , block->alloc_ref ->handle , 0 , block->block_size , 0 ), check_all);
588+
589+ success_end = true ;
587590
591+ check_all:
592+ // Whether all ranks have succeeded
593+ bool all_succeed = true ;
594+ auto rank_successes = storeExchange.all_gather (store, rank, world_size, success_end);
595+ for (int r = 0 ; r < world_size; ++r) {
596+ all_succeed &= rank_successes[r];
597+ }
598+ // Close the file descriptor before exit
599+ if constexpr (!use_fabric_handle) {
600+ close (recv_handle);
601+ }
602+ if (!all_succeed) {
603+ LOG (WARNING) << " Gracefully skipping multicast initialization." ;
604+ return ;
605+ }
606+
607+ // Phase 5: Map to virtual memory
588608 map_block (&mc_addr, mc_handle, block->block_size , block->device_idx );
589- storeExchange.barrier (store, rank, world_size);
590609#endif
591610}
592611
0 commit comments