Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/include/registered_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct RegisteredMemory::Impl {
bool isCuMemMapAlloc;
TransportFlags transports;
std::vector<TransportInfo> transportInfos;
std::shared_ptr<void> peerHandle;

// Only used for IB transport
std::unordered_map<Transport, std::unique_ptr<const IbMr>> ibMrMap;
Expand Down
64 changes: 55 additions & 9 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <unistd.h>

#include <algorithm>
#include <cstring>
#include <mscclpp/gpu_utils.hpp>

#include "api.h"
Expand All @@ -29,6 +30,21 @@
} while (false)

namespace {

// Custom hash and equality for cudaIpcMemHandle_t
struct CudaIpcMemHandleHash {
size_t operator()(const cudaIpcMemHandle_t& handle) const {
std::string_view view(handle.reserved, sizeof(handle.reserved));
return std::hash<std::string_view>{}(view);
}
};

struct CudaIpcMemHandleEqual {
bool operator()(const cudaIpcMemHandle_t& lhs, const cudaIpcMemHandle_t& rhs) const noexcept {
return std::memcmp(lhs.reserved, rhs.reserved, sizeof(lhs.reserved)) == 0;
}
};

CUmemAllocationHandleType getNvlsMemHandleType() {
#if (CUDA_NVLS_API_AVAILABLE)
if (mscclpp::detail::nvlsCompatibleMemHandleType & CU_MEM_HANDLE_TYPE_FABRIC) {
Expand All @@ -41,6 +57,43 @@ CUmemAllocationHandleType getNvlsMemHandleType() {
#endif
}

std::shared_ptr<void> getPeerMemoryHandle(cudaIpcMemHandle_t ipcHandle) {
void* addr;
auto deleter = [](void* p) {
cudaError_t err = cudaIpcCloseMemHandle(p);
if (err != cudaSuccess) {
WARN("Failed to close CUDA IPC handle at pointer %p: %s", p, cudaGetErrorString(err));
} else {
INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", p);
}
};
#if defined(__HIP_PLATFORM_AMD__)
static std::unordered_map<cudaIpcMemHandle_t, std::weak_ptr<void>, CudaIpcMemHandleHash, CudaIpcMemHandleEqual>
peerMemoryHandleMap;
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
auto it = peerMemoryHandleMap.find(ipcHandle);
if (it != peerMemoryHandleMap.end()) {
if (auto ptr = it->second.lock()) {
return ptr;
} else {
peerMemoryHandleMap.erase(it);
}
}
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&addr, ipcHandle, cudaIpcMemLazyEnablePeerAccess));
std::shared_ptr<void> ptr = std::shared_ptr<void>(addr, [ipcHandle, deleter](void* p) {
deleter(p);
std::lock_guard<std::mutex> lock(mutex);
peerMemoryHandleMap.erase(ipcHandle);
});
peerMemoryHandleMap[ipcHandle] = ptr;
return ptr;
#else
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&addr, ipcHandle, cudaIpcMemLazyEnablePeerAccess));
return std::shared_ptr<void>(addr, deleter);
#endif
}

} // namespace

namespace mscclpp {
Expand Down Expand Up @@ -256,8 +309,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>::const_iterator& begin,
throw Error("Unexpected error", ErrorCode::InternalError);
#endif // !(CUDA_NVLS_API_AVAILABLE)
} else if (getHostHash() == this->hostHash) {
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
this->data = static_cast<char*>(base) + entry.cudaIpcOffsetFromBase;
this->peerHandle = getPeerMemoryHandle(entry.cudaIpcBaseHandle);
this->data = static_cast<char*>(this->peerHandle.get()) + entry.cudaIpcOffsetFromBase;
}
}
if (this->data != nullptr) {
Expand Down Expand Up @@ -291,13 +344,6 @@ RegisteredMemory::Impl::~Impl() {
MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size));
MSCCLPP_CULOG_WARN(cuMemRelease(handle));
MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size));
} else {
cudaError_t err = cudaIpcCloseMemHandle(base);
if (err != cudaSuccess) {
WARN("Failed to close CUDA IPC handle at pointer %p: %s", base, cudaGetErrorString(err));
} else {
INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", base);
}
}
data = nullptr;
fileDesc = -1;
Expand Down
1 change: 0 additions & 1 deletion test/mp_unit/executor_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ void ExecutorTest::TearDown() {

TEST_F(ExecutorTest, TwoNodesAllreduce) {
if (gEnv->worldSize != 2 || gEnv->nRanksPerNode != 2) {
GTEST_SKIP() << "This test requires world size to be 2 and ranks per node to be 2";
return;
}
std::string executablePath = getExecutablePath();
Expand Down
Loading