@@ -43,21 +43,24 @@ struct NVSHMEMAllocation {
43
43
}
44
44
};
45
45
46
- class NVSHMEMSymmetricMemory : public SymmetricMemory {
46
+ // A class to hold the base pointers and signal pad pointers for a group of
47
+ // peers. One `NVSHMEMPeerAllocInfo` object can be shared by multiple
48
+ // `NVSHMEMSymmetricMemory` objects when latter reside on the same allocation
49
+ // and rendezvous over the same group. (The `NVSHMEMSymmetricMemory` objects may
50
+ // have different offsets compared to the base address.)
51
+ class NVSHMEMPeerAllocInfo : public c10 ::intrusive_ptr_target {
47
52
public:
48
- NVSHMEMSymmetricMemory (
53
+ NVSHMEMPeerAllocInfo (
49
54
std::shared_ptr<NVSHMEMAllocation> allocation,
50
55
const std::string& group_name)
51
- : allocation_(allocation),
52
- buffer_size_ (allocation->buffer_size),
53
- device_idx_(allocation->device_idx),
54
- group_name_(group_name) {
56
+ : base_ptr_(allocation->ptr),
57
+ buffer_size_ (allocation->buffer_size) {
55
58
// For logging only
56
59
static int exchanged_n_times = 0 ;
57
- c10::cuda::CUDAGuard guard (device_idx_ );
60
+ c10::cuda::CUDAGuard guard (allocation-> device_idx );
58
61
59
62
auto global_rank = get_group_info (" 0" ).rank ;
60
- GroupInfo& group_info = get_group_info (group_name_ );
63
+ GroupInfo& group_info = get_group_info (group_name );
61
64
auto store = group_info.store ;
62
65
rank_ = group_info.rank ;
63
66
world_size_ = group_info.world_size ;
@@ -70,15 +73,15 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
70
73
if (rank_ == 0 ) {
71
74
LOG (INFO) << " [rank " << rank_ << " ]"
72
75
<< " rank_to_global_rank: " << group_info.rank_to_global_rank
73
- << " , group_name: " << group_name_
76
+ << " , group_name: " << group_name
74
77
<< " , exchanged_n_times: " << exchanged_n_times;
75
78
}
76
79
}
77
80
TORCH_INTERNAL_ASSERT (!group_info.rank_to_global_rank .empty ());
78
81
rank_to_global_rank_ = group_info.rank_to_global_rank ;
79
82
for (int r = 0 ; r < world_size_; ++r) {
80
83
buffers_.push_back (nvshmem_ptr (
81
- allocation-> ptr , rank_to_global_rank_[r]));
84
+ base_ptr_ , rank_to_global_rank_[r]));
82
85
}
83
86
84
87
// TODO: use the same allocation for signal pad
@@ -114,28 +117,68 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
114
117
cudaMemcpyHostToDevice));
115
118
}
116
119
120
+ private:
121
+ void * base_ptr_;
122
+ size_t buffer_size_;
123
+ int rank_;
124
+ int world_size_;
125
+ std::vector<void *> buffers_;
126
+ std::vector<void *> signal_pads_;
127
+ void ** buffers_dev_;
128
+ void ** signal_pads_dev_;
129
+ std::vector<int > rank_to_global_rank_;
130
+ int * rank_to_global_rank_dev_;
131
+
132
+ friend class NVSHMEMSymmetricMemory ;
133
+ };
134
+
135
+ class NVSHMEMSymmetricMemory : public SymmetricMemory {
136
+ public:
137
+ NVSHMEMSymmetricMemory (
138
+ std::shared_ptr<NVSHMEMAllocation> allocation,
139
+ const std::string& group_name)
140
+ : allocation_(allocation),
141
+ device_idx_ (allocation->device_idx),
142
+ group_name_(group_name) {
143
+ // A handle stores two types of info:
144
+ // (i) allocation's base ptrs and base signal pads, ours and peers'
145
+ pai_ = c10::make_intrusive<NVSHMEMPeerAllocInfo>(allocation, group_name);
146
+ // (ii) offset of tensor compared to base ptr (in byte)
147
+ offset_ = 0 ;
148
+ }
149
+
150
+ // Exact copy is not needed / supported
151
+ NVSHMEMSymmetricMemory (const NVSHMEMSymmetricMemory& other) = delete;
152
+
153
+ // Copy with offset is allowed
154
+ // This is mostly a shallow copy that shares the pointer to `NVSHMEMPeerAllocInfo` which has been created by `other`
155
+ NVSHMEMSymmetricMemory (const NVSHMEMSymmetricMemory& other, size_t offset)
156
+ : allocation_(other.allocation_), device_idx_(other.device_idx_), group_name_(other.group_name_), pai_(other.pai_) {
157
+ offset_ = offset;
158
+ }
159
+
117
160
~NVSHMEMSymmetricMemory () override {
118
161
// TODO
119
162
};
120
163
121
164
std::vector<void *> get_buffer_ptrs () override {
122
- return buffers_;
165
+ return pai_-> buffers_ ;
123
166
}
124
167
125
168
std::vector<void *> get_signal_pad_ptrs () override {
126
- return signal_pads_;
169
+ return pai_-> signal_pads_ ;
127
170
}
128
171
129
172
void ** get_buffer_ptrs_dev () override {
130
- return buffers_dev_;
173
+ return pai_-> buffers_dev_ ;
131
174
}
132
175
133
176
void ** get_signal_pad_ptrs_dev () override {
134
- return signal_pads_dev_;
177
+ return pai_-> signal_pads_dev_ ;
135
178
}
136
179
137
180
size_t get_buffer_size () override {
138
- return buffer_size_;
181
+ return pai_-> buffer_size_ ;
139
182
}
140
183
141
184
size_t get_signal_pad_size () override {
@@ -166,13 +209,13 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
166
209
const auto element_size = c10::elementSize (dtype);
167
210
const auto req_size = (numel + storage_offset) * element_size;
168
211
TORCH_CHECK (
169
- req_size <= buffer_size_ ,
212
+ req_size <= allocation_-> buffer_size ,
170
213
" NVSHMEMSymmetricMemory::get_buffer: the requested size (" ,
171
214
req_size,
172
215
" bytes) exceeds the allocated size (" ,
173
- buffer_size_ ,
216
+ allocation_-> buffer_size ,
174
217
" bytes)" );
175
- auto data_ptr = reinterpret_cast <uint8_t *>(buffers_[rank]) +
218
+ auto data_ptr = reinterpret_cast <uint8_t *>(pai_-> buffers_ [rank]) +
176
219
storage_offset * element_size;
177
220
auto device = c10::Device (c10::DeviceType::CUDA, device_idx_);
178
221
auto options = at::TensorOptions ().dtype (dtype).device (device);
@@ -216,7 +259,7 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
216
259
" bytes) exceeds the allocated size (" ,
217
260
signal_pad_size,
218
261
" bytes)" );
219
- auto data_ptr = reinterpret_cast <uint8_t *>(signal_pads_[rank]) +
262
+ auto data_ptr = reinterpret_cast <uint8_t *>(pai_-> signal_pads_ [rank]) +
220
263
storage_offset * element_size;
221
264
auto device = c10::Device (c10::DeviceType::CUDA, device_idx_);
222
265
auto options = at::TensorOptions ().dtype (*dtype).device (device);
@@ -239,35 +282,27 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
239
282
}
240
283
241
284
int get_rank () override {
242
- return rank_;
285
+ return pai_-> rank_ ;
243
286
}
244
287
245
288
int get_world_size () override {
246
- return world_size_;
289
+ return pai_-> world_size_ ;
247
290
}
248
291
249
- virtual const std::vector<int >& get_rank_to_global_rank () override {
250
- return rank_to_global_rank_;
292
+ const std::vector<int >& get_rank_to_global_rank () override {
293
+ return pai_-> rank_to_global_rank_ ;
251
294
};
252
295
253
296
int * get_rank_to_global_rank_dev () override {
254
- return rank_to_global_rank_dev_;
297
+ return pai_-> rank_to_global_rank_dev_ ;
255
298
};
256
299
257
300
private:
258
301
std::shared_ptr<NVSHMEMAllocation> allocation_;
259
- size_t buffer_size_;
260
- std::vector<void *> buffers_;
261
- std::vector<void *> signal_pads_;
262
302
int device_idx_;
263
- int rank_;
264
- int world_size_;
265
- void ** buffers_dev_;
266
- void ** signal_pads_dev_;
267
303
std::string group_name_;
268
-
269
- std::vector<int > rank_to_global_rank_;
270
- int * rank_to_global_rank_dev_;
304
+ c10::intrusive_ptr<NVSHMEMPeerAllocInfo> pai_;
305
+ size_t offset_{0 }; // in byte
271
306
};
272
307
273
308
// Bootstrap based on user's setting for NCCL
@@ -379,13 +414,48 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
379
414
return it->second ;
380
415
}
381
416
}
382
- auto it = allocations_.find (ptr);
383
- TORCH_CHECK (it != allocations_.end ());
384
- auto symm_mem =
385
- c10::make_intrusive<NVSHMEMSymmetricMemory>(it->second , *group_name);
417
+ // In case of MemPool, tensor.storage().data_ptr() may not match
418
+ // exactly an allocation's base address. Thus we perform the search by
419
+ // testing if the former is within an allocation's range.
420
+ auto alloc_it = std::find_if (allocations_.begin (), allocations_.end (),
421
+ [&](const auto & pair){
422
+ auto & allocation = pair.second ;
423
+ auto ptr_int = reinterpret_cast <uintptr_t >(ptr);
424
+ auto base_ptr = reinterpret_cast <uintptr_t >(allocation->ptr );
425
+ return ptr_int >= base_ptr && ptr_int < base_ptr + allocation->buffer_size ; });
426
+ TORCH_CHECK (alloc_it != allocations_.end (),
427
+ " Pointer not within any SymmetricMemory allocation, "
428
+ " is the tensor allocated from SymmetricMemory?" );
429
+
430
+ auto & allocation = alloc_it->second ;
431
+
432
+ // Search again using allocation base ptr (which is the key we use for caching, see below)
433
+ auto it = symm_mems_.find (std::make_tuple (allocation->ptr , *group_name));
434
+ c10::intrusive_ptr<NVSHMEMSymmetricMemory> symm_mem;
435
+ if (it != symm_mems_.end ()) {
436
+ // Base allocation has been rendezvoused
437
+ symm_mem = it->second ;
438
+ } else {
439
+ // Create a new rendezvous
440
+ symm_mem =
441
+ c10::make_intrusive<NVSHMEMSymmetricMemory>(allocation, *group_name);
442
+ }
443
+
444
+ // Cache rendezvous using allocation's base address as key
445
+ symm_mems_[std::make_tuple (allocation->ptr , *group_name)] = symm_mem;
386
446
387
- symm_mems_[std::make_tuple (ptr, *group_name)] = symm_mem;
388
- return symm_mem;
447
+ // TODO: change the `ptr` below to `tensor.data_ptr()` when adding support
448
+ // for user slice/view operations. For MemPool support,
449
+ // `tensor.storate().data_ptr()` is fine (today's `ptr`).
450
+
451
+ // If the tensor's ptr happen to be the same as allocation ptr
452
+ if (ptr == allocation->ptr ) {
453
+ return symm_mem;
454
+ } else {
455
+ // Return a copy of the SymmetricMemory with an offset. This is a
456
+ // "shallow" copy adjusting the offset field in the handle.
457
+ return c10::make_intrusive<NVSHMEMSymmetricMemory>(*symm_mem, (uintptr_t )ptr - (uintptr_t )allocation->ptr );
458
+ }
389
459
};
390
460
391
461
bool has_multicast_support (int device_idx) override {
@@ -403,7 +473,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
403
473
404
474
private:
405
475
std::unordered_map<void *, std::shared_ptr<NVSHMEMAllocation>> allocations_;
406
- std::map<std::tuple<void *, std::string>, c10::intrusive_ptr<SymmetricMemory >>
476
+ std::map<std::tuple<void *, std::string>, c10::intrusive_ptr<NVSHMEMSymmetricMemory >>
407
477
symm_mems_;
408
478
};
409
479
0 commit comments