@@ -517,11 +517,6 @@ static void init_multicast_for_block(
517517 using McHandleType =
518518 std::conditional_t <use_fabric_handle, CUmemFabricHandle, int >;
519519
520- McHandleType invalidator;
521- std::memset (&invalidator, UINT8_MAX, sizeof (McHandleType));
522-
523- // Phase 1: export handle (rank 0 only)
524- McHandleType mc_exported_handle{};
525520 if (rank == 0 ) {
526521 CUmulticastObjectProp mc_prop{};
527522 mc_prop.numDevices = world_size;
@@ -530,82 +525,68 @@ static void init_multicast_for_block(
530525
531526 // create a multicast object, which acts as a handle that allows multiple
532527 // devices or processes to access the same memory allocation coherently.
533- try {
534- C10_CUDA_DRIVER_CHECK (
535- driver_api->cuMulticastCreate_ (&mc_handle, &mc_prop));
536- // using the CUDA Driver API to export a multicast object into a POSIX file
537- // descriptor.
538- C10_CUDA_DRIVER_CHECK (driver_api->cuMemExportToShareableHandle_ (
539- &mc_exported_handle, mc_handle, handleType, 0 ));
540- } catch (const std::exception& e) {
541- // Allow peers gracefully skip multicast initialization by sending -1
542- mc_exported_handle = invalidator;
528+ auto err = driver_api->cuMulticastCreate_ (&mc_handle, &mc_prop);
529+ if (err != CUDA_SUCCESS) {
530+ const char * err_str;
531+ CUresult get_error_str_err = driver_api->cuGetErrorString_ (err, &err_str);
532+ if (get_error_str_err != CUDA_SUCCESS) {
533+ err_str = " unknown cuda driver error" ;
534+ }
543535 LOG (WARNING)
544- << " SymmetricMemory: fail to export multicast handle.\n "
545- << e.what ();
536+ << " SymmetricMemory: cuMulticastCreate failed with: \" " << err_str
537+ << " \" . Gracefully skipping multicast initialization. "
538+ << " However, this is unexpected. Please report the issue on GitHub." ;
539+ // Allow peers gracefully skip multicast initialization by sending -1
540+ // TODO: allow graceful skip for fabric
541+ if constexpr (!use_fabric_handle) {
542+ ipc_channel.broadcast_fds (rank, 0 , pids, -1 );
543+ }
544+ return ;
546545 }
547- }
548-
549- // Phase 2: Exchange handle
550- McHandleType recv_handle;
551- if constexpr (!use_fabric_handle) {
552- recv_handle = ipc_channel.broadcast_fds (rank, 0 , pids, mc_exported_handle);
553- } else {
554- // TODO implement storeExchange.broadcast
555- auto gathered_handles = storeExchange.all_gather (store, rank, world_size, mc_exported_handle);
556- recv_handle = std::move (gathered_handles[0 ]);
557- }
558-
559- // Check exchange result
560- if (memcmp (&recv_handle, &invalidator, sizeof (McHandleType)) == 0 ) {
561- LOG (WARNING) << " Gracefully skipping multicast initialization." ;
562- return ;
563- }
564546
565- // Flip to true after all CUDA steps finish
566- bool success_end = false ;
547+ McHandleType mc_exported_handle;
548+ // using the CUDA Driver API to export a multicast object into a POSIX file
549+ // descriptor.
550+ C10_CUDA_DRIVER_CHECK (driver_api->cuMemExportToShareableHandle_ (
551+ &mc_exported_handle, mc_handle, handleType, 0 ));
552+ if constexpr (!use_fabric_handle) {
553+ ipc_channel.broadcast_fds (rank, 0 , pids, mc_exported_handle);
554+ // Ref count is incremented as soon as SCM_RIGHTS send happens
555+ close (mc_exported_handle);
556+ } else {
557+ // TODO implement storeExchange.broadcast
558+ storeExchange.all_gather (store, rank, world_size, mc_exported_handle);
559+ }
567560
568- // Phase 3: Import handle (non-0 ranks only)
569- if (rank != 0 ) {
561+ } else {
570562 if constexpr (!use_fabric_handle) {
563+ int mc_fd = ipc_channel.broadcast_fds (rank, 0 , pids, -1 );
564+ if (mc_fd == -1 ) {
565+ return ;
566+ }
571567 // Convert back to a handle from the broadcasted POSIX file descriptor.
572- C10_CUDA_DRIVER_CHECK_GOTO (driver_api->cuMemImportFromShareableHandle_ (
568+ C10_CUDA_DRIVER_CHECK (driver_api->cuMemImportFromShareableHandle_ (
573569 &mc_handle,
574- (void *)(uintptr_t )recv_handle,
575- CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), check_all);
570+ (void *)(uintptr_t )mc_fd,
571+ CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
572+ close (mc_fd);
576573 } else {
577- C10_CUDA_DRIVER_CHECK_GOTO (driver_api->cuMemImportFromShareableHandle_ (
578- &mc_handle, (void *)&(recv_handle), CU_MEM_HANDLE_TYPE_FABRIC), check_all);
574+ CUmemFabricHandle null_handle{};
575+ auto mc_handles =
576+ storeExchange.all_gather (store, rank, world_size, null_handle);
577+ C10_CUDA_DRIVER_CHECK (driver_api->cuMemImportFromShareableHandle_ (
578+ &mc_handle, (void *)&(mc_handles[0 ]), CU_MEM_HANDLE_TYPE_FABRIC));
579579 }
580580 }
581581
582- // Phase 4: Bind memory
583582 // All rank adds their physical allocation to the multicast object
584- C10_CUDA_DRIVER_CHECK_GOTO (
585- driver_api->cuMulticastAddDevice_ (mc_handle, block->device_idx ), check_all);
586- C10_CUDA_DRIVER_CHECK_GOTO (driver_api->cuMulticastBindMem_ (
587- mc_handle, 0 , block->alloc_ref ->handle , 0 , block->block_size , 0 ), check_all);
588-
589- success_end = true ;
590-
591- check_all:
592- // Whether all ranks have succeeded
593- bool all_succeed = true ;
594- auto rank_successes = storeExchange.all_gather (store, rank, world_size, success_end);
595- for (int r = 0 ; r < world_size; ++r) {
596- all_succeed &= rank_successes[r];
597- }
598- // Close the file descriptor before exit
599- if constexpr (!use_fabric_handle) {
600- close (recv_handle);
601- }
602- if (!all_succeed) {
603- LOG (WARNING) << " Gracefully skipping multicast initialization." ;
604- return ;
605- }
583+ C10_CUDA_DRIVER_CHECK (
584+ driver_api->cuMulticastAddDevice_ (mc_handle, block->device_idx ));
585+ C10_CUDA_DRIVER_CHECK (driver_api->cuMulticastBindMem_ (
586+ mc_handle, 0 , block->alloc_ref ->handle , 0 , block->block_size , 0 ));
606587
607- // Phase 5: Map to virtual memory
608588 map_block (&mc_addr, mc_handle, block->block_size , block->device_idx );
589+ storeExchange.barrier (store, rank, world_size);
609590#endif
610591}
611592
0 commit comments