Skip to content

Commit adc9ee5

Browse files
authored
Export mscclpp GpuBuffer to dlpack format (#492)
For mscclpp, to use nvls we require the buffer is allocated by mscclpp::GpuBuffer. Due to cupy doesn't support bfloat16 yet, we export the raw buffer to dlpack format. User can use this feature to create buffer with type supported by pytorch ```python buffer = RawGpuBuffer(1024 * 2) # 2 for bfloat16 dl_pack = buffer.to_dlpack(str(torch.bfloat16)) tensor = torch.utils.dlpack.from_dlpack(dl_pack) ```
1 parent 5a7a59f commit adc9ee5

File tree

9 files changed

+142
-24
lines changed

9 files changed

+142
-24
lines changed

.github/workflows/lint.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ on:
77

88
jobs:
99
cpplint:
10-
runs-on: ubuntu-20.04
10+
runs-on: ubuntu-22.04
1111

1212
steps:
1313
- name: Check out Git repository
@@ -24,7 +24,7 @@ jobs:
2424
clang-format -style=file --verbose --Werror --dry-run ${CPPSOURCES}
2525
2626
pylint:
27-
runs-on: ubuntu-20.04
27+
runs-on: ubuntu-22.04
2828

2929
steps:
3030
- name: Check out Git repository
@@ -42,7 +42,7 @@ jobs:
4242
run: python3 -m black --check --config pyproject.toml .
4343

4444
spelling:
45-
runs-on: ubuntu-20.04
45+
runs-on: ubuntu-22.04
4646

4747
steps:
4848
- name: Check out Git repository

include/mscclpp/gpu_utils.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,11 @@ void gpuMemcpy(T* dst, const T* src, size_t nelems, cudaMemcpyKind kind = cudaMe
202202
/// @return True if NVLink SHARP (NVLS) is supported, false otherwise.
203203
bool isNvlsSupported();
204204

205+
/// Check if ptr is allocaed by cuMemMap
206+
/// @param ptr The pointer to check.
207+
/// @return True if the pointer is allocated by cuMemMap, false otherwise.
208+
bool isCuMemMapAllocated([[maybe_unused]] void* ptr);
209+
205210
/// Allocates a GPU memory space specialized for communication. The memory is zeroed out. Get the device pointer by
206211
/// `GpuBuffer::data()`.
207212
///
@@ -224,6 +229,7 @@ class GpuBuffer {
224229
bytes_ = 0;
225230
return;
226231
}
232+
MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId_));
227233
#if (CUDA_NVLS_SUPPORTED)
228234
if (isNvlsSupported()) {
229235
size_t gran = detail::getMulticastGranularity(nelems * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED);
@@ -259,9 +265,14 @@ class GpuBuffer {
259265
/// @return A device pointer to the allocated memory.
260266
T* data() { return memory_.get(); }
261267

268+
/// Returns the device id of the allocated memory.
269+
/// @return The device id.
270+
int deviceId() const { return deviceId_; }
271+
262272
private:
263273
size_t nelems_;
264274
size_t bytes_;
275+
int deviceId_;
265276
std::shared_ptr<T> memory_;
266277
};
267278

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ version = "0.6.0"
1212
[tool.scikit-build]
1313
cmake.version = ">=3.25.0"
1414
cmake.build-type = "Release"
15+
# for dlpack issue: https://github.com/microsoft/vcpkg/pull/44679
16+
cmake.args = ["-DCMAKE_POLICY_VERSION_MINIMUM=3.5"]
1517
build-dir = "build/{wheel_tag}"
1618
wheel.packages = ["python/mscclpp", "python/mscclpp_benchmark"]
1719
wheel.install-dir = "mscclpp"

python/mscclpp/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@ include(FetchContent)
66
FetchContent_Declare(nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git GIT_TAG v1.4.0)
77
FetchContent_MakeAvailable(nanobind)
88

9+
FetchContent_Declare(dlpack GIT_REPOSITORY https://github.com/dmlc/dlpack.git GIT_TAG v1.1)
10+
FetchContent_MakeAvailable(dlpack)
11+
912
file(GLOB_RECURSE SOURCES CONFIGURE_DEPENDS *.cpp)
1013
nanobind_add_module(mscclpp_py ${SOURCES})
1114
set_target_properties(mscclpp_py PROPERTIES OUTPUT_NAME _mscclpp)
12-
target_link_libraries(mscclpp_py PRIVATE mscclpp_static ${GPU_LIBRARIES})
15+
target_link_libraries(mscclpp_py PRIVATE dlpack mscclpp_static ${GPU_LIBRARIES})
1316
target_include_directories(mscclpp_py SYSTEM PRIVATE ${GPU_INCLUDE_DIRS})
1417
install(TARGETS mscclpp_py LIBRARY DESTINATION .)

python/mscclpp/gpu_utils_py.cpp

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,116 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#include <dlpack/dlpack.h>
15
#include <nanobind/nanobind.h>
26
#include <nanobind/stl/shared_ptr.h>
7+
#include <nanobind/stl/string.h>
8+
#include <nanobind/stl/vector.h>
39

410
#include <mscclpp/gpu_data_types.hpp>
511
#include <mscclpp/gpu_utils.hpp>
612

713
namespace nb = nanobind;
814
using namespace mscclpp;
915

16+
constexpr int BYTE_BITS = 8;
17+
18+
static DLDeviceType getDeviceType() {
19+
#if defined(__HIP_PLATFORM_AMD__)
20+
return kDLROCM;
21+
#else
22+
return kDLCUDA;
23+
#endif
24+
}
25+
26+
static DLDataType getDlType(std::string type) {
27+
if (type == "torch.float") {
28+
return DLDataType{kDLFloat, 32, 1};
29+
} else if (type == "torch.int") {
30+
return DLDataType{kDLInt, 32, 1};
31+
} else if (type == "torch.uint32") {
32+
return DLDataType{kDLUInt, 32, 1};
33+
} else if (type == "torch.bfloat16") {
34+
return DLDataType{kDLBfloat, 16, 1};
35+
} else if (type == "torch.float16") {
36+
return DLDataType{kDLFloat, 16, 1};
37+
} else {
38+
throw Error("Unsupported type: " + type, ErrorCode::InvalidUsage);
39+
}
40+
}
41+
42+
static nb::capsule toDlpack(GpuBuffer<char> buffer, std::string dataType, std::vector<int64_t>& shape,
43+
std::vector<int64_t>& strides) {
44+
DLDataType dtype = getDlType(dataType);
45+
int64_t* tensorShape = shape.size() > 0 ? new int64_t[shape.size()] : new int64_t[1];
46+
int64_t* tensorStrides = strides.size() > 0 ? new int64_t[strides.size()] : nullptr;
47+
if (shape.size() == 0) {
48+
tensorShape[0] = (int64_t)(buffer.nelems() / ((dtype.bits * dtype.lanes + 7) / BYTE_BITS));
49+
} else {
50+
for (size_t i = 0; i < shape.size(); ++i) {
51+
tensorShape[i] = shape[i];
52+
}
53+
}
54+
for (size_t i = 0; i < strides.size(); ++i) {
55+
tensorStrides[i] = strides[i];
56+
}
57+
58+
DLManagedTensor* dlManagedTensor = new DLManagedTensor();
59+
dlManagedTensor->dl_tensor.data = buffer.data();
60+
dlManagedTensor->dl_tensor.device.device_type = getDeviceType();
61+
dlManagedTensor->dl_tensor.device.device_id = buffer.deviceId();
62+
dlManagedTensor->dl_tensor.ndim = shape.size() == 0 ? 1 : shape.size();
63+
dlManagedTensor->dl_tensor.strides = tensorStrides;
64+
dlManagedTensor->dl_tensor.shape = tensorShape;
65+
dlManagedTensor->dl_tensor.byte_offset = 0;
66+
dlManagedTensor->dl_tensor.dtype = dtype;
67+
dlManagedTensor->manager_ctx = new GpuBuffer<char>(buffer);
68+
dlManagedTensor->deleter = [](DLManagedTensor* self) {
69+
delete static_cast<GpuBuffer<char>*>(self->manager_ctx);
70+
self->manager_ctx = nullptr;
71+
self->dl_tensor.data = nullptr;
72+
if (self->dl_tensor.shape != nullptr) {
73+
delete[] self->dl_tensor.shape;
74+
self->dl_tensor.shape = nullptr;
75+
if (self->dl_tensor.strides) {
76+
delete[] self->dl_tensor.strides;
77+
self->dl_tensor.strides = nullptr;
78+
}
79+
}
80+
delete self;
81+
};
82+
83+
PyObject* dlCapsule = PyCapsule_New(static_cast<void*>(dlManagedTensor), "dltensor", [](PyObject* capsule) {
84+
if (PyCapsule_IsValid(capsule, "used_dltensor")) {
85+
return;
86+
}
87+
if (!PyCapsule_IsValid(capsule, "dltensor")) {
88+
return;
89+
}
90+
DLManagedTensor* managedTensor = static_cast<DLManagedTensor*>(PyCapsule_GetPointer(capsule, "dltensor"));
91+
if (managedTensor == nullptr) {
92+
return;
93+
}
94+
if (managedTensor->deleter) {
95+
managedTensor->deleter(managedTensor);
96+
}
97+
});
98+
return nb::steal<nb::capsule>(dlCapsule);
99+
}
100+
10101
void register_gpu_utils(nb::module_& m) {
11102
m.def("is_nvls_supported", &isNvlsSupported);
12103

13104
nb::class_<GpuBuffer<char>>(m, "RawGpuBuffer")
14105
.def(nb::init<size_t>(), nb::arg("nelems"))
15106
.def("nelems", &GpuBuffer<char>::nelems)
16107
.def("bytes", &GpuBuffer<char>::bytes)
17-
.def("data", [](GpuBuffer<char>& self) { return reinterpret_cast<uintptr_t>(self.data()); });
108+
.def("data", [](GpuBuffer<char>& self) { return reinterpret_cast<uintptr_t>(self.data()); })
109+
.def("device_id", &GpuBuffer<char>::deviceId)
110+
.def(
111+
"to_dlpack",
112+
[](GpuBuffer<char>& self, std::string dataType, std::vector<int64_t> shape, std::vector<int64_t> strides) {
113+
return toDlpack(self, dataType, shape, strides);
114+
},
115+
nb::arg("dataType"), nb::arg("shape") = std::vector<int64_t>(), nb::arg("strides") = std::vector<int64_t>());
18116
}

src/executor/execution_plan.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ void ExecutionPlan::Impl::setupOperations(const json& gpus, size_t constSrcOffse
455455
for (int i = 0; i < operation.nOutputs; i++) {
456456
if (operation.channelType == mscclpp::ChannelType::NVLS) {
457457
BufferType dstBufferType = convertToBufferType(op["dstbuff"]);
458-
operation.nvlsInputIndex =
458+
operation.nvlsOutputIndex =
459459
channelIndexes[{dstBufferType, dstBufferType, ChannelType::NVLS}][op["o_cids"][0]["id"]];
460460
chunkIndexes.push_back((uint32_t)op["dstoff"]);
461461
} else {

src/gpu_utils.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,21 @@ bool isNvlsSupported() {
199199
return false;
200200
}
201201

202+
bool isCuMemMapAllocated([[maybe_unused]] void* ptr) {
203+
#if defined(__HIP_PLATFORM_AMD__)
204+
return false;
205+
#else
206+
CUmemGenericAllocationHandle handle;
207+
CUresult result = cuMemRetainAllocationHandle(&handle, ptr);
208+
if (result != CUDA_SUCCESS) {
209+
return false;
210+
}
211+
MSCCLPP_CUTHROW(cuMemRelease(handle));
212+
if (!mscclpp::isNvlsSupported()) {
213+
throw mscclpp::Error("cuMemMap is used in env without NVLS support", mscclpp::ErrorCode::InvalidUsage);
214+
}
215+
return true;
216+
#endif
217+
}
218+
202219
} // namespace mscclpp

src/nvls.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ void NvlsConnection::Impl::freeBuffer(size_t offset, size_t size) noexcept {
193193
}
194194

195195
std::shared_ptr<char> NvlsConnection::Impl::bindMemory(CUdeviceptr devicePtr, size_t devBuffSize) {
196+
if (!isCuMemMapAllocated((void*)devicePtr)) {
197+
throw Error("This NVLS connection tried to bind a buffer that was not allocated with cuMemMap",
198+
ErrorCode::InvalidUsage);
199+
}
200+
196201
devBuffSize = ((devBuffSize + minMcGran_ - 1) / minMcGran_) * minMcGran_;
197202
size_t offset = allocateBuffer(devBuffSize);
198203
MSCCLPP_CUTHROW(cuMulticastBindAddr(mcHandle_, offset /*mcOffset*/, devicePtr, devBuffSize, 0));

src/registered_memory.cc

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,24 +39,6 @@ CUmemAllocationHandleType getNvlsMemHandleType() {
3939
#endif
4040
}
4141

42-
// Check if ptr is allocaed by cuMemMap
43-
bool isCuMemMapAllocated([[maybe_unused]] void* ptr) {
44-
#if defined(__HIP_PLATFORM_AMD__)
45-
return false;
46-
#else
47-
CUmemGenericAllocationHandle handle;
48-
CUresult result = cuMemRetainAllocationHandle(&handle, ptr);
49-
if (result != CUDA_SUCCESS) {
50-
return false;
51-
}
52-
MSCCLPP_CUTHROW(cuMemRelease(handle));
53-
if (!mscclpp::isNvlsSupported()) {
54-
throw mscclpp::Error("cuMemMap is used in env without NVLS support", mscclpp::ErrorCode::InvalidUsage);
55-
}
56-
return true;
57-
#endif
58-
}
59-
6042
} // namespace
6143

6244
namespace mscclpp {

0 commit comments

Comments
 (0)