Skip to content

Commit 68d395d

Browse files
kwen2501pytorchmergebot
authored andcommitted
[3/N][SymmMem] Expose offset field from handle (pytorch#161532)
As titled, so that kernels relying on direct pointers can use base address and `hdl.offset` to access remote memory. Pull Request resolved: pytorch#161532 Approved by: https://github.com/ngimel ghstack dependencies: pytorch#161470, pytorch#161471
1 parent 4ed71d5 commit 68d395d

File tree

4 files changed

+32
-0
lines changed

4 files changed

+32
-0
lines changed

test/distributed/test_nvshmem.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,29 @@ def test_mempool_compute_ops(self) -> None:
117117
expected = torch.mm(x0, w)
118118
self.assertEqual(y, expected)
119119

120+
@skipIfRocm
121+
def test_handle_offset(self) -> None:
122+
"""
123+
Test if handle offset is correctly set.
124+
"""
125+
self._init_device()
126+
group_name = dist.group.WORLD.group_name
127+
symm_mem.enable_symm_mem_for_group(group_name)
128+
129+
dtype = torch.float
130+
numel = 1024
131+
allocator = symm_mem.get_mempool_allocator(self.device)
132+
mempool = torch.cuda.MemPool(allocator)
133+
134+
with torch.cuda.use_mem_pool(mempool):
135+
x0 = torch.empty(numel, dtype=dtype, device=self.device)
136+
x1 = torch.empty_like(x0)
137+
138+
hdl0 = symm_mem.rendezvous(x0, group=group_name)
139+
hdl1 = symm_mem.rendezvous(x1, group=group_name)
140+
self.assertEqual(hdl0.offset, 0)
141+
self.assertEqual(hdl1.offset, x0.untyped_storage().nbytes())
142+
120143
@skipIfRocm
121144
def test_nvshmem_put(self) -> None:
122145
self._init_device()

torch/csrc/distributed/c10d/init.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,6 +1170,7 @@ This class does not support ``__members__`` property.)");
11701170
.def_property_readonly("buffer_size", &SymmetricMemory::get_buffer_size)
11711171
.def_property_readonly(
11721172
"signal_pad_size", &SymmetricMemory::get_signal_pad_size)
1173+
.def_property_readonly("offset", &SymmetricMemory::get_offset)
11731174
.def(
11741175
"get_buffer",
11751176
&SymmetricMemory::get_buffer,

torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
195195
return nullptr;
196196
}
197197

198+
size_t get_offset() override {
199+
return offset_;
200+
}
201+
198202
at::Tensor get_buffer(
199203
int rank,
200204
c10::IntArrayRef sizes,

torch/csrc/distributed/c10d/symm_mem/SymmetricMemory.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
5050
virtual size_t get_buffer_size() = 0;
5151
virtual size_t get_signal_pad_size() = 0;
5252

53+
virtual size_t get_offset() {
54+
TORCH_CHECK(false, "NYI");
55+
}
56+
5357
virtual bool has_multicast_support() = 0;
5458
virtual void* get_multicast_ptr() = 0;
5559

0 commit comments

Comments
 (0)