Skip to content

Commit 8cb8b6c

Browse files
kwen2501pytorchmergebot
authored andcommitted
[SymmMem] Skip multicast init if any CUDA call fails (pytorch#168049)
Pull Request resolved: pytorch#168049 Approved by: https://github.com/fduwjj
1 parent 2b92b31 commit 8cb8b6c

File tree

2 files changed

+83
-48
lines changed

2 files changed

+83
-48
lines changed

c10/cuda/driver_api.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,22 @@
2020
} \
2121
} while (0)
2222

23+
#define C10_CUDA_DRIVER_CHECK_GOTO(EXPR, NEXT) \
24+
do { \
25+
CUresult __err = EXPR; \
26+
if (__err != CUDA_SUCCESS) { \
27+
const char* err_str; \
28+
CUresult get_error_str_err [[maybe_unused]] = \
29+
c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \
30+
if (get_error_str_err != CUDA_SUCCESS) { \
31+
TORCH_WARN("CUDA driver error: unknown error"); \
32+
} else { \
33+
TORCH_WARN("CUDA driver error: ", err_str); \
34+
} \
35+
goto NEXT; \
36+
} \
37+
} while (0)
38+
2339
// The integer in the second column specifies the requested CUDA Driver API
2440
// version. The dynamic loader will accept a driver with a newer version, but it
2541
// ensures that the requested symbol exists in *at least* the specified version

torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu

Lines changed: 67 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,11 @@ static void init_multicast_for_block(
517517
using McHandleType =
518518
std::conditional_t<use_fabric_handle, CUmemFabricHandle, int>;
519519
520+
McHandleType invalidator;
521+
std::memset(&invalidator, UINT8_MAX, sizeof(McHandleType));
522+
523+
// Phase 1: export handle (rank 0 only)
524+
McHandleType mc_exported_handle{};
520525
if (rank == 0) {
521526
CUmulticastObjectProp mc_prop{};
522527
mc_prop.numDevices = world_size;
@@ -525,68 +530,82 @@ static void init_multicast_for_block(
525530
526531
// create a multicast object, which acts as a handle that allows multiple
527532
// devices or processes to access the same memory allocation coherently.
528-
auto err = driver_api->cuMulticastCreate_(&mc_handle, &mc_prop);
529-
if (err != CUDA_SUCCESS) {
530-
const char* err_str;
531-
CUresult get_error_str_err = driver_api->cuGetErrorString_(err, &err_str);
532-
if (get_error_str_err != CUDA_SUCCESS) {
533-
err_str = "unknown cuda driver error";
534-
}
535-
LOG(WARNING)
536-
<< "SymmetricMemory: cuMulticastCreate failed with: \"" << err_str
537-
<< "\". Gracefully skipping multicast initialization. "
538-
<< "However, this is unexpected. Please report the issue on GitHub.";
533+
try {
534+
C10_CUDA_DRIVER_CHECK(
535+
driver_api->cuMulticastCreate_(&mc_handle, &mc_prop));
536+
// using the CUDA Driver API to export a multicast object into a POSIX file
537+
// descriptor.
538+
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
539+
&mc_exported_handle, mc_handle, handleType, 0));
540+
} catch (const std::exception& e) {
539541
// Allow peers gracefully skip multicast initialization by sending -1
540-
// TODO: allow graceful skip for fabric
541-
if constexpr (!use_fabric_handle) {
542-
ipc_channel.broadcast_fds(rank, 0, pids, -1);
543-
}
544-
return;
545-
}
546-
547-
McHandleType mc_exported_handle;
548-
// using the CUDA Driver API to export a multicast object into a POSIX file
549-
// descriptor.
550-
C10_CUDA_DRIVER_CHECK(driver_api->cuMemExportToShareableHandle_(
551-
&mc_exported_handle, mc_handle, handleType, 0));
552-
if constexpr (!use_fabric_handle) {
553-
ipc_channel.broadcast_fds(rank, 0, pids, mc_exported_handle);
554-
// Ref count is incremented as soon as SCM_RIGHTS send happens
555-
close(mc_exported_handle);
556-
} else {
557-
// TODO implement storeExchange.broadcast
558-
storeExchange.all_gather(store, rank, world_size, mc_exported_handle);
542+
mc_exported_handle = invalidator;
543+
LOG(WARNING)
544+
<< "SymmetricMemory: fail to export multicast handle.\n"
545+
<< e.what();
559546
}
547+
}
560548
549+
// Phase 2: Exchange handle
550+
McHandleType recv_handle;
551+
if constexpr (!use_fabric_handle) {
552+
recv_handle = ipc_channel.broadcast_fds(rank, 0, pids, mc_exported_handle);
561553
} else {
554+
// TODO implement storeExchange.broadcast
555+
auto gathered_handles = storeExchange.all_gather(store, rank, world_size, mc_exported_handle);
556+
recv_handle = std::move(gathered_handles[0]);
557+
}
558+
559+
// Check exchange result
560+
if (memcmp(&recv_handle, &invalidator, sizeof(McHandleType)) == 0) {
561+
LOG(WARNING) << "Gracefully skipping multicast initialization.";
562+
return;
563+
}
564+
565+
// Flip to true after all CUDA steps finish
566+
bool success_end = false;
567+
568+
// Phase 3: Import handle (non-0 ranks only)
569+
if (rank != 0) {
562570
if constexpr (!use_fabric_handle) {
563-
int mc_fd = ipc_channel.broadcast_fds(rank, 0, pids, -1);
564-
if (mc_fd == -1) {
565-
return;
566-
}
567571
// Convert back to a handle from the broadcasted POSIX file descriptor.
568-
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
572+
C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMemImportFromShareableHandle_(
569573
&mc_handle,
570-
(void*)(uintptr_t)mc_fd,
571-
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
572-
close(mc_fd);
574+
(void*)(uintptr_t)recv_handle,
575+
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), check_all);
573576
} else {
574-
CUmemFabricHandle null_handle{};
575-
auto mc_handles =
576-
storeExchange.all_gather(store, rank, world_size, null_handle);
577-
C10_CUDA_DRIVER_CHECK(driver_api->cuMemImportFromShareableHandle_(
578-
&mc_handle, (void*)&(mc_handles[0]), CU_MEM_HANDLE_TYPE_FABRIC));
577+
C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMemImportFromShareableHandle_(
578+
&mc_handle, (void*)&(recv_handle), CU_MEM_HANDLE_TYPE_FABRIC), check_all);
579579
}
580580
}
581581
582+
// Phase 4: Bind memory
582583
// All rank adds their physical allocation to the multicast object
583-
C10_CUDA_DRIVER_CHECK(
584-
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx));
585-
C10_CUDA_DRIVER_CHECK(driver_api->cuMulticastBindMem_(
586-
mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0));
584+
C10_CUDA_DRIVER_CHECK_GOTO(
585+
driver_api->cuMulticastAddDevice_(mc_handle, block->device_idx), check_all);
586+
C10_CUDA_DRIVER_CHECK_GOTO(driver_api->cuMulticastBindMem_(
587+
mc_handle, 0, block->alloc_ref->handle, 0, block->block_size, 0), check_all);
588+
589+
success_end = true;
587590
591+
check_all:
592+
// Whether all ranks have succeeded
593+
bool all_succeed = true;
594+
auto rank_successes = storeExchange.all_gather(store, rank, world_size, success_end);
595+
for (int r = 0; r < world_size; ++r) {
596+
all_succeed &= rank_successes[r];
597+
}
598+
// Close the file descriptor before exit
599+
if constexpr (!use_fabric_handle) {
600+
close(recv_handle);
601+
}
602+
if (!all_succeed) {
603+
LOG(WARNING) << "Gracefully skipping multicast initialization.";
604+
return;
605+
}
606+
607+
// Phase 5: Map to virtual memory
588608
map_block(&mc_addr, mc_handle, block->block_size, block->device_idx);
589-
storeExchange.barrier(store, rank, world_size);
590609
#endif
591610
}
592611

0 commit comments

Comments
 (0)