File tree Expand file tree Collapse file tree 4 files changed +32
-0
lines changed
torch/csrc/distributed/c10d Expand file tree Collapse file tree 4 files changed +32
-0
lines changed Original file line number Diff line number Diff line change @@ -117,6 +117,29 @@ def test_mempool_compute_ops(self) -> None:
117
117
expected = torch .mm (x0 , w )
118
118
self .assertEqual (y , expected )
119
119
120
+ @skipIfRocm
121
+ def test_handle_offset (self ) -> None :
122
+ """
123
+ Test if handle offset is correctly set.
124
+ """
125
+ self ._init_device ()
126
+ group_name = dist .group .WORLD .group_name
127
+ symm_mem .enable_symm_mem_for_group (group_name )
128
+
129
+ dtype = torch .float
130
+ numel = 1024
131
+ allocator = symm_mem .get_mempool_allocator (self .device )
132
+ mempool = torch .cuda .MemPool (allocator )
133
+
134
+ with torch .cuda .use_mem_pool (mempool ):
135
+ x0 = torch .empty (numel , dtype = dtype , device = self .device )
136
+ x1 = torch .empty_like (x0 )
137
+
138
+ hdl0 = symm_mem .rendezvous (x0 , group = group_name )
139
+ hdl1 = symm_mem .rendezvous (x1 , group = group_name )
140
+ self .assertEqual (hdl0 .offset , 0 )
141
+ self .assertEqual (hdl1 .offset , x0 .untyped_storage ().nbytes ())
142
+
120
143
@skipIfRocm
121
144
def test_nvshmem_put (self ) -> None :
122
145
self ._init_device ()
Original file line number Diff line number Diff line change @@ -1170,6 +1170,7 @@ This class does not support ``__members__`` property.)");
1170
1170
.def_property_readonly (" buffer_size" , &SymmetricMemory::get_buffer_size)
1171
1171
.def_property_readonly (
1172
1172
" signal_pad_size" , &SymmetricMemory::get_signal_pad_size)
1173
+ .def_property_readonly (" offset" , &SymmetricMemory::get_offset)
1173
1174
.def (
1174
1175
" get_buffer" ,
1175
1176
&SymmetricMemory::get_buffer,
Original file line number Diff line number Diff line change @@ -195,6 +195,10 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
195
195
return nullptr ;
196
196
}
197
197
198
+ size_t get_offset () override {
199
+ return offset_;
200
+ }
201
+
198
202
at::Tensor get_buffer (
199
203
int rank,
200
204
c10::IntArrayRef sizes,
Original file line number Diff line number Diff line change @@ -50,6 +50,10 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target {
50
50
virtual size_t get_buffer_size () = 0;
51
51
virtual size_t get_signal_pad_size () = 0;
52
52
53
+ virtual size_t get_offset () {
54
+ TORCH_CHECK (false , " NYI" );
55
+ }
56
+
53
57
virtual bool has_multicast_support () = 0;
54
58
virtual void * get_multicast_ptr () = 0;
55
59
You can’t perform that action at this time.
0 commit comments