99import triton
1010import triton .language as tl
1111import iris
12- from iris ._distributed_helpers import _device_barrier_kernel , extract_group_info
1312
1413
1514BarrierType = Literal ["host" , "device" ]
@@ -55,22 +54,23 @@ def _write_remote_kernel(
5554@pytest .mark .parametrize ("barrier_type" , BARRIER_TYPES )
5655def test_barrier_basic (barrier_type , n ):
5756 shmem = iris .iris (1 << 20 )
58- shmem . barrier ( )
57+ _call_barrier ( shmem , barrier_type )
5958
6059 try :
6160 for _ in range (n ):
6261 _call_barrier (shmem , barrier_type )
6362 finally :
64- shmem . barrier ( )
63+ _call_barrier ( shmem , barrier_type )
6564 del shmem
6665 gc .collect ()
6766
6867
6968@pytest .mark .parametrize ("n" , [1 , 2 , 5 , 10 ])
70- def test_barrier_state_reuse (n ):
69+ @pytest .mark .parametrize ("barrier_type" , BARRIER_TYPES )
70+ def test_barrier_state_reuse (barrier_type , n ):
7171 """Verify device barrier reuses the same flags tensor across calls."""
7272 shmem = iris .iris (1 << 20 )
73- shmem . barrier ( )
73+ _call_barrier ( shmem , barrier_type )
7474
7575 try :
7676 shmem .device_barrier ()
@@ -82,7 +82,7 @@ def test_barrier_state_reuse(n):
8282 shmem .device_barrier ()
8383 assert shmem ._device_barrier_state [None ].data_ptr () == flags_ptr
8484 finally :
85- shmem . barrier ( )
85+ _call_barrier ( shmem , barrier_type )
8686 del shmem
8787 gc .collect ()
8888
@@ -161,13 +161,13 @@ def _cross_rank_graph(
161161 buf ,
162162 result ,
163163):
164- stream = torch .cuda .Stream ()
164+ capture_stream = torch .cuda .Stream ()
165165
166166 if op == "load" :
167167 buf .fill_ (float (rank ))
168168
169169 # Warmup on capture stream.
170- with torch .cuda .stream (stream ):
170+ with torch .cuda .stream (capture_stream ):
171171 for _ in range (num_barriers ):
172172 shmem .device_barrier ()
173173 _read_remote_kernel [(1 ,)](
@@ -180,11 +180,11 @@ def _cross_rank_graph(
180180 )
181181 for _ in range (num_barriers ):
182182 shmem .device_barrier ()
183- stream .synchronize ()
183+ capture_stream .synchronize ()
184184
185185 # Capture.
186186 graph = torch .cuda .CUDAGraph ()
187- with torch .cuda .graph (graph , stream = stream ):
187+ with torch .cuda .graph (graph , stream = capture_stream ):
188188 for _ in range (num_barriers ):
189189 shmem .device_barrier ()
190190 _read_remote_kernel [(1 ,)](
@@ -201,11 +201,11 @@ def _cross_rank_graph(
201201 # Replay with fresh data.
202202 for i in range (rounds ):
203203 val = float (rank + (i + 1 ) * 10 )
204- buf . fill_ ( val )
205- shmem . device_barrier ( )
206-
207- graph .replay ()
208- stream .synchronize ()
204+ with torch . cuda . stream ( capture_stream ):
205+ buf . fill_ ( val )
206+ shmem . device_barrier ()
207+ graph .replay ()
208+ capture_stream .synchronize ()
209209
210210 expected = torch .full (
211211 (N ,),
@@ -218,7 +218,7 @@ def _cross_rank_graph(
218218 buf .fill_ (0.0 )
219219
220220 # Warmup on capture stream.
221- with torch .cuda .stream (stream ):
221+ with torch .cuda .stream (capture_stream ):
222222 for _ in range (num_barriers ):
223223 shmem .device_barrier ()
224224 _write_remote_kernel [(1 ,)](
@@ -231,11 +231,11 @@ def _cross_rank_graph(
231231 )
232232 for _ in range (num_barriers ):
233233 shmem .device_barrier ()
234- stream .synchronize ()
234+ capture_stream .synchronize ()
235235
236236 # Capture.
237237 graph = torch .cuda .CUDAGraph ()
238- with torch .cuda .graph (graph , stream = stream ):
238+ with torch .cuda .graph (graph , stream = capture_stream ):
239239 for _ in range (num_barriers ):
240240 shmem .device_barrier ()
241241 _write_remote_kernel [(1 ,)](
@@ -251,13 +251,15 @@ def _cross_rank_graph(
251251
252252 # Replay and verify.
253253 for _ in range (rounds ):
254- buf . fill_ ( 0.0 )
255- shmem . device_barrier ( )
256-
257- graph .replay ()
258- stream .synchronize ()
254+ with torch . cuda . stream ( capture_stream ):
255+ buf . fill_ ( 0.0 )
256+ shmem . device_barrier ()
257+ graph .replay ()
258+ capture_stream .synchronize ()
259259
260- shmem .device_barrier ()
260+ with torch .cuda .stream (capture_stream ):
261+ shmem .device_barrier ()
262+ capture_stream .synchronize ()
261263 expected = torch .full ((N ,), float (writer ), dtype = torch .float32 , device = "cuda" )
262264 torch .testing .assert_close (buf , expected , rtol = 0 , atol = 0 )
263265
@@ -288,7 +290,7 @@ def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3):
288290 )
289291
290292 shmem = iris .iris (1 << 20 )
291- shmem . barrier ( )
293+ _call_barrier ( shmem , barrier_type )
292294 rank = shmem .get_rank ()
293295 num_ranks = shmem .get_num_ranks ()
294296 heap_bases = shmem .get_heap_bases ()
@@ -332,44 +334,6 @@ def test_barrier_cross_rank(barrier_type, op, mode, num_barriers, N, rounds=3):
332334 result ,
333335 )
334336 finally :
335- shmem .barrier ()
336- del shmem
337- gc .collect ()
338-
339-
340- def test_barrier_timeout_assert ():
341- """Verify device_barrier asserts on timeout instead of hanging forever.
342-
343- Only rank 0 calls the barrier kernel. Other ranks skip it, so rank 0
344- spins waiting for them and hits the MAX_SPINS assert.
345- """
346- shmem = iris .iris (1 << 20 )
347- rank = shmem .get_rank ()
348- num_ranks = shmem .get_num_ranks ()
349- heap_bases = shmem .get_heap_bases ()
350-
351- if num_ranks < 2 :
352- pytest .skip ("Need at least 2 ranks" )
353-
354- shmem .barrier ()
355-
356- flags = shmem ._device_barrier_state .setdefault (None , shmem .zeros ((num_ranks ,), dtype = torch .int32 ))
357-
358- try :
359- if rank == 0 :
360- _ , rank_global , world_size , rank_start , rank_stride = extract_group_info (None , rank , num_ranks )
361- _device_barrier_kernel [(1 ,)](
362- flags ,
363- rank_global ,
364- world_size ,
365- rank_start ,
366- rank_stride ,
367- heap_bases ,
368- MAX_SPINS = 1000 ,
369- )
370- with pytest .raises (RuntimeError , match = "device-side assert" ):
371- torch .cuda .synchronize ()
372- finally :
373- shmem .barrier ()
337+ _call_barrier (shmem , barrier_type )
374338 del shmem
375339 gc .collect ()
0 commit comments