Skip to content

Commit 8dd5aa9

Browse files
kwen2501pytorchmergebot
authored andcommitted
[1/N][SymmMem] Add offset to handle, cache on base address (pytorch#161470)
For the kernels that need peer pointers directly, the rendezvous handle should allow user to get the offset of tensor wrt to base allocation address. Thus the need to add an `offset` field to SymmMem handle. But we don't want to cache all the handles just bc they have different offsets, hence the search and cache logic below: (i) At rendezvous, the search key is still `x.storage().data_ptr()`, like now, but it should do search in 2 parts - one is just dictionary lookup, like today, if that failed, it needs to search `allocations_` to see if the storage ptr falls in one of the segments. This is possible as we have all segments recorded during alloc. (ii) If this segment hasn't been rendezvoused, we rendezvous it, cache it in the `symm_mem_` map with its base address as key. (iii) We still need to return a handle for the current tensor, with a corresponding offset. This handle will be a shallow copy of the base handle, with the offset adjusted. Some impl details: (i.1) If we find a matching allocation, we can immediately use the allocation base address to do a re-search in `symm_mem_`. (iii.1) To make the handle copy shallow, we move the common information -- base ptrs, base signal pad, etc -- to a structure referenced by both handles. The structure is called `NVSHMEMPeerAllocInfo`. A copy of handle just adds one more `intrusive_ptr` to it. The handle copy constructor accepts an `offset` argument. Test: Existing tests should not fail. Pull Request resolved: pytorch#161470 Approved by: https://github.com/ngimel
1 parent 8ff9485 commit 8dd5aa9

File tree

1 file changed

+111
-41
lines changed

1 file changed

+111
-41
lines changed

torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu

Lines changed: 111 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,24 @@ struct NVSHMEMAllocation {
4343
}
4444
};
4545

46-
class NVSHMEMSymmetricMemory : public SymmetricMemory {
46+
// A class to hold the base pointers and signal pad pointers for a group of
47+
// peers. One `NVSHMEMPeerAllocInfo` object can be shared by multiple
48+
// `NVSHMEMSymmetricMemory` objects when latter reside on the same allocation
49+
// and rendezvous over the same group. (The `NVSHMEMSymmetricMemory` objects may
50+
// have different offsets compared to the base address.)
51+
class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target {
4752
public:
48-
NVSHMEMSymmetricMemory(
53+
NVSHMEMPeerAllocInfo(
4954
std::shared_ptr<NVSHMEMAllocation> allocation,
5055
const std::string& group_name)
51-
: allocation_(allocation),
52-
buffer_size_(allocation->buffer_size),
53-
device_idx_(allocation->device_idx),
54-
group_name_(group_name) {
56+
: base_ptr_(allocation->ptr),
57+
buffer_size_(allocation->buffer_size) {
5558
// For logging only
5659
static int exchanged_n_times = 0;
57-
c10::cuda::CUDAGuard guard(device_idx_);
60+
c10::cuda::CUDAGuard guard(allocation->device_idx);
5861

5962
auto global_rank = get_group_info("0").rank;
60-
GroupInfo& group_info = get_group_info(group_name_);
63+
GroupInfo& group_info = get_group_info(group_name);
6164
auto store = group_info.store;
6265
rank_ = group_info.rank;
6366
world_size_ = group_info.world_size;
@@ -70,15 +73,15 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
7073
if (rank_ == 0) {
7174
LOG(INFO) << "[rank " << rank_ << "]"
7275
<< " rank_to_global_rank: " << group_info.rank_to_global_rank
73-
<< ", group_name: " << group_name_
76+
<< ", group_name: " << group_name
7477
<< ", exchanged_n_times: " << exchanged_n_times;
7578
}
7679
}
7780
TORCH_INTERNAL_ASSERT(!group_info.rank_to_global_rank.empty());
7881
rank_to_global_rank_ = group_info.rank_to_global_rank;
7982
for (int r = 0; r < world_size_; ++r) {
8083
buffers_.push_back(nvshmem_ptr(
81-
allocation->ptr, rank_to_global_rank_[r]));
84+
base_ptr_, rank_to_global_rank_[r]));
8285
}
8386

8487
// TODO: use the same allocation for signal pad
@@ -114,28 +117,68 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
114117
cudaMemcpyHostToDevice));
115118
}
116119

120+
private:
121+
void* base_ptr_;
122+
size_t buffer_size_;
123+
int rank_;
124+
int world_size_;
125+
std::vector<void*> buffers_;
126+
std::vector<void*> signal_pads_;
127+
void** buffers_dev_;
128+
void** signal_pads_dev_;
129+
std::vector<int> rank_to_global_rank_;
130+
int* rank_to_global_rank_dev_;
131+
132+
friend class NVSHMEMSymmetricMemory;
133+
};
134+
135+
class NVSHMEMSymmetricMemory : public SymmetricMemory {
136+
public:
137+
NVSHMEMSymmetricMemory(
138+
std::shared_ptr<NVSHMEMAllocation> allocation,
139+
const std::string& group_name)
140+
: allocation_(allocation),
141+
device_idx_(allocation->device_idx),
142+
group_name_(group_name) {
143+
// A handle stores two types of info:
144+
// (i) allocation's base ptrs and base signal pads, ours and peers'
145+
pai_ = c10::make_intrusive<NVSHMEMPeerAllocInfo>(allocation, group_name);
146+
// (ii) offset of tensor compared to base ptr (in byte)
147+
offset_ = 0;
148+
}
149+
150+
// Exact copy is not needed / supported
151+
NVSHMEMSymmetricMemory(const NVSHMEMSymmetricMemory& other) = delete;
152+
153+
// Copy with offset is allowed
154+
// This is mostly a shallow copy that shares the pointer to `NVSHMEMPeerAllocInfo` which has been created by `other`
155+
NVSHMEMSymmetricMemory(const NVSHMEMSymmetricMemory& other, size_t offset)
156+
: allocation_(other.allocation_), device_idx_(other.device_idx_), group_name_(other.group_name_), pai_(other.pai_) {
157+
offset_ = offset;
158+
}
159+
117160
~NVSHMEMSymmetricMemory() override{
118161
// TODO
119162
};
120163

121164
std::vector<void*> get_buffer_ptrs() override {
122-
return buffers_;
165+
return pai_->buffers_;
123166
}
124167

125168
std::vector<void*> get_signal_pad_ptrs() override {
126-
return signal_pads_;
169+
return pai_->signal_pads_;
127170
}
128171

129172
void** get_buffer_ptrs_dev() override {
130-
return buffers_dev_;
173+
return pai_->buffers_dev_;
131174
}
132175

133176
void** get_signal_pad_ptrs_dev() override {
134-
return signal_pads_dev_;
177+
return pai_->signal_pads_dev_;
135178
}
136179

137180
size_t get_buffer_size() override {
138-
return buffer_size_;
181+
return pai_->buffer_size_;
139182
}
140183

141184
size_t get_signal_pad_size() override {
@@ -166,13 +209,13 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
166209
const auto element_size = c10::elementSize(dtype);
167210
const auto req_size = (numel + storage_offset) * element_size;
168211
TORCH_CHECK(
169-
req_size <= buffer_size_,
212+
req_size <= allocation_->buffer_size,
170213
"NVSHMEMSymmetricMemory::get_buffer: the requested size (",
171214
req_size,
172215
" bytes) exceeds the allocated size (",
173-
buffer_size_,
216+
allocation_->buffer_size,
174217
" bytes)");
175-
auto data_ptr = reinterpret_cast<uint8_t*>(buffers_[rank]) +
218+
auto data_ptr = reinterpret_cast<uint8_t*>(pai_->buffers_[rank]) +
176219
storage_offset * element_size;
177220
auto device = c10::Device(c10::DeviceType::CUDA, device_idx_);
178221
auto options = at::TensorOptions().dtype(dtype).device(device);
@@ -216,7 +259,7 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
216259
" bytes) exceeds the allocated size (",
217260
signal_pad_size,
218261
" bytes)");
219-
auto data_ptr = reinterpret_cast<uint8_t*>(signal_pads_[rank]) +
262+
auto data_ptr = reinterpret_cast<uint8_t*>(pai_->signal_pads_[rank]) +
220263
storage_offset * element_size;
221264
auto device = c10::Device(c10::DeviceType::CUDA, device_idx_);
222265
auto options = at::TensorOptions().dtype(*dtype).device(device);
@@ -239,35 +282,27 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
239282
}
240283

241284
int get_rank() override {
242-
return rank_;
285+
return pai_->rank_;
243286
}
244287

245288
int get_world_size() override {
246-
return world_size_;
289+
return pai_->world_size_;
247290
}
248291

249-
virtual const std::vector<int>& get_rank_to_global_rank() override {
250-
return rank_to_global_rank_;
292+
const std::vector<int>& get_rank_to_global_rank() override {
293+
return pai_->rank_to_global_rank_;
251294
};
252295

253296
int* get_rank_to_global_rank_dev() override {
254-
return rank_to_global_rank_dev_;
297+
return pai_->rank_to_global_rank_dev_;
255298
};
256299

257300
private:
258301
std::shared_ptr<NVSHMEMAllocation> allocation_;
259-
size_t buffer_size_;
260-
std::vector<void*> buffers_;
261-
std::vector<void*> signal_pads_;
262302
int device_idx_;
263-
int rank_;
264-
int world_size_;
265-
void** buffers_dev_;
266-
void** signal_pads_dev_;
267303
std::string group_name_;
268-
269-
std::vector<int> rank_to_global_rank_;
270-
int* rank_to_global_rank_dev_;
304+
c10::intrusive_ptr<NVSHMEMPeerAllocInfo> pai_;
305+
size_t offset_{0}; // in byte
271306
};
272307

273308
// Bootstrap based on user's setting for NCCL
@@ -379,13 +414,48 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
379414
return it->second;
380415
}
381416
}
382-
auto it = allocations_.find(ptr);
383-
TORCH_CHECK(it != allocations_.end());
384-
auto symm_mem =
385-
c10::make_intrusive<NVSHMEMSymmetricMemory>(it->second, *group_name);
417+
// In case of MemPool, tensor.storage().data_ptr() may not match
418+
// exactly an allocation's base address. Thus we perform the search by
419+
// testing if the former is within an allocation's range.
420+
auto alloc_it = std::find_if(allocations_.begin(), allocations_.end(),
421+
[&](const auto& pair){
422+
auto& allocation = pair.second;
423+
auto ptr_int = reinterpret_cast<uintptr_t>(ptr);
424+
auto base_ptr = reinterpret_cast<uintptr_t>(allocation->ptr);
425+
return ptr_int >= base_ptr && ptr_int < base_ptr + allocation->buffer_size; });
426+
TORCH_CHECK(alloc_it != allocations_.end(),
427+
"Pointer not within any SymmetricMemory allocation, "
428+
"is the tensor allocated from SymmetricMemory?");
429+
430+
auto& allocation = alloc_it->second;
431+
432+
// Search again using allocation base ptr (which is the key we use for caching, see below)
433+
auto it = symm_mems_.find(std::make_tuple(allocation->ptr, *group_name));
434+
c10::intrusive_ptr<NVSHMEMSymmetricMemory> symm_mem;
435+
if (it != symm_mems_.end()) {
436+
// Base allocation has been rendezvoused
437+
symm_mem = it->second;
438+
} else {
439+
// Create a new rendezvous
440+
symm_mem =
441+
c10::make_intrusive<NVSHMEMSymmetricMemory>(allocation, *group_name);
442+
}
443+
444+
// Cache rendezvous using allocation's base address as key
445+
symm_mems_[std::make_tuple(allocation->ptr, *group_name)] = symm_mem;
386446

387-
symm_mems_[std::make_tuple(ptr, *group_name)] = symm_mem;
388-
return symm_mem;
447+
// TODO: change the `ptr` below to `tensor.data_ptr()` when adding support
448+
// for user slice/view operations. For MemPool support,
449+
// `tensor.storate().data_ptr()` is fine (today's `ptr`).
450+
451+
// If the tensor's ptr happen to be the same as allocation ptr
452+
if (ptr == allocation->ptr) {
453+
return symm_mem;
454+
} else {
455+
// Return a copy of the SymmetricMemory with an offset. This is a
456+
// "shallow" copy adjusting the offset field in the handle.
457+
return c10::make_intrusive<NVSHMEMSymmetricMemory>(*symm_mem, (uintptr_t)ptr - (uintptr_t)allocation->ptr);
458+
}
389459
};
390460

391461
bool has_multicast_support(int device_idx) override {
@@ -403,7 +473,7 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
403473

404474
private:
405475
std::unordered_map<void*, std::shared_ptr<NVSHMEMAllocation>> allocations_;
406-
std::map<std::tuple<void*, std::string>, c10::intrusive_ptr<SymmetricMemory>>
476+
std::map<std::tuple<void*, std::string>, c10::intrusive_ptr<NVSHMEMSymmetricMemory>>
407477
symm_mems_;
408478
};
409479

0 commit comments

Comments
 (0)