@@ -1302,7 +1302,7 @@ def test_async_dot_blackwell_2cta_tma(device):
13021302 with pytest .raises (Exception ) as e :
13031303 run_async_dot_blackwell_2cta_tma (device , False , 128 )
13041304 assert isinstance (e .value , triton .CompilationError ), "expecting a compilation error"
1305- assert ' only supports M=128 per CTA for pair-CTA mma' in e .value .error_message
1305+ assert " only supports M=128 per CTA for pair-CTA mma" in e .value .error_message
13061306
13071307
13081308def run_async_dot_blackwell_2cta_tma (device , A_TMEM , SAMPLE_M ):
@@ -2703,6 +2703,84 @@ def test_wait_arrive_ws(BLOCK_SIZE, device):
27032703 and (ttgir .count ("default {" ) == 1 ) and (ttgir .count ("partition0" ) == 1 )), f"TTGIR { ttgir } "
27042704
27052705
2706+ @triton .jit
2707+ def tlx_square_warp_barrier (
2708+ x_ptr ,
2709+ z_ptr ,
2710+ n_elements ,
2711+ BLOCK_SIZE : tl .constexpr ,
2712+ NUM_WARPS : tl .constexpr ,
2713+ ):
2714+ """
2715+ Test warp barrier: all threads arrive independently (no leader pattern).
2716+ Uses alloc_warp_barrier instead of alloc_barriers.
2717+ """
2718+ pid = tl .program_id (axis = 0 )
2719+ block_start = pid * BLOCK_SIZE
2720+ offsets = block_start + tl .arange (0 , BLOCK_SIZE )
2721+ mask = offsets < n_elements
2722+
2723+ bars = tlx .alloc_warp_barrier (num_barriers = 1 , num_warps = NUM_WARPS )
2724+ bar = tlx .local_view (bars , 0 )
2725+
2726+ x = tl .load (x_ptr + offsets , mask = mask )
2727+
2728+ p = 0
2729+ tlx .barrier_arrive (bar = bar )
2730+ tlx .barrier_wait (bar = bar , phase = p )
2731+
2732+ z = x * x
2733+
2734+ p = p ^ 1
2735+ tlx .barrier_arrive (bar = bar )
2736+ tlx .barrier_wait (bar = bar , phase = p )
2737+
2738+ tl .store (z_ptr + offsets , z , mask = mask )
2739+
2740+ p = p ^ 1
2741+ tlx .barrier_arrive (bar = bar )
2742+ tlx .barrier_wait (bar = bar , phase = 0 )
2743+
2744+
2745+ @pytest .mark .skipif (not is_hopper_or_newer (), reason = "Need Hopper or newer" )
2746+ @pytest .mark .parametrize ("BLOCK_SIZE" , [(1024 )])
2747+ @pytest .mark .parametrize ("num_warps" , [4 ])
2748+ def test_alloc_warp_barrier (BLOCK_SIZE , num_warps , device ):
2749+ torch .manual_seed (0 )
2750+ size = 98432
2751+ x = torch .rand (size , device = device )
2752+ z = torch .empty_like (x )
2753+ n_elements = x .numel ()
2754+
2755+ grid = lambda meta : (triton .cdiv (n_elements , meta ["BLOCK_SIZE" ]), )
2756+ kernel = tlx_square_warp_barrier [grid ](
2757+ x ,
2758+ z ,
2759+ n_elements ,
2760+ BLOCK_SIZE ,
2761+ num_warps ,
2762+ num_warps = num_warps ,
2763+ )
2764+
2765+ z_ref = x * x
2766+ torch .testing .assert_close (z , z_ref , check_dtype = False )
2767+
2768+ # Verify IR uses arrive_barrier with perThread attribute
2769+ ttgir = kernel .asm ["ttgir" ]
2770+ assert ttgir .count ("ttng.init_barrier" ) == 1 , f"Expected 1 init_barrier in TTGIR:\n { ttgir } "
2771+ assert ttgir .count ("ttng.arrive_barrier" ) == 3 , f"Expected 3 arrive_barrier in TTGIR:\n { ttgir } "
2772+ assert ttgir .count ("perThread" ) == 3 , f"Expected 3 perThread attrs in TTGIR:\n { ttgir } "
2773+ assert ttgir .count ("ttng.wait_barrier" ) == 3 , f"Expected 3 wait_barrier in TTGIR:\n { ttgir } "
2774+
2775+ # Verify LLIR: perThread arrives use per-thread lowering (no leader predicate)
2776+ llir = kernel .asm ["llir" ]
2777+ # Per-thread arrive emits unpredicated: mbarrier.arrive.shared::cta.b64 _, [$0]
2778+ assert "mbarrier.arrive.shared::cta.b64 _, [$0]" in llir , (
2779+ f"Expected unpredicated per-thread mbarrier.arrive in LLIR:\n { llir } " )
2780+ # Leader pattern would emit predicated: @$0 mbarrier.arrive
2781+ assert "@$0 mbarrier.arrive" not in llir , f"Unexpected leader-predicated mbarrier.arrive in LLIR:\n { llir } "
2782+
2783+
27062784@pytest .mark .skipif (not is_hopper_or_newer (), reason = "Need Hopper or newer" )
27072785def test_barrier_live_range (device ):
27082786
@@ -6382,13 +6460,13 @@ def bulk_copy_kernel(
63826460 ttgir = kernel .asm ["ttgir" ]
63836461 assert "ttg.async_copy_global_to_local" in ttgir , "Expected async_copy_global_to_local in TTGIR"
63846462 assert "useBulk = true" in ttgir , "Expected useBulk = true in TTGIR"
6385- assert "ttng.async_store" in ttgir , ( "Expected async_store in TTGIR" )
6463+ assert "ttng.async_store" in ttgir , "Expected async_store in TTGIR"
63866464
63876465 # Verify PTX contains the bulk copy instructions
63886466 ptx = kernel .asm ["ptx" ]
63896467 assert "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes" in ptx , (
63906468 "Expected cp.async.bulk gmem->smem in PTX" )
6391- assert "cp.async.bulk.global.shared::cta.bulk_group" in ptx , ( "Expected cp.async.bulk smem->gmem in PTX" )
6469+ assert "cp.async.bulk.global.shared::cta.bulk_group" in ptx , "Expected cp.async.bulk smem->gmem in PTX"
63926470
63936471 # Verify correctness
63946472 torch .testing .assert_close (src , dst )
0 commit comments