Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/mscclpp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _compile_cuda(self, source_file, output_file, std_version="c++17"):
if self.macros:
command += self.macros
try:
subprocess.run(command, capture_output=True, text=True, check=True, bufsize=1)
subprocess.run(command, capture_output=True, text=True, check=True, bufsize=1, stdin=subprocess.DEVNULL)
with open(f"{self._tempdir.name}/{output_file}", "rb") as f:
return f.read()
except subprocess.CalledProcessError as e:
Expand Down
27 changes: 27 additions & 0 deletions python/test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import mpi4py
import os
import sys

mpi4py.rc.initialize = False
mpi4py.rc.finalize = True

import cupy as cp
from mpi4py import MPI


def pytest_configure(config):
"""Initialize MPI before test collection."""
if not MPI.Is_initialized():
MPI.Init()
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
N_GPUS_PER_NODE = shm_comm.size
shm_comm.Free()
cp.cuda.Device(MPI.COMM_WORLD.rank % N_GPUS_PER_NODE).use()

# only print process with rank 0 to avoid bad fd issue
if MPI.COMM_WORLD.rank != 0:
sys.stdout = open(os.devnull, "w")
sys.stderr = open(os.devnull, "w")
60 changes: 23 additions & 37 deletions python/test/mscclpp_mpi.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,16 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT License.

import atexit
import logging

import cupy as cp
import mpi4py

mpi4py.rc.initialize = False
mpi4py.rc.finalize = False

from mpi4py import MPI
import pytest

N_GPUS_PER_NODE = 8

logging.basicConfig(level=logging.INFO)


def init_mpi():
if not MPI.Is_initialized():
MPI.Init()
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
N_GPUS_PER_NODE = shm_comm.size
shm_comm.Free()
cp.cuda.Device(MPI.COMM_WORLD.rank % N_GPUS_PER_NODE).use()


# Define a function to finalize MPI
def finalize_mpi():
if MPI.Is_initialized():
MPI.Finalize()


# Register the function to be called on exit
atexit.register(finalize_mpi)
_mpi_group_cache = {}


class MpiGroup:
Expand All @@ -46,13 +22,25 @@ def __init__(self, ranks: list = []):
group = world_group.Incl(ranks)
self.comm = MPI.COMM_WORLD.Create(group)

def __del__(self):
if self.comm != MPI.COMM_NULL and MPI.Is_initialized() and not MPI.Is_finalized():
self.comm.Free()


@pytest.fixture
def mpi_group(request: pytest.FixtureRequest):
MPI.COMM_WORLD.barrier()
if request.param is None:
pytest.skip(f"Skip for rank {MPI.COMM_WORLD.rank}")
yield request.param

mpi_group_obj = request.param
should_skip = mpi_group_obj.comm == MPI.COMM_NULL

try:
if should_skip:
pytest.skip(f"Skip for rank {MPI.COMM_WORLD.rank}")
yield request.param
finally:
if MPI.Is_initialized() and not MPI.Is_finalized():
MPI.COMM_WORLD.barrier()


def parametrize_mpi_groups(*tuples: tuple):
Expand All @@ -62,14 +50,12 @@ def decorator(func):
if MPI.COMM_WORLD.size < group_size:
logging.warning(f"MPI.COMM_WORLD.size < {group_size}, skip")
continue
mpi_group = MpiGroup(list(range(group_size)))
if mpi_group.comm == MPI.COMM_NULL:
mpi_groups.append(None)
else:
mpi_groups.append(mpi_group)
ranks = list(range(group_size))
ranks_key = tuple(ranks)
if ranks_key not in _mpi_group_cache:
_mpi_group_cache[ranks_key] = MpiGroup(ranks)

mpi_groups.append(_mpi_group_cache[ranks_key])
return pytest.mark.parametrize("mpi_group", mpi_groups, indirect=True)(func)

return decorator


init_mpi()
4 changes: 3 additions & 1 deletion src/include/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace mscclpp {

typedef enum : unsigned int { NONE = 0, DEBUG, INFO, WARN, ERROR } LogLevel;
typedef enum : std::size_t { ENV = 0, NET, CONN, EXEC, NCCL, COUNT } LogSubsys;
typedef enum : std::size_t { ENV = 0, NET, CONN, P2P, EXEC, NCCL, COUNT } LogSubsys;

namespace detail {

Expand Down Expand Up @@ -61,6 +61,8 @@ constexpr std::string_view logSubsysToString(LogSubsys subsys) {
return "NET";
case LogSubsys::CONN:
return "CONN";
case LogSubsys::P2P:
return "P2P";
case LogSubsys::EXEC:
return "EXEC";
case LogSubsys::NCCL:
Expand Down
1 change: 1 addition & 0 deletions src/include/registered_memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct RegisteredMemory::Impl {
bool isCuMemMapAlloc;
TransportFlags transports;
std::vector<TransportInfo> transportInfos;
std::shared_ptr<void> peerMemHandle;

// Only used for IB transport
std::unordered_map<Transport, std::unique_ptr<const IbMr>> ibMrMap;
Expand Down
100 changes: 75 additions & 25 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,45 @@
#include <unistd.h>

#include <algorithm>
#include <cstring>
#include <mscclpp/gpu_utils.hpp>
#include <unordered_map>

#include "api.h"
#include "context.hpp"
#include "debug.h"
#include "logger.hpp"
#include "serialization.hpp"
#include "unix_socket.hpp"
#include "utils_internal.hpp"

#define MSCCLPP_CULOG_WARN(cmd) \
do { \
CUresult err = cmd; \
if (err != CUDA_SUCCESS) { \
const char* errStr; \
if (cuGetErrorString(err, &errStr) != CUDA_SUCCESS) { \
errStr = "failed to get error string"; \
} \
WARN("Call to " #cmd " failed, error is %s", errStr); \
} \
#define MSCCLPP_CULOG_WARN(cmd) \
do { \
CUresult err = cmd; \
if (err != CUDA_SUCCESS) { \
const char* errStr; \
if (cuGetErrorString(err, &errStr) != CUDA_SUCCESS) { \
errStr = "failed to get error string"; \
} \
WARN(mscclpp::P2P, "Call to " #cmd " failed, error is ", errStr); \
} \
} while (false)

namespace {

// Custom hash and equality for cudaIpcMemHandle_t
struct CudaIpcMemHandleHash {
size_t operator()(const cudaIpcMemHandle_t& handle) const {
std::string_view view(handle.reserved, sizeof(handle.reserved));
return std::hash<std::string_view>{}(view);
}
};

struct CudaIpcMemHandleEqual {
bool operator()(const cudaIpcMemHandle_t& lhs, const cudaIpcMemHandle_t& rhs) const noexcept {
return std::memcmp(lhs.reserved, rhs.reserved, sizeof(lhs.reserved)) == 0;
}
};

CUmemAllocationHandleType getNvlsMemHandleType() {
#if (CUDA_NVLS_API_AVAILABLE)
if (mscclpp::detail::nvlsCompatibleMemHandleType & CU_MEM_HANDLE_TYPE_FABRIC) {
Expand All @@ -41,6 +58,46 @@ CUmemAllocationHandleType getNvlsMemHandleType() {
#endif
}

std::shared_ptr<void> getPeerMemoryHandle(cudaIpcMemHandle_t ipcHandle) {
void* addr;
auto deleter = [](void* p) {
cudaError_t err = cudaIpcCloseMemHandle(p);
if (err != cudaSuccess) {
WARN(mscclpp::P2P, "Failed to close CUDA IPC handle at pointer ", std::hex, p, ": ", cudaGetErrorString(err));
} else {
INFO(mscclpp::P2P, "Closed CUDA IPC handle at pointer ", std::hex, p);
}
};
#if defined(__HIP_PLATFORM_AMD__)
// Unlike Nvidia, ROCm will not reuse the same ipc handle for same memory region.
// We cache the opened ipc handles to avoid opening multiple times. (May exceed system limit on vm.max_map_count)
static auto peerMemoryHandleMap = std::make_shared<
std::unordered_map<cudaIpcMemHandle_t, std::weak_ptr<void>, CudaIpcMemHandleHash, CudaIpcMemHandleEqual>>();
static auto mutex = std::make_shared<std::mutex>();
std::lock_guard<std::mutex> lock(*mutex);
auto it = peerMemoryHandleMap->find(ipcHandle);
if (it != peerMemoryHandleMap->end()) {
if (auto ptr = it->second.lock()) {
return ptr;
} else {
peerMemoryHandleMap->erase(it);
}
}
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&addr, ipcHandle, cudaIpcMemLazyEnablePeerAccess));
std::shared_ptr<void> ptr =
std::shared_ptr<void>(addr, [ipcHandle, deleter, m = mutex, map = peerMemoryHandleMap](void* p) {
deleter(p);
std::lock_guard<std::mutex> lock(*m);
map->erase(ipcHandle);
});
peerMemoryHandleMap->emplace(ipcHandle, ptr);
return ptr;
#else
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&addr, ipcHandle, cudaIpcMemLazyEnablePeerAccess));
return std::shared_ptr<void>(addr, deleter);
#endif
}

} // namespace

namespace mscclpp {
Expand Down Expand Up @@ -93,7 +150,7 @@ RegisteredMemory::Impl::Impl(void* data, size_t size, TransportFlags transports,
transportInfo.ibLocal = true;
transportInfo.ibMrInfo = this->ibMrMap[ibTransport]->getInfo();
this->transportInfos.push_back(transportInfo);
INFO(MSCCLPP_NET, "IB mr for address %p with size %ld is registered", data, size);
INFO(NET, "IB mr for address ", data, " with size ", size, " is registered");
};
if (transports.has(Transport::IB0)) addIb(Transport::IB0);
if (transports.has(Transport::IB1)) addIb(Transport::IB1);
Expand Down Expand Up @@ -227,8 +284,8 @@ RegisteredMemory::Impl::Impl(const std::vector<char>::const_iterator& begin,
// TODO: only open handle if in same MNNVL domain
CUresult err = cuMemImportFromShareableHandle(&handle, entry.shareableHandle, getNvlsMemHandleType());
if (err != CUDA_SUCCESS) {
INFO(MSCCLPP_P2P, "Failed to import shareable handle from host: 0x%lx, may not be in the same MNNVL domain",
hostHash);
INFO(P2P, "Failed to import shareable handle from host: 0x", std::hex, hostHash,
", may not be in the same MNNVL domain");
return;
}
} else {
Expand All @@ -237,7 +294,7 @@ RegisteredMemory::Impl::Impl(const std::vector<char>::const_iterator& begin,
} else {
int fd =
UnixSocketClient::instance().requestFd(UnixSocketServer::generateSocketPath(entry.rootPid), entry.rootFd);
INFO(MSCCLPP_P2P, "Get file descriptor %d from peer 0x%lx", fd, hostHash);
INFO(P2P, "Get file descriptor ", fd, " from peer 0x", std::hex, hostHash);
MSCCLPP_CUTHROW(cuMemImportFromShareableHandle(&handle, reinterpret_cast<void*>(fd),
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(fd);
Expand All @@ -256,12 +313,12 @@ RegisteredMemory::Impl::Impl(const std::vector<char>::const_iterator& begin,
throw Error("Unexpected error", ErrorCode::InternalError);
#endif // !(CUDA_NVLS_API_AVAILABLE)
} else if (getHostHash() == this->hostHash) {
MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess));
this->data = static_cast<char*>(base) + entry.cudaIpcOffsetFromBase;
this->peerMemHandle = getPeerMemoryHandle(entry.cudaIpcBaseHandle);
this->data = static_cast<char*>(this->peerMemHandle.get()) + entry.cudaIpcOffsetFromBase;
}
}
if (this->data != nullptr) {
INFO(MSCCLPP_P2P, "Opened CUDA IPC handle at pointer %p", this->data);
INFO(P2P, "Opened CUDA IPC handle at pointer ", this->data);
}
}

Expand Down Expand Up @@ -291,13 +348,6 @@ RegisteredMemory::Impl::~Impl() {
MSCCLPP_CULOG_WARN(cuMemUnmap((CUdeviceptr)base, size));
MSCCLPP_CULOG_WARN(cuMemRelease(handle));
MSCCLPP_CULOG_WARN(cuMemAddressFree((CUdeviceptr)base, size));
} else {
cudaError_t err = cudaIpcCloseMemHandle(base);
if (err != cudaSuccess) {
WARN("Failed to close CUDA IPC handle at pointer %p: %s", base, cudaGetErrorString(err));
} else {
INFO(MSCCLPP_P2P, "Closed CUDA IPC handle at pointer %p", base);
}
}
data = nullptr;
fileDesc = -1;
Expand Down
8 changes: 3 additions & 5 deletions test/mp_unit/executor_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ std::string getExecutablePath() {
} // namespace

void ExecutorTest::SetUp() {
if (gEnv->worldSize != 2 || gEnv->nRanksPerNode != 2) {
GTEST_SKIP() << "This test requires world size to be 2 and ranks per node to be 2";
}
MultiProcessTest::SetUp();

MSCCLPP_CUDATHROW(cudaSetDevice(rankToLocalRank(gEnv->rank)));
Expand All @@ -43,15 +46,10 @@ void ExecutorTest::TearDown() {
NpKit::Dump(npkitDumpDir);
NpKit::Shutdown();
}
executor.reset();
MultiProcessTest::TearDown();
}

TEST_F(ExecutorTest, TwoNodesAllreduce) {
if (gEnv->worldSize != 2 || gEnv->nRanksPerNode != 2) {
GTEST_SKIP() << "This test requires world size to be 2 and ranks per node to be 2";
return;
}
std::string executablePath = getExecutablePath();
std::filesystem::path path = executablePath;
std::filesystem::path executionFilesPath =
Expand Down
Loading