diff --git a/comms/torchcomms/nccl/CMakeLists.txt b/comms/torchcomms/nccl/CMakeLists.txt index 37af2f55..9a82b4fa 100644 --- a/comms/torchcomms/nccl/CMakeLists.txt +++ b/comms/torchcomms/nccl/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(torchcomms_comms_nccl MODULE ${TORCHCOMMS_NCCL_SOURCES} ${TORCHCOMMS_CUDA_API_SOURCE} ) +target_compile_definitions(torchcomms_comms_nccl PRIVATE TORCHCOMMS_CONDA_BUILD) set_target_properties(torchcomms_comms_nccl PROPERTIES PREFIX "" OUTPUT_NAME "_comms_nccl" diff --git a/comms/torchcomms/nccl/TorchCommNCCLCCA.cpp b/comms/torchcomms/nccl/TorchCommNCCLCCA.cpp index 5b158656..90df8cd8 100644 --- a/comms/torchcomms/nccl/TorchCommNCCLCCA.cpp +++ b/comms/torchcomms/nccl/TorchCommNCCLCCA.cpp @@ -38,7 +38,11 @@ void CachingAllocatorHookImpl::registerMemPreHook() { int device = c10::cuda::current_device(); // We assume no mem pool and no comm has been created yet, we just loop up the // snapshot of the default pool for the current device. +#ifdef TORCHCOMMS_CONDA_BUILD + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); +#else auto snapshot = c10::cuda::CUDACachingAllocator::snapshot({device, 0}); +#endif for (const auto& segmentInfo : snapshot.segments) { // NOLINTNEXTLINE(performance-no-int-to-ptr) void* addr = reinterpret_cast(segmentInfo.address); diff --git a/comms/torchcomms/ncclx/CMakeLists.txt b/comms/torchcomms/ncclx/CMakeLists.txt index 069f3ba8..f3f67820 100644 --- a/comms/torchcomms/ncclx/CMakeLists.txt +++ b/comms/torchcomms/ncclx/CMakeLists.txt @@ -61,7 +61,11 @@ add_library(torchcomms_comms_ncclx MODULE ${TORCHCOMMS_NCCLX_SOURCES} ${TORCHCOMMS_CUDA_API_SOURCE} ) -target_compile_definitions(torchcomms_comms_ncclx PRIVATE MOCK_SCUBA_DATA CTRAN_DISABLE_TCPDM) +target_compile_definitions(torchcomms_comms_ncclx PRIVATE + MOCK_SCUBA_DATA + CTRAN_DISABLE_TCPDM + TORCHCOMMS_CONDA_BUILD +) set_target_properties(torchcomms_comms_ncclx PROPERTIES PREFIX "" OUTPUT_NAME "_comms_ncclx" diff --git a/comms/torchcomms/ncclx/TorchCommNCCLXCCA.cpp b/comms/torchcomms/ncclx/TorchCommNCCLXCCA.cpp index f832d2ef..9c6c19d8 100644 --- a/comms/torchcomms/ncclx/TorchCommNCCLXCCA.cpp +++ b/comms/torchcomms/ncclx/TorchCommNCCLXCCA.cpp @@ -38,7 +38,11 @@ void CachingAllocatorHookImpl::registerMemPreHook() { int device = c10::cuda::current_device(); // We assume no mem pool and no comm has been created yet, we just loop up the // snapshot of the default pool for the current device. +#ifdef TORCHCOMMS_CONDA_BUILD + auto snapshot = c10::cuda::CUDACachingAllocator::snapshot(); +#else auto snapshot = c10::cuda::CUDACachingAllocator::snapshot({device, 0}); +#endif for (const auto& segmentInfo : snapshot.segments) { // NOLINTNEXTLINE(performance-no-int-to-ptr) void* addr = reinterpret_cast(segmentInfo.address);