Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
cpplint:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04

steps:
- name: Check out Git repository
Expand All @@ -24,7 +24,7 @@ jobs:
clang-format -style=file --verbose --Werror --dry-run ${CPPSOURCES}

pylint:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04

steps:
- name: Check out Git repository
Expand All @@ -42,7 +42,7 @@ jobs:
run: python3 -m black --check --config pyproject.toml .

spelling:
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04

steps:
- name: Check out Git repository
Expand Down
11 changes: 11 additions & 0 deletions include/mscclpp/gpu_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ void gpuMemcpy(T* dst, const T* src, size_t nelems, cudaMemcpyKind kind = cudaMe
/// @return True if NVLink SHARP (NVLS) is supported, false otherwise.
bool isNvlsSupported();

/// Check if ptr is allocaed by cuMemMap
/// @param ptr The pointer to check.
/// @return True if the pointer is allocated by cuMemMap, false otherwise.
bool isCuMemMapAllocated([[maybe_unused]] void* ptr);

/// Allocates a GPU memory space specialized for communication. The memory is zeroed out. Get the device pointer by
/// `GpuBuffer::data()`.
///
Expand All @@ -224,6 +229,7 @@ class GpuBuffer {
bytes_ = 0;
return;
}
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId_));
#if (CUDA_NVLS_SUPPORTED)
if (isNvlsSupported()) {
size_t gran = detail::getMulticastGranularity(nelems * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED);
Expand Down Expand Up @@ -259,9 +265,14 @@ class GpuBuffer {
/// @return A device pointer to the allocated memory.
T* data() { return memory_.get(); }

/// Returns the device id of the allocated memory.
/// @return The device id.
int deviceId() const { return deviceId_; }

private:
size_t nelems_;
size_t bytes_;
int deviceId_;
std::shared_ptr<T> memory_;
};

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ version = "0.6.0"
[tool.scikit-build]
cmake.version = ">=3.25.0"
cmake.build-type = "Release"
# for dlpack issue: https://github.com/microsoft/vcpkg/pull/44679
cmake.args = ["-DCMAKE_POLICY_VERSION_MINIMUM=3.5"]
build-dir = "build/{wheel_tag}"
wheel.packages = ["python/mscclpp", "python/mscclpp_benchmark"]
wheel.install-dir = "mscclpp"
Expand Down
5 changes: 4 additions & 1 deletion python/mscclpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ include(FetchContent)
FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.4.0)
FetchContent_MakeAvailable(nanobind)

FetchContent_Declare(dlpack GIT_REPOSITORY https://github.com/dmlc/dlpack.git GIT_TAG v1.1)
FetchContent_MakeAvailable(dlpack)

file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp)
nanobind_add_module(mscclpp_py ${SOURCES})
set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
target_link_libraries(mscclpp_py PRIVATE mscclpp_static ${GPU_LIBRARIES})
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp_static ${GPU_LIBRARIES})
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
install(TARGETS mscclpp_py LIBRARY DESTINATION .)
100 changes: 99 additions & 1 deletion python/mscclpp/gpu_utils_py.cpp
Original file line number Diff line number Diff line change
@@ -1,18 +1,116 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <dlpack/dlpack.h>
#include <nanobind/nanobind.h>
#include <nanobind/stl/shared_ptr.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>

#include <mscclpp/gpu_data_types.hpp>
#include <mscclpp/gpu_utils.hpp>

namespace nb = nanobind;
using namespace mscclpp;

constexpr int BYTE_BITS = 8;

static DLDeviceType getDeviceType() {
#if defined(__HIP_PLATFORM_AMD__)
return kDLROCM;
#else
return kDLCUDA;
#endif
}

static DLDataType getDlType(std::string type) {
if (type == "torch.float") {
return DLDataType{kDLFloat, 32, 1};
} else if (type == "torch.int") {
return DLDataType{kDLInt, 32, 1};
} else if (type == "torch.uint32") {
return DLDataType{kDLUInt, 32, 1};
} else if (type == "torch.bfloat16") {
return DLDataType{kDLBfloat, 16, 1};
} else if (type == "torch.float16") {
return DLDataType{kDLFloat, 16, 1};
} else {
throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage);
}
}

static nb::capsule toDlpack(GpuBuffer<char> buffer, std::string dataType, std::vector<int64_t>& shape,
std::vector<int64_t>& strides) {
DLDataType dtype = getDlType(dataType);
int64_t* tensorShape = shape.size() > 0 ? new int64_t[shape.size()] : new int64_t[1];
int64_t* tensorStrides = strides.size() > 0 ? new int64_t[strides.size()] : nullptr;
if (shape.size() == 0) {
tensorShape[0] = (int64_t)(buffer.nelems() / ((dtype.bits * dtype.lanes + 7) / BYTE_BITS));
} else {
for (size_t i = 0; i < shape.size(); ++i) {
tensorShape[i] = shape[i];
}
}
for (size_t i = 0; i < strides.size(); ++i) {
tensorStrides[i] = strides[i];
}

DLManagedTensor* dlManagedTensor = new DLManagedTensor();
dlManagedTensor->dl_tensor.data = buffer.data();
dlManagedTensor->dl_tensor.device.device_type = getDeviceType();
dlManagedTensor->dl_tensor.device.device_id = buffer.deviceId();
dlManagedTensor->dl_tensor.ndim = shape.size() == 0 ? 1 : shape.size();
dlManagedTensor->dl_tensor.strides = tensorStrides;
dlManagedTensor->dl_tensor.shape = tensorShape;
dlManagedTensor->dl_tensor.byte_offset = 0;
dlManagedTensor->dl_tensor.dtype = dtype;
dlManagedTensor->manager_ctx = new GpuBuffer<char>(buffer);
dlManagedTensor->deleter = [](DLManagedTensor* self) {
delete static_cast<GpuBuffer<char>*>(self->manager_ctx);
self->manager_ctx = nullptr;
self->dl_tensor.data = nullptr;
if (self->dl_tensor.shape != nullptr) {
delete[] self->dl_tensor.shape;
self->dl_tensor.shape = nullptr;
if (self->dl_tensor.strides) {
delete[] self->dl_tensor.strides;
self->dl_tensor.strides = nullptr;
}
}
delete self;
};

PyObject* dlCapsule = PyCapsule_New(static_cast<void*>(dlManagedTensor), "dltensor", [](PyObject* capsule) {
if (PyCapsule_IsValid(capsule, "used_dltensor")) {
return;
}
if (!PyCapsule_IsValid(capsule, "dltensor")) {
return;
}
DLManagedTensor* managedTensor = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(capsule, "dltensor"));
if (managedTensor == nullptr) {
return;
}
if (managedTensor->deleter) {
managedTensor->deleter(managedTensor);
}
});
return nb::steal<nb::capsule>(dlCapsule);
}

void register_gpu_utils(nb::module_& m) {
m.def("is_nvls_supported", &isNvlsSupported);

nb::class_<GpuBuffer<char>>(m, "RawGpuBuffer")
.def(nb::init<size_t>(), nb::arg("nelems"))
.def("nelems", &GpuBuffer<char>::nelems)
.def("bytes", &GpuBuffer<char>::bytes)
.def("data", [](GpuBuffer<char>& self) { return reinterpret_cast<uintptr_t>(self.data()); });
.def("data", [](GpuBuffer<char>& self) { return reinterpret_cast<uintptr_t>(self.data()); })
.def("device_id", &GpuBuffer<char>::deviceId)
.def(
"to_dlpack",
[](GpuBuffer<char>& self, std::string dataType, std::vector<int64_t> shape, std::vector<int64_t> strides) {
return toDlpack(self, dataType, shape, strides);
},
nb::arg("dataType"), nb::arg("shape") = std::vector<int64_t>(), nb::arg("strides") = std::vector<int64_t>());
}
2 changes: 1 addition & 1 deletion src/executor/execution_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
for (int i = 0; i < operation.nOutputs; i++) {
if (operation.channelType == mscclpp::ChannelType::NVLS) {
BufferType dstBufferType = convertToBufferType(op["dstbuff"]);
operation.nvlsInputIndex =
operation.nvlsOutputIndex =
channelIndexes[{dstBufferType, dstBufferType, ChannelType::NVLS}][op["o_cids"][0]["id"]];
chunkIndexes.push_back((uint32_t)op["dstoff"]);
} else {
Expand Down
17 changes: 17 additions & 0 deletions src/gpu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,21 @@ bool isNvlsSupported() {
return false;
}

bool isCuMemMapAllocated([[maybe_unused]] void* ptr) {
#if defined(__HIP_PLATFORM_AMD__)
return false;
#else
CUmemGenericAllocationHandle handle;
CUresult result = cuMemRetainAllocationHandle(&handle, ptr);
if (result != CUDA_SUCCESS) {
return false;
}
MSCCLPP_CUTHROW(cuMemRelease(handle));
if (!mscclpp::isNvlsSupported()) {
throw mscclpp::Error("cuMemMap is used in env without NVLS support", mscclpp::ErrorCode::InvalidUsage);
}
return true;
#endif
}

} // namespace mscclpp
5 changes: 5 additions & 0 deletions src/nvls.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ void NvlsConnection::Impl::freeBuffer(size_t offset, size_t size) noexcept {
}

std::shared_ptr<char> NvlsConnection::Impl::bindMemory(CUdeviceptr devicePtr, size_t devBuffSize) {
if (!isCuMemMapAllocated((void*)devicePtr)) {
throw Error("This NVLS connection tried to bind a buffer that was not allocated with cuMemMap",
ErrorCode::InvalidUsage);
}

devBuffSize = ((devBuffSize + minMcGran_ - 1) / minMcGran_) * minMcGran_;
size_t offset = allocateBuffer(devBuffSize);
MSCCLPP_CUTHROW(cuMulticastBindAddr(mcHandle_, offset /*mcOffset*/, devicePtr, devBuffSize, 0));
Expand Down
18 changes: 0 additions & 18 deletions src/registered_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,6 @@ CUmemAllocationHandleType getNvlsMemHandleType() {
#endif
}

// Check if ptr is allocaed by cuMemMap
bool isCuMemMapAllocated([[maybe_unused]] void* ptr) {
#if defined(__HIP_PLATFORM_AMD__)
return false;
#else
CUmemGenericAllocationHandle handle;
CUresult result = cuMemRetainAllocationHandle(&handle, ptr);
if (result != CUDA_SUCCESS) {
return false;
}
MSCCLPP_CUTHROW(cuMemRelease(handle));
if (!mscclpp::isNvlsSupported()) {
throw mscclpp::Error("cuMemMap is used in env without NVLS support", mscclpp::ErrorCode::InvalidUsage);
}
return true;
#endif
}

} // namespace

namespace mscclpp {
Expand Down
Loading