Skip to content

Commit 4f6f23d

Browse files
Binyang2014chhwang
andauthored
Use smart pointer for IB structure (#585)
Change to use smart pointer for IB structure. Registered memory will own ibMr, ibCtx will not held the reference - Use smart pointer for IbQp and IbMr - Update memoryChannel API, keep localRegisteredMemory - Close fd when registedMemory released --------- Co-authored-by: Changho Hwang <[email protected]>
1 parent d55ac96 commit 4f6f23d

23 files changed

+175
-118
lines changed

apps/nccl/src/nccl.cu

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ struct ncclComm {
199199
std::unordered_map<channelKey, ChannelInfo> channelScratchInfos;
200200
std::unordered_map<channelKey, NvlsChannelInfo> channelNvlsInfos;
201201
std::shared_ptr<char> scratchBuff;
202+
mscclpp::RegisteredMemory registeredScratchMemory;
202203
std::vector<mscclpp::RegisteredMemory> remoteScratchRegMemories;
203204
std::vector<ChannelInfo> channelInfos;
204205

@@ -268,30 +269,29 @@ static Op getReduceOp(ncclRedOp_t op) {
268269
}
269270

270271
static std::vector<mscclpp::RegisteredMemory> setupRemoteMemories(std::shared_ptr<mscclpp::Communicator> comm, int rank,
271-
void* buff, size_t bytes,
272-
mscclpp::TransportFlags transport) {
272+
mscclpp::RegisteredMemory localMemory) {
273273
std::vector<mscclpp::RegisteredMemory> remoteMemories;
274-
mscclpp::RegisteredMemory memory = comm->registerMemory(buff, bytes, transport);
275274
std::vector<std::shared_future<mscclpp::RegisteredMemory>> remoteRegMemoryFutures;
276275
for (int i = 0; i < comm->bootstrap()->getNranks(); i++) {
277276
if (i == rank) continue;
278277
remoteRegMemoryFutures.push_back(comm->recvMemory(i));
279-
comm->sendMemory(memory, i);
278+
comm->sendMemory(localMemory, i);
280279
}
281280
std::transform(remoteRegMemoryFutures.begin(), remoteRegMemoryFutures.end(), std::back_inserter(remoteMemories),
282281
[](const auto& future) { return future.get(); });
283282
return remoteMemories;
284283
}
285284

286285
static std::vector<mscclpp::MemoryChannel> setupMemoryChannels(
287-
ncclComm_t comm, const std::vector<mscclpp::RegisteredMemory>& remoteMemories, void* src) {
286+
ncclComm_t comm, const std::vector<mscclpp::RegisteredMemory>& remoteMemories,
287+
mscclpp::RegisteredMemory localMemory) {
288288
std::vector<mscclpp::MemoryChannel> channels;
289289
std::vector<std::shared_ptr<mscclpp::MemoryDevice2DeviceSemaphore>>& memorySemaphores = comm->memorySemaphores;
290290
size_t nConnections = comm->connections.size();
291291
for (size_t idx = 0; idx < NUM_CHANNELS_PER_CONNECTION; ++idx) {
292292
for (size_t cid = 0; cid < nConnections; ++cid) {
293293
if (comm->connections[cid]->transport() == mscclpp::Transport::CudaIpc) {
294-
channels.emplace_back(memorySemaphores[idx * nConnections + cid], remoteMemories[cid], src, nullptr);
294+
channels.emplace_back(memorySemaphores[idx * nConnections + cid], remoteMemories[cid], localMemory, nullptr);
295295
}
296296
}
297297
}
@@ -432,8 +432,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
432432
if (count * ncclTypeSize(datatype) <= (1 << 20) || mscclpp::isNvlsSupported()) {
433433
auto sendIt = comm->channelScratchInfos.find(sendKey);
434434
if (sendIt == comm->channelScratchInfos.end()) {
435+
mscclpp::RegisteredMemory localMemory =
436+
comm->comm->registerMemory((void*)sendBasePtr, sendBytes, mscclpp::Transport::CudaIpc);
435437
std::vector<mscclpp::MemoryChannel> channels =
436-
setupMemoryChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
438+
setupMemoryChannels(comm, comm->remoteScratchRegMemories, localMemory);
437439
ChannelInfo channelInfo{channels, setupMemoryChannelDeviceHandles(channels)};
438440
sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first;
439441
}
@@ -444,8 +446,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
444446

445447
auto sendIt = comm->channelInInfos.find(sendKey);
446448
if (sendIt == comm->channelInInfos.end()) {
449+
mscclpp::RegisteredMemory localMemory =
450+
comm->comm->registerMemory((void*)sendBasePtr, sendBytes, mscclpp::Transport::CudaIpc);
447451
std::vector<mscclpp::MemoryChannel> channels =
448-
setupMemoryChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
452+
setupMemoryChannels(comm, comm->remoteScratchRegMemories, localMemory);
449453
ChannelInfo channelInfo{channels, setupMemoryChannelDeviceHandles(channels)};
450454
sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first;
451455
}
@@ -457,10 +461,10 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
457461
recvBasePtr = (CUdeviceptr)recvbuff;
458462
offsetOut = 0;
459463
}
460-
remoteMemories =
461-
setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
462-
std::vector<mscclpp::MemoryChannel> outChannels =
463-
setupMemoryChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
464+
mscclpp::RegisteredMemory localMemory =
465+
comm->comm->registerMemory((void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
466+
remoteMemories = setupRemoteMemories(comm->comm, rank, localMemory);
467+
std::vector<mscclpp::MemoryChannel> outChannels = setupMemoryChannels(comm, remoteMemories, localMemory);
464468
ChannelInfo channelInfo{outChannels, setupMemoryChannelDeviceHandles(outChannels)};
465469
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
466470
if (mscclppDisableChannelCache == true) {
@@ -552,10 +556,10 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,
552556
recvBasePtr = (CUdeviceptr)recvbuff;
553557
offsetOut = 0;
554558
}
555-
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
556-
comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
557-
std::vector<mscclpp::MemoryChannel> channels =
558-
setupMemoryChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
559+
mscclpp::RegisteredMemory localMemory =
560+
comm->comm->registerMemory((void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
561+
std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(comm->comm, rank, localMemory);
562+
std::vector<mscclpp::MemoryChannel> channels = setupMemoryChannels(comm, remoteMemories, localMemory);
559563
std::vector<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> memoryChannelDeviceHandles;
560564
std::transform(channels.begin(), channels.end(), std::back_inserter(memoryChannelDeviceHandles),
561565
[](const mscclpp::MemoryChannel& memoryChannel) { return mscclpp::deviceHandle(memoryChannel); });
@@ -577,8 +581,10 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,
577581
#else
578582
auto sendIt = comm->channelInInfos.find(sendKey);
579583
if (sendIt == comm->channelInInfos.end()) {
584+
mscclpp::RegisteredMemory localMemory =
585+
comm->comm->registerMemory((void*)sendBasePtr, sendBytes, mscclpp::Transport::CudaIpc);
580586
std::vector<mscclpp::MemoryChannel> channels =
581-
setupMemoryChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
587+
setupMemoryChannels(comm, comm->remoteScratchRegMemories, localMemory);
582588
ChannelInfo channelInfo{channels, setupMemoryChannelDeviceHandles(channels)};
583589
sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first;
584590
}
@@ -629,8 +635,9 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt
629635
commPtr->buffFlag = 0;
630636
commPtr->numScratchBuff = 2;
631637
commPtr->scratchBuff = mscclpp::GpuBuffer<char>(SCRATCH_SIZE).memory();
632-
commPtr->remoteScratchRegMemories =
633-
setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
638+
commPtr->registeredScratchMemory =
639+
commPtr->comm->registerMemory(commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc);
640+
commPtr->remoteScratchRegMemories = setupRemoteMemories(commPtr->comm, rank, commPtr->registeredScratchMemory);
634641

635642
commPtr->deviceFlag7 = mscclpp::detail::gpuCallocShared<uint32_t>(7);
636643
commPtr->deviceFlag28 = mscclpp::detail::gpuCallocShared<uint32_t>(28);
@@ -935,12 +942,10 @@ NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff
935942

936943
auto it = comm->channelOutInfos.find(recvKey);
937944
if (it == comm->channelOutInfos.end()) {
938-
// std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
939-
// comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
940-
// std::vector<mscclpp::MemoryChannel> channels =
941-
// setupMemoryChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
945+
mscclpp::RegisteredMemory localMemory =
946+
comm->comm->registerMemory((void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
942947
std::vector<mscclpp::MemoryChannel> channels =
943-
setupMemoryChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)recvBasePtr));
948+
setupMemoryChannels(comm, comm->remoteScratchRegMemories, localMemory);
944949
std::vector<mscclpp::DeviceHandle<mscclpp::MemoryChannel>> memoryChannelDeviceHandles;
945950
std::transform(channels.begin(), channels.end(), std::back_inserter(memoryChannelDeviceHandles),
946951
[](const mscclpp::MemoryChannel& memoryChannel) { return mscclpp::deviceHandle(memoryChannel); });

include/mscclpp/memory_channel.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ struct BaseMemoryChannel {
4848
struct MemoryChannel : public BaseMemoryChannel {
4949
private:
5050
RegisteredMemory dst_;
51-
void* src_;
51+
RegisteredMemory src_;
5252
void* packetBuffer_;
5353

5454
public:
@@ -58,19 +58,19 @@ struct MemoryChannel : public BaseMemoryChannel {
5858
/// Constructor.
5959
/// @param semaphore The semaphore used to synchronize the communication.
6060
/// @param dst Registered memory of the destination.
61-
/// @param src The source memory address.
61+
/// @param src Registered memory of the source.
6262
/// @param packetBuffer A buffer used to store packets. @p packetBuffer is optional and if it is nullptr,
6363
/// unpackPacket() and unpackPackets() methods are not available.
64-
MemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore, RegisteredMemory dst, void* src,
64+
MemoryChannel(std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore, RegisteredMemory dst, RegisteredMemory src,
6565
void* packetBuffer = nullptr);
6666

6767
/// Constructor.
6868
/// @param semaphore The semaphore used to synchronize the communication.
6969
/// @param dst Registered memory of the destination.
70-
/// @param src The source memory address.
70+
/// @param src Registered memory of the source.
7171
/// @param packetBuffer A buffer used to store packets. @p packetBuffer is optional and if it is nullptr,
7272
/// unpackPacket() and unpackPackets() methods are not available.
73-
MemoryChannel(const Semaphore& semaphore, RegisteredMemory dst, void* src, void* packetBuffer = nullptr);
73+
MemoryChannel(const Semaphore& semaphore, RegisteredMemory dst, RegisteredMemory src, void* packetBuffer = nullptr);
7474

7575
/// Device-side handle for MemoryChannel.
7676
using DeviceHandle = MemoryChannelDeviceHandle;

python/mscclpp/comm.py

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT license.
33

44
from __future__ import annotations
5-
from typing import Type
5+
from typing import Tuple, Type
66

77
import cupy as cp
88
from ._mscclpp import (
@@ -109,18 +109,7 @@ def make_connection(
109109
def register_tensor_with_connections(
110110
self, tensor: Type[cp.ndarray] | Type[np.ndarray], connections: dict[int, Connection]
111111
) -> dict[int, RegisteredMemory]:
112-
transport_flags = TransportFlags()
113-
for rank in connections:
114-
transport_flags |= connections[rank].transport()
115-
data_ptr = (
116-
tensor.data.ptr
117-
if isinstance(tensor, cp.ndarray)
118-
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
119-
)
120-
tensor_size = (
121-
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
122-
)
123-
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
112+
local_reg_memory = self.register_local_memory(tensor, connections)
124113
all_registered_memories = {}
125114
all_registered_memories[self.my_rank] = local_reg_memory
126115
future_memories = {}
@@ -131,6 +120,19 @@ def register_tensor_with_connections(
131120
all_registered_memories[rank] = future_memories[rank].get()
132121
return all_registered_memories
133122

123+
def _register_memory_with_connections(
124+
self, memory: RegisteredMemory, connections: dict[int, Connection]
125+
) -> dict[int, RegisteredMemory]:
126+
all_registered_memories = {}
127+
all_registered_memories[self.my_rank] = memory
128+
future_memories = {}
129+
for rank in connections:
130+
self.communicator.send_memory(memory, rank)
131+
future_memories[rank] = self.communicator.recv_memory(rank)
132+
for rank in connections:
133+
all_registered_memories[rank] = future_memories[rank].get()
134+
return all_registered_memories
135+
134136
def make_semaphore(
135137
self,
136138
connections: dict[int, Connection],
@@ -145,31 +147,36 @@ def make_memory_channels(self, tensor: cp.ndarray, connections: dict[int, Connec
145147
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
146148
registered_memories = self.register_tensor_with_connections(tensor, connections)
147149
channels = {}
148-
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
149150
for rank in connections:
150-
channels[rank] = MemoryChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
151+
channels[rank] = MemoryChannel(
152+
semaphores[rank], registered_memories[rank], registered_memories[self.my_rank]
153+
)
151154
return channels
152155

153156
def make_memory_channels_with_scratch(
154157
self,
155158
tensor: cp.ndarray,
156-
scratchTensor: cp.ndarray,
159+
registeredScratchBuffer: RegisteredMemory,
157160
connections: dict[int, Connection],
158161
) -> dict[int, MemoryChannel]:
159162
semaphores = self.make_semaphore(connections, MemoryDevice2DeviceSemaphore)
160-
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
163+
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
161164
channels = {}
162165
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
163-
scratch_data_ptr = scratchTensor.data_ptr() if is_torch_tensor(scratchTensor) else scratchTensor.data.ptr
166+
tensor_size = (
167+
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
168+
)
169+
local_registered_memory = self.communicator.register_memory(tensor_data_ptr, tensor_size, TransportFlags())
170+
scratch_data_ptr = registeredScratchBuffer.data()
164171
for rank in connections:
165172
channels[rank] = MemoryChannel(
166-
semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr
173+
semaphores[rank], registered_memories[rank], local_registered_memory, scratch_data_ptr
167174
)
168175
return channels
169176

170177
def make_port_channels(
171178
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
172-
) -> dict[int, MemoryChannel]:
179+
) -> dict[int, PortChannel]:
173180
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
174181
registered_memories = self.register_tensor_with_connections(tensor, connections)
175182
memory_ids = {}
@@ -187,9 +194,9 @@ def make_port_channels_with_scratch(
187194
self,
188195
proxy_service: ProxyService,
189196
tensor: cp.ndarray,
190-
scratchTensor: cp.ndarray,
197+
registeredScratchBuffer: RegisteredMemory,
191198
connections: dict[int, Connection],
192-
) -> dict[int, MemoryChannel]:
199+
) -> dict[int, PortChannel]:
193200
transport_flags = TransportFlags()
194201
for rank in connections:
195202
transport_flags |= connections[rank].transport()
@@ -204,7 +211,7 @@ def make_port_channels_with_scratch(
204211
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
205212

206213
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
207-
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
214+
registered_memories = self._register_memory_with_connections(registeredScratchBuffer, connections)
208215
memory_ids = {}
209216
semaphore_ids = {}
210217
for rank in registered_memories:
@@ -221,7 +228,7 @@ def make_port_channels_with_scratch(
221228

222229
def register_semaphore_with_proxy(
223230
self, proxy_service: ProxyService, connections: dict[int, Connection]
224-
) -> dict[int, MemoryChannel]:
231+
) -> dict[int, PortChannel]:
225232
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
226233
semaphore_ids = {}
227234
for rank in semaphores:
@@ -239,3 +246,17 @@ def register_memory_with_proxy(
239246
for rank in registered_memories:
240247
memory_ids[rank] = proxy_service.add_memory(registered_memories[rank])
241248
return memory_ids
249+
250+
def register_local_memory(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> RegisteredMemory:
251+
transport_flags = TransportFlags()
252+
for rank in connections:
253+
transport_flags |= connections[rank].transport()
254+
data_ptr = (
255+
tensor.data.ptr
256+
if isinstance(tensor, cp.ndarray)
257+
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
258+
)
259+
tensor_size = (
260+
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
261+
)
262+
return self.communicator.register_memory(data_ptr, tensor_size, transport_flags)

python/mscclpp/core_py.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ void register_core(nb::module_& m) {
126126

127127
nb::class_<RegisteredMemory>(m, "RegisteredMemory")
128128
.def(nb::init<>())
129-
.def("data", &RegisteredMemory::data)
129+
.def("data", [](RegisteredMemory& self) { return reinterpret_cast<uintptr_t>(self.data()); })
130130
.def("size", &RegisteredMemory::size)
131131
.def("transports", &RegisteredMemory::transports)
132132
.def("serialize", &RegisteredMemory::serialize)

python/mscclpp/memory_channel_py.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,11 @@ void register_memory_channel(nb::module_& m) {
2828
.def(nb::init<>())
2929
.def("__init__",
3030
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
31-
RegisteredMemory dst,
32-
uintptr_t src) { new (memoryChannel) MemoryChannel(semaphore, dst, reinterpret_cast<void*>(src)); })
31+
RegisteredMemory dst, RegisteredMemory src) { new (memoryChannel) MemoryChannel(semaphore, dst, src); })
3332
.def("__init__",
3433
[](MemoryChannel* memoryChannel, std::shared_ptr<MemoryDevice2DeviceSemaphore> semaphore,
35-
RegisteredMemory dst, uintptr_t src, uintptr_t packet_buffer) {
36-
new (memoryChannel)
37-
MemoryChannel(semaphore, dst, reinterpret_cast<void*>(src), reinterpret_cast<void*>(packet_buffer));
34+
RegisteredMemory dst, RegisteredMemory src, uintptr_t packet_buffer) {
35+
new (memoryChannel) MemoryChannel(semaphore, dst, src, reinterpret_cast<void*>(packet_buffer));
3836
})
3937
.def("device_handle", &MemoryChannel::deviceHandle);
4038

0 commit comments

Comments
 (0)