@@ -60,8 +60,7 @@ def worker_fn():
6060 device = get_world_group ().device )
6161 tensor = torch .ones (16 , 1024 , 1024 ,
6262 dtype = torch .float32 ).cuda (pynccl_comm .rank )
63- with pynccl_comm .change_state (enable = True ):
64- tensor = pynccl_comm .all_reduce (tensor )
63+ tensor = pynccl_comm .all_reduce (tensor )
6564 torch .cuda .synchronize ()
6665 assert torch .all (tensor == pynccl_comm .world_size ).cpu ().item ()
6766
@@ -82,17 +81,16 @@ def multiple_allreduce_worker_fn():
8281 group = groups [0 ] if torch .distributed .get_rank () in [0 , 1 ] else groups [1 ]
8382 pynccl_comm = PyNcclCommunicator (group = group , device = device )
8483 tensor = torch .ones (16 , 1024 , 1024 , dtype = torch .float32 , device = device )
85- with pynccl_comm .change_state (enable = True ):
86- # two groups can communicate independently
87- if torch .distributed .get_rank () in [0 , 1 ]:
88- tensor = pynccl_comm .all_reduce (tensor )
89- tensor = pynccl_comm .all_reduce (tensor )
90- torch .cuda .synchronize ()
91- assert torch .all (tensor == 4 ).cpu ().item ()
92- else :
93- tensor = pynccl_comm .all_reduce (tensor )
94- torch .cuda .synchronize ()
95- assert torch .all (tensor == 2 ).cpu ().item ()
84+ # two groups can communicate independently
85+ if torch .distributed .get_rank () in [0 , 1 ]:
86+ tensor = pynccl_comm .all_reduce (tensor )
87+ tensor = pynccl_comm .all_reduce (tensor )
88+ torch .cuda .synchronize ()
89+ assert torch .all (tensor == 4 ).cpu ().item ()
90+ else :
91+ tensor = pynccl_comm .all_reduce (tensor )
92+ torch .cuda .synchronize ()
93+ assert torch .all (tensor == 2 ).cpu ().item ()
9694
9795
9896@pytest .mark .skipif (torch .cuda .device_count () < 4 ,
@@ -138,9 +136,7 @@ def worker_fn_with_cudagraph():
138136 # run something in the default stream to initialize torch engine
139137 a = torch .ones ((4 , 4 ), device = f'cuda:{ pynccl_comm .rank } ' )
140138 torch .cuda .synchronize ()
141- with torch .cuda .graph (
142- graph , stream = pynccl_comm .stream ), pynccl_comm .change_state (
143- enable = True ):
139+ with torch .cuda .graph (graph ):
144140 a_out = pynccl_comm .all_reduce (a )
145141 torch .cuda .synchronize ()
146142 graph .replay ()
@@ -169,8 +165,7 @@ def all_gather_worker_fn():
169165 for r in range (world_size )
170166 ]).to (device )
171167
172- with pynccl_comm .change_state (enable = True ):
173- pynccl_comm .all_gather (result , tensor )
168+ pynccl_comm .all_gather (result , tensor )
174169 torch .cuda .synchronize ()
175170 torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-8 )
176171
@@ -207,8 +202,7 @@ def reduce_scatter_worker_fn():
207202 expected = sum (tensor [rank * scattered_size :(rank + 1 ) * scattered_size ]
208203 for tensor in all_tensors ).to (device )
209204
210- with pynccl_comm .change_state (enable = True ):
211- pynccl_comm .reduce_scatter (result , tensor )
205+ pynccl_comm .reduce_scatter (result , tensor )
212206 torch .cuda .synchronize ()
213207 torch .testing .assert_close (result , expected , rtol = 1e-5 , atol = 1e-8 )
214208
@@ -235,15 +229,13 @@ def send_recv_worker_fn():
235229 else :
236230 tensor = torch .empty (16 , 1024 , 1024 ,
237231 dtype = torch .float32 ).cuda (pynccl_comm .rank )
238- with pynccl_comm .change_state (enable = True ):
239- if pynccl_comm .rank == 0 :
240- pynccl_comm .send (tensor ,
241- dst = (pynccl_comm .rank + 1 ) %
242- pynccl_comm .world_size )
243- else :
244- pynccl_comm .recv (tensor ,
245- src = (pynccl_comm .rank - 1 ) %
246- pynccl_comm .world_size )
232+
233+ if pynccl_comm .rank == 0 :
234+ pynccl_comm .send (tensor ,
235+ dst = (pynccl_comm .rank + 1 ) % pynccl_comm .world_size )
236+ else :
237+ pynccl_comm .recv (tensor ,
238+ src = (pynccl_comm .rank - 1 ) % pynccl_comm .world_size )
247239 torch .cuda .synchronize ()
248240 assert torch .all (tensor == 1 ).cpu ().item ()
249241
@@ -274,15 +266,12 @@ def multiple_send_recv_worker_fn():
274266 1024 ,
275267 dtype = torch .float32 ,
276268 device = device )
277- with pynccl_comm .change_state (enable = True ):
278- if torch .distributed .get_rank () in [0 , 1 ]:
279- pynccl_comm .send (tensor ,
280- dst = (pynccl_comm .rank + 1 ) %
281- pynccl_comm .world_size )
282- else :
283- pynccl_comm .recv (tensor ,
284- src = (pynccl_comm .rank - 1 ) %
285- pynccl_comm .world_size )
269+ if torch .distributed .get_rank () in [0 , 1 ]:
270+ pynccl_comm .send (tensor ,
271+ dst = (pynccl_comm .rank + 1 ) % pynccl_comm .world_size )
272+ else :
273+ pynccl_comm .recv (tensor ,
274+ src = (pynccl_comm .rank - 1 ) % pynccl_comm .world_size )
286275 torch .cuda .synchronize ()
287276 if torch .distributed .get_rank () in [0 , 2 ]:
288277 assert torch .all (tensor == 1 ).cpu ().item ()
0 commit comments