-
Notifications
You must be signed in to change notification settings - Fork 80
Expand file tree
/
Copy pathipc_handle.cpp
More file actions
458 lines (389 loc) · 16.4 KB
/
ipc_handle.cpp
File metadata and controls
458 lines (389 loc) · 16.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include "multidevice/ipc_handle.h"
#include "cuda_utils.h"
#include "multidevice/communicator.h"
#include "multidevice/ipc_utils.h"
#include "multidevice/utils.h"
namespace nvfuser {
IpcHandle::IpcHandle(at::Tensor tensor)
: ptr_(tensor.data_ptr()),
rank_(Communicator::getInstance().deviceId()),
tensor_(tensor) {
size_t psize = 0;
NVFUSER_CUDA_SAFE_CALL(cuMemGetAddressRange(
(CUdeviceptr*)&base_address_, &psize, (CUdeviceptr)ptr_));
offset_from_base_address_ = static_cast<int64_t>(
static_cast<uint8_t*>(ptr_) - static_cast<uint8_t*>(base_address_));
NVFUSER_CUDA_RT_SAFE_CALL(
cudaIpcGetMemHandle(&ipc_handle_, tensor.data_ptr()));
NVFUSER_CUDA_RT_SAFE_CALL(
cudaMalloc((void**)&semaphore_, sizeof(IpcSemaphore)));
static_assert(
sizeof(IpcSemaphore) == sizeof(int),
"IpcSemaphore must be same size as int");
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemset(
(void*)semaphore_, (int)IpcSemaphore::kIdle, sizeof(IpcSemaphore)));
NVFUSER_CUDA_RT_SAFE_CALL(
cudaIpcGetMemHandle(&semaphore_ipc_handle_, semaphore_));
}
IpcHandle::IpcHandle(std::vector<uint8_t> data) {
const IpcHandle& imported_buffer = fromBytes<IpcHandle>(data);
offset_from_base_address_ = imported_buffer.offset_from_base_address_;
ipc_handle_ = imported_buffer.ipc_handle_;
semaphore_ipc_handle_ = imported_buffer.semaphore_ipc_handle_;
rank_ = imported_buffer.rank_;
NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
&base_address_, ipc_handle_, cudaIpcMemLazyEnablePeerAccess));
ptr_ = (void*)((uint8_t*)base_address_ + offset_from_base_address_);
NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcOpenMemHandle(
(void**)&semaphore_,
semaphore_ipc_handle_,
cudaIpcMemLazyEnablePeerAccess));
}
IpcHandle::~IpcHandle() {
if (rank_ == Communicator::getInstance().deviceId()) {
NVFUSER_CUDA_RT_SAFE_CALL(cudaFree((void*)semaphore_));
} else {
NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(base_address_));
NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle((void*)semaphore_));
}
}
// Retrieves a key for the TCP store corresponding to a `communication` and the
// exporter `rank`
std::string IpcHandleCache::getTcpStoreKey(
P2PCommunication* communication,
int64_t rank) const {
const int64_t my_rank = Communicator::getInstance().deviceId();
const int64_t peer =
expr_evaluator_->evaluate(communication->peer()).as<int64_t>();
const int64_t src =
communication->type() == P2PCommunicationType::SEND ? my_rank : peer;
const int64_t dst =
communication->type() == P2PCommunicationType::SEND ? peer : my_rank;
return "nvfuser_ipc_handle_info_P2PComm_dst=" + std::to_string(dst) +
"_src=" + std::to_string(src) + "_rank=" + std::to_string(rank);
}
void IpcHandleCache::exchangeHandles(
const std::vector<P2PCommunication*>& communications) {
Communicator* communicator = &Communicator::getInstance();
const int64_t my_rank = communicator->deviceId();
std::vector<P2PCommunication*> non_cached_communications;
for (auto communication : communications) {
NVF_ERROR(
expr_evaluator_->evaluate(communication->peer()).as<int64_t>() !=
my_rank,
"send to self not supported");
if (find(communication) != nullptr) {
continue;
}
non_cached_communications.push_back(communication);
}
// Put memhandles to TCP store
std::unordered_map<P2PCommunication*, std::unique_ptr<IpcHandle>>
local_ipc_handles;
auto store = communicator->getTcpStore();
for (P2PCommunication* communication : non_cached_communications) {
at::Tensor tensor =
expr_evaluator_->evaluate(communication->buffer()).as<at::Tensor>();
NVF_ERROR(
tensor.is_contiguous(), "IpcHandle only supports contiguous tensors");
auto buffer_handle = std::make_unique<IpcHandle>(tensor);
auto key = getTcpStoreKey(communication, my_rank);
// TODO: use multiSet
store->set(key, toBytes(*buffer_handle));
local_ipc_handles.emplace(communication, std::move(buffer_handle));
}
// Get memhandles from TCP store
for (P2PCommunication* communication : non_cached_communications) {
const int64_t peer =
expr_evaluator_->evaluate(communication->peer()).as<int64_t>();
std::string key = getTcpStoreKey(communication, peer);
// TCP store get is blocking until a timeout
// TODO: use multiGet
auto peer_ipc_handle = std::make_unique<IpcHandle>(store->get(key));
store->deleteKey(key);
auto& local_ipc_handle = local_ipc_handles.at(communication);
auto ipc_handles = std::make_unique<P2pIpcHandle>(
std::move(local_ipc_handle), std::move(peer_ipc_handle));
insert(communication, std::move(ipc_handles));
}
if (non_cached_communications.empty()) {
return;
}
// a barrier is needed here to ensure all ranks have received the
// memhandles and the keys are deleted from the store before the next call to
// exchangeHandles, otherwise there is a correctness issue
// TODO: precisely select what ranks need to wait on that barrier.
communicator->barrier();
}
SymMemForBroadcast::SymMemForBroadcast(
Communication* communication,
int64_t root,
at::Tensor buffer)
: SymMemForBroadcast(
buffer,
root,
"for_Communication" + std::to_string(communication->name())) {}
SymMemForBroadcast::SymMemForBroadcast(
at::Tensor buffer,
int64_t root,
const std::string& name_suffix) {
std::string store_key_prefix = "nvls_export_mcast_handle_" + name_suffix;
// Create symmetric tensor for the buffer
buffer_sym_tensor_ = std::make_unique<SymmetricTensor>(buffer);
// Setup multicast for the buffer
buffer_sym_tensor_->setupRemoteHandles(store_key_prefix + "_buffer_unicast");
// Setup multicast for the buffer
MulticastProtocol protocol = getMulticastProtocol();
if (protocol == MulticastProtocol::Memcpy ||
protocol == MulticastProtocol::Multimem) {
buffer_sym_tensor_->setupMulticast(
root, store_key_prefix + "_buffer_mcast");
}
// Create semaphore tensor
at::Tensor semaphore = SymmetricTensor::allocate(
/*sizes=*/at::IntArrayRef({1}),
/*dtype=*/at::ScalarType::Int,
/*device=*/buffer.device());
// Initialize the semaphore to kIdle
IpcSemaphore init_value = IpcSemaphore::kIdle;
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy(
semaphore.data_ptr(),
&init_value,
sizeof(IpcSemaphore),
cudaMemcpyHostToDevice));
// Create symmetric tensor for the semaphore
semaphore_sym_tensor_ = std::make_unique<SymmetricTensor>(semaphore);
// Setup (unicast) IPC handles for the semaphore
semaphore_sym_tensor_->setupRemoteHandles(store_key_prefix + "_semaphore");
// Setup multicast for the semaphore
if (protocol == MulticastProtocol::Memcpy ||
protocol == MulticastProtocol::Multimem) {
semaphore_sym_tensor_->setupMulticast(
root, store_key_prefix + "_semaphore_mcast");
}
}
void* SymMemForBroadcast::bufferMulticastPtr() const {
return buffer_sym_tensor_->multicastPtr();
}
void* SymMemForBroadcast::bufferUnicastPtr(int64_t rank) const {
return buffer_sym_tensor_->remoteTensor(rank).data_ptr();
}
void* SymMemForBroadcast::semaphoreMulticastPtr() const {
return semaphore_sym_tensor_->multicastPtr();
}
void* SymMemForBroadcast::semaphoreUnicastPtr(int64_t rank) const {
// Use a fixed tag for semaphore remote access
return semaphore_sym_tensor_->remoteTensor(rank).data_ptr();
}
SymmetricMemoryForAllreduce::SymmetricMemoryForAllreduce(
Communication* communication,
at::Tensor output_buffer)
: size_bytes_(output_buffer.numel() * output_buffer.element_size()) {
Communicator& communicator = Communicator::getInstance();
const int64_t world_size = communicator.size();
std::string name_suffix =
"for_Communication" + std::to_string(communication->name());
std::string store_key_prefix = "nvls_allreduce_" + name_suffix;
at::Tensor input_sym = SymmetricTensor::allocate(
output_buffer.sizes(),
output_buffer.scalar_type(),
output_buffer.device());
input_sym_tensor_ = std::make_unique<SymmetricTensor>(input_sym);
input_sym_tensor_->setupRemoteHandles(store_key_prefix + "_input_unicast");
MulticastProtocol protocol = getMulticastProtocol();
if (protocol == MulticastProtocol::Memcpy ||
protocol == MulticastProtocol::Multimem) {
input_sym_tensor_->setupMulticast(
/*exporter_rank=*/0, store_key_prefix + "_input_mcast");
}
// Semaphore vector per device
at::Tensor semaphores = SymmetricTensor::allocate(
at::IntArrayRef({world_size}),
at::ScalarType::Int,
output_buffer.device());
std::vector<IpcSemaphore> init_values(world_size, IpcSemaphore::kIdle);
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy(
semaphores.data_ptr(),
init_values.data(),
world_size * sizeof(IpcSemaphore),
cudaMemcpyHostToDevice));
semaphores_sym_tensor_ = std::make_unique<SymmetricTensor>(semaphores);
semaphores_sym_tensor_->setupRemoteHandles(
store_key_prefix + "_semaphores_unicast");
if (protocol == MulticastProtocol::Memcpy ||
protocol == MulticastProtocol::Multimem) {
semaphores_sym_tensor_->setupMulticast(
/*exporter_rank=*/0, store_key_prefix + "_semaphores_mcast");
}
}
at::Tensor SymmetricMemoryForAllreduce::inputBuffer() const {
return input_sym_tensor_->localTensor();
}
void* SymmetricMemoryForAllreduce::multicastPtr() const {
return input_sym_tensor_->multicastPtr();
}
void* SymmetricMemoryForAllreduce::semaphoreUnicastPtr(
int64_t root_rank,
int64_t rank) const {
uint8_t* base_ptr =
(uint8_t*)semaphores_sym_tensor_->remoteTensor(rank).data_ptr();
return base_ptr + (root_rank * sizeof(IpcSemaphore));
}
SymmetricMemoryForReduce::SymmetricMemoryForReduce(
Communication* communication,
int64_t root,
at::Tensor buffer)
: size_bytes_(buffer.numel() * buffer.element_size()) {
std::string name_suffix =
"for_Communication" + std::to_string(communication->name());
std::string store_key_prefix = "nvls_reduce_" + name_suffix;
// We assume the input buffer is already a symmetric tensor
input_sym_tensor_ = std::make_unique<SymmetricTensor>(buffer);
input_sym_tensor_->setupRemoteHandles(store_key_prefix + "_input_unicast");
MulticastProtocol protocol = getMulticastProtocol();
if (protocol == MulticastProtocol::Memcpy ||
protocol == MulticastProtocol::Multimem) {
input_sym_tensor_->setupMulticast(root, store_key_prefix + "_input_mcast");
}
// Create semaphore tensor
at::Tensor semaphore = SymmetricTensor::allocate(
/*sizes=*/at::IntArrayRef({1}),
/*dtype=*/at::ScalarType::Int,
/*device=*/buffer.device());
// Initialize the semaphore to kIdle
IpcSemaphore init_value = IpcSemaphore::kIdle;
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy(
semaphore.data_ptr(),
&init_value,
sizeof(IpcSemaphore),
cudaMemcpyHostToDevice));
// Create symmetric tensor for the semaphore
semaphore_sym_tensor_ = std::make_unique<SymmetricTensor>(semaphore);
// Setup (unicast) IPC handles for the semaphore
semaphore_sym_tensor_->setupRemoteHandles(store_key_prefix + "_semaphore");
}
at::Tensor SymmetricMemoryForReduce::inputBuffer() const {
return input_sym_tensor_->localTensor();
}
void* SymmetricMemoryForReduce::multicastPtr() const {
return input_sym_tensor_->multicastPtr();
}
void* SymmetricMemoryForReduce::semaphoreUnicastPtr(int64_t rank) const {
return semaphore_sym_tensor_->remoteTensor(rank).data_ptr();
}
SymMemForAllgather::SymMemForAllgather(
Communication* communication,
at::Tensor buffer) {
Communicator& communicator = Communicator::getInstance();
const int64_t world_size = communicator.size();
// Initialize full buffer symmetric tensor for unicast access
// We need to setup unicast handles on the full buffer because
// setupRemoteHandles requires a VMM-aligned allocation, which slices are not.
full_buffer_sym_tensor_ = std::make_unique<SymmetricTensor>(buffer);
std::string full_buffer_suffix =
std::to_string(communication->name()) + "_allgather_full";
// Setup Unicast
full_buffer_sym_tensor_->setupRemoteHandles(
"nvls_export_mcast_handle_" + full_buffer_suffix + "_buffer_unicast");
int64_t slice_numel = buffer.numel() / world_size;
slice_size_bytes_ = slice_numel * buffer.element_size();
// Setup Multicast on full buffer
MulticastProtocol protocol = getMulticastProtocol();
if (protocol == MulticastProtocol::Memcpy ||
protocol == MulticastProtocol::Multimem) {
full_buffer_sym_tensor_->setupMulticast(
/*exporter_rank=*/0,
"nvls_export_mcast_handle_" + full_buffer_suffix + "_buffer_mcast");
}
// Allocate semaphores (one per rank) in a single symmetric tensor
at::Tensor semaphores = SymmetricTensor::allocate(
at::IntArrayRef({world_size}), at::ScalarType::Int, buffer.device());
// Init semaphores to kIdle
std::vector<IpcSemaphore> init_values(world_size, IpcSemaphore::kIdle);
NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy(
semaphores.data_ptr(),
init_values.data(),
world_size * sizeof(IpcSemaphore),
cudaMemcpyHostToDevice));
semaphores_sym_tensor_ = std::make_unique<SymmetricTensor>(semaphores);
semaphores_sym_tensor_->setupRemoteHandles(
"nvls_export_mcast_handle_" + full_buffer_suffix + "_semaphores_unicast");
if (protocol == MulticastProtocol::Memcpy ||
protocol == MulticastProtocol::Multimem) {
semaphores_sym_tensor_->setupMulticast(
/*exporter_rank=*/0,
"nvls_export_mcast_handle_" + full_buffer_suffix + "_semaphores_mcast");
}
}
void* SymMemForAllgather::bufferMulticastPtr(int64_t root_rank) const {
uint8_t* base_ptr = (uint8_t*)full_buffer_sym_tensor_->multicastPtr();
return base_ptr + (root_rank * slice_size_bytes_);
}
void* SymMemForAllgather::bufferUnicastPtr(int64_t root_rank, int64_t rank)
const {
uint8_t* base_ptr =
(uint8_t*)full_buffer_sym_tensor_->remoteTensor(rank).data_ptr();
return base_ptr + (root_rank * slice_size_bytes_);
}
void* SymMemForAllgather::semaphoreMulticastPtr(int64_t root_rank) const {
uint8_t* base_ptr = (uint8_t*)semaphores_sym_tensor_->multicastPtr();
return base_ptr + (root_rank * sizeof(IpcSemaphore));
}
void* SymMemForAllgather::semaphoreUnicastPtr(int64_t root_rank, int64_t rank)
const {
uint8_t* base_ptr =
(uint8_t*)semaphores_sym_tensor_->remoteTensor(rank).data_ptr();
return base_ptr + (root_rank * sizeof(IpcSemaphore));
}
SymmetricMemoryHandle* SymmetricMemoryHandleCache::get(KeyType key) {
auto it = handles_.find(key);
if (it != handles_.end()) {
return it->second.get();
}
// If not found, create a new handle based on the expr type
std::unique_ptr<SymmetricMemoryHandle> handle;
if (auto* contig_view =
dynamic_cast<hir::SymmetricContiguousView*>(key.expr)) {
// SymmetricContiguousView
handle = std::make_unique<SymMemForContiguousView>(key.buffer, contig_view);
} else if (auto* comm = dynamic_cast<Communication*>(key.expr)) {
// Communication (Broadcast/Allgather/Allreduce/Reduce)
if (comm->type() == CommunicationType::Broadcast) {
handle = std::make_unique<SymMemForBroadcast>(comm, key.root, key.buffer);
} else if (comm->type() == CommunicationType::Allgather) {
handle = std::make_unique<SymMemForAllgather>(comm, key.buffer);
} else if (comm->type() == CommunicationType::Allreduce) {
handle = std::make_unique<SymmetricMemoryForAllreduce>(comm, key.buffer);
} else if (comm->type() == CommunicationType::Reduce) {
handle = std::make_unique<SymmetricMemoryForReduce>(
comm, key.root, key.buffer);
} else {
NVF_ERROR(
false,
"Unsupported communication type for multicast handle: ",
comm->type());
}
} else {
NVF_ERROR(
false, "Unsupported expr type for symmetric memory handle: ", key.expr);
}
auto inserted = handles_.emplace(key, std::move(handle));
return inserted.first->second.get();
}
SymMemForContiguousView::SymMemForContiguousView(
at::Tensor in_tensor,
hir::SymmetricContiguousView* contig_view) {
std::string tag = "contig_view_" + std::to_string(contig_view->name());
sym_tensor_ = std::make_unique<SymmetricTensor>(in_tensor);
sym_tensor_->setupRemoteHandles(tag + "_remote_handles");
sym_tensor_->setupContiguousView(tag);
tensor_ = sym_tensor_->getContiguousView();
}
} // namespace nvfuser