@@ -55,22 +55,23 @@ def _write_remote_kernel(
5555@pytest .mark .parametrize ("barrier_type" , BARRIER_TYPES )
5656def test_barrier_basic (barrier_type , n ):
5757 shmem = iris .iris (1 << 20 )
58- shmem . barrier ( )
58+ _call_barrier ( shmem , barrier_type )
5959
6060 try :
6161 for _ in range (n ):
6262 _call_barrier (shmem , barrier_type )
6363 finally :
64- shmem . barrier ( )
64+ _call_barrier ( shmem , barrier_type )
6565 del shmem
6666 gc .collect ()
6767
6868
6969@pytest .mark .parametrize ("n" , [1 , 2 , 5 , 10 ])
70- def test_barrier_state_reuse (n ):
70+ @pytest .mark .parametrize ("barrier_type" , BARRIER_TYPES )
71+ def test_barrier_state_reuse (barrier_type , n ):
7172 """Verify device barrier reuses the same flags tensor across calls."""
7273 shmem = iris .iris (1 << 20 )
73- shmem . barrier ( )
74+ _call_barrier ( shmem , barrier_type )
7475
7576 try :
7677 shmem .device_barrier ()
@@ -82,7 +83,7 @@ def test_barrier_state_reuse(n):
8283 shmem .device_barrier ()
8384 assert shmem ._device_barrier_state [None ].data_ptr () == flags_ptr
8485 finally :
85- shmem . barrier ( )
86+ _call_barrier ( shmem , barrier_type )
8687 del shmem
8788 gc .collect ()
8889
@@ -161,13 +162,13 @@ def _cross_rank_graph(
161162 buf ,
162163 result ,
163164):
164- stream = torch .cuda .Stream ()
165+ capture_stream = torch .cuda .Stream ()
165166
166167 if op == "load" :
167168 buf .fill_ (float (rank ))
168169
169170 # Warmup on capture stream.
170- with torch .cuda .stream (stream ):
171+ with torch .cuda .stream (capture_stream ):
171172 for _ in range (num_barriers ):
172173 shmem .device_barrier ()
173174 _read_remote_kernel [(1 ,)](
@@ -180,11 +181,11 @@ def _cross_rank_graph(
180181 )
181182 for _ in range (num_barriers ):
182183 shmem .device_barrier ()
183- stream .synchronize ()
184+ capture_stream .synchronize ()
184185
185186 # Capture.
186187 graph = torch .cuda .CUDAGraph ()
187- with torch .cuda .graph (graph , stream = stream ):
188+ with torch .cuda .graph (graph , stream = capture_stream ):
188189 for _ in range (num_barriers ):
189190 shmem .device_barrier ()
190191 _read_remote_kernel [(1 ,)](
@@ -201,11 +202,11 @@ def _cross_rank_graph(
201202 # Replay with fresh data.
202203 for i in range (rounds ):
203204 val = float (rank + (i + 1 ) * 10 )
204- buf . fill_ ( val )
205- shmem . device_barrier ( )
206-
207- graph .replay ()
208- stream .synchronize ()
205+ with torch . cuda . stream ( capture_stream ):
206+ buf . fill_ ( val )
207+ shmem . device_barrier ()
208+ graph .replay ()
209+ capture_stream .synchronize ()
209210
210211 expected = torch .full (
211212 (N ,),
@@ -218,7 +219,7 @@ def _cross_rank_graph(
218219 buf .fill_ (0.0 )
219220
220221 # Warmup on capture stream.
221- with torch .cuda .stream (stream ):
222+ with torch .cuda .stream (capture_stream ):
222223 for _ in range (num_barriers ):
223224 shmem .device_barrier ()
224225 _write_remote_kernel [(1 ,)](
@@ -231,11 +232,11 @@ def _cross_rank_graph(
231232 )
232233 for _ in range (num_barriers ):
233234 shmem .device_barrier ()
234- stream .synchronize ()
235+ capture_stream .synchronize ()
235236
236237 # Capture.
237238 graph = torch .cuda .CUDAGraph ()
238- with torch .cuda .graph (graph , stream = stream ):
239+ with torch .cuda .graph (graph , stream = capture_stream ):
239240 for _ in range (num_barriers ):
240241 shmem .device_barrier ()
241242 _write_remote_kernel [(1 ,)](
@@ -251,13 +252,15 @@ def _cross_rank_graph(
251252
252253 # Replay and verify.
253254 for _ in range (rounds ):
254- buf . fill_ ( 0.0 )
255- shmem . device_barrier ( )
256-
257- graph .replay ()
258- stream .synchronize ()
255+ with torch . cuda . stream ( capture_stream ):
256+ buf . fill_ ( 0.0 )
257+ shmem . device_barrier ()
258+ graph .replay ()
259+ capture_stream .synchronize ()
259260
260- shmem .device_barrier ()
261+ with torch .cuda .stream (capture_stream ):
262+ shmem .device_barrier ()
263+ capture_stream .synchronize ()
261264 expected = torch .full ((N ,), float (writer ), dtype = torch .float32 , device = "cuda" )
262265 torch .testing .assert_close (buf , expected , rtol = 0 , atol = 0 )
263266
@@ -288,7 +291,7 @@ def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3):
288291 )
289292
290293 shmem = iris .iris (1 << 20 )
291- shmem . barrier ( )
294+ _call_barrier ( shmem , barrier_type )
292295 rank = shmem .get_rank ()
293296 num_ranks = shmem .get_num_ranks ()
294297 heap_bases = shmem .get_heap_bases ()
@@ -332,12 +335,13 @@ def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3):
332335 result ,
333336 )
334337 finally :
335- shmem . barrier ( )
338+ _call_barrier ( shmem , barrier_type )
336339 del shmem
337340 gc .collect ()
338341
339342
340- def test_barrier_timeout_assert ():
343+ @pytest .mark .parametrize ("barrier_type" , BARRIER_TYPES )
344+ def test_barrier_timeout_assert (barrier_type ):
341345 """Verify device_barrier asserts on timeout instead of hanging forever.
342346
343347 Only rank 0 calls the barrier kernel. Other ranks skip it, so rank 0
@@ -351,7 +355,7 @@ def test_barrier_timeout_assert():
351355 if num_ranks < 2 :
352356 pytest .skip ("Need at least 2 ranks" )
353357
354- shmem . barrier ( )
358+ _call_barrier ( shmem , barrier_type )
355359
356360 flags = shmem ._device_barrier_state .setdefault (None , shmem .zeros ((num_ranks ,), dtype = torch .int32 ))
357361
@@ -370,6 +374,8 @@ def test_barrier_timeout_assert():
370374 with pytest .raises (RuntimeError , match = "device-side assert" ):
371375 torch .cuda .synchronize ()
372376 finally :
373- shmem .barrier ()
377+ # No barrier here: rank 0's GPU is dead after the intentional
378+ # device-side assert. Any GPU sync (NCCL or device_barrier)
379+ # will hang or crash.
374380 del shmem
375381 gc .collect ()
0 commit comments