@@ -228,7 +228,8 @@ def tcgen5_mma_kernel(input_desc, XBLOCK: ttgl.constexpr, FAILURE: ttgl.constexp
228228 mbarrier .init (bar .index (0 ), count = 1 )
229229 mbarrier .init (bar .index (1 ), count = 1 )
230230
231- blackwell .tcgen05_mma (smemA , smemB .permute ([1 , 0 ]), acc , mbarriers = [bar .index (0 )])
231+ blackwell .tcgen05_mma (smemA , smemB .permute ([1 , 0 ]), acc )
232+ blackwell .tcgen05_commit (bar .index (0 ))
232233
233234 if not FAILURE :
234235 mbarrier .wait (bar .index (0 ), 0 )
@@ -285,32 +286,6 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
285286 tcgen5_mma_kernel [(1 , )](input_desc , XBLOCK , FAILURE = FAILURE , MEM_ACCESS_KIND = MEM_ACCESS_KIND , num_warps = 4 )
286287
287288
288- @gluon .jit
289- def tcgen5_mma_multibar_kernel (input_desc , XBLOCK : ttgl .constexpr , BUF_IDX : ttgl .constexpr , BAR_IDX : ttgl .constexpr ):
290- acc_layout : ttgl .constexpr = blackwell .TensorMemoryLayout ([XBLOCK , XBLOCK ], unpacked = True , cta_split_num = [1 , 1 ])
291- blocked_layout : ttgl .constexpr = ttgl .BlockedLayout (size_per_thread = [1 , XBLOCK ], threads_per_warp = [32 , 1 ],
292- warps_per_cta = [4 , 1 ], order = [0 , 1 ])
293- smemA = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], input_desc .layout )
294- smemB = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], input_desc .layout )
295- bar = ttgl .allocate_shared_memory (ttgl .int64 , [4 , 1 ], mbarrier .MBarrierLayout ())
296- acc = blackwell .allocate_tensor_memory (ttgl .float32 , [2 , XBLOCK , XBLOCK ], acc_layout )
297- for i in range (4 ):
298- mbarrier .init (bar .index (i ), count = 1 )
299-
300- blackwell .tcgen05_mma (smemA , smemB .permute ([1 , 0 ]), acc .index (0 ), mbarriers = [bar .index (0 ),
301- bar .index (1 )],
302- mbarrier_preds = [False , True ])
303- blackwell .tcgen05_mma (smemA , smemB .permute ([1 , 0 ]), acc .index (1 ), mbarriers = [bar .index (2 )])
304- blackwell .tcgen05_commit (bar .index (3 ))
305-
306- mbarrier .wait (bar .index (BAR_IDX ), 0 )
307-
308- acc .index (BUF_IDX ).store (ttgl .full ([XBLOCK , XBLOCK ], 42 , ttgl .float32 , blocked_layout ))
309-
310- for i in range (4 ):
311- mbarrier .invalidate (bar .index (i ))
312-
313-
314289@gluon .jit
315290def warpgroup_mma_kernel (input , XBLOCK : ttgl .constexpr , FAILURE : ttgl .constexpr ):
316291 smem_layout : ttgl .constexpr = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 )
@@ -405,6 +380,32 @@ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
405380 warpgroup_mma_kernel [(1 , )](input , XBLOCK , FAILURE = FAILURE )
406381
407382
383+ @gluon .jit
384+ def tcgen5_mma_multibar_kernel (input_desc , XBLOCK : ttgl .constexpr , BUF_IDX : ttgl .constexpr , BAR_IDX : ttgl .constexpr ):
385+ acc_layout : ttgl .constexpr = blackwell .TensorMemoryLayout ([XBLOCK , XBLOCK ], unpacked = True , cta_split_num = [1 , 1 ])
386+ blocked_layout : ttgl .constexpr = ttgl .BlockedLayout (size_per_thread = [1 , XBLOCK ], threads_per_warp = [32 , 1 ],
387+ warps_per_cta = [4 , 1 ], order = [0 , 1 ])
388+ smemA = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], input_desc .layout )
389+ smemB = ttgl .allocate_shared_memory (ttgl .float16 , [XBLOCK , XBLOCK ], input_desc .layout )
390+ bar = ttgl .allocate_shared_memory (ttgl .int64 , [4 , 1 ], mbarrier .MBarrierLayout ())
391+ acc = blackwell .allocate_tensor_memory (ttgl .float32 , [2 , XBLOCK , XBLOCK ], acc_layout )
392+ for i in range (4 ):
393+ mbarrier .init (bar .index (i ), count = 1 )
394+
395+ blackwell .tcgen05_mma (smemA , smemB .permute ([1 , 0 ]), acc .index (0 ), mbarriers = [bar .index (0 ),
396+ bar .index (1 )],
397+ mbarrier_preds = [False , True ])
398+ blackwell .tcgen05_mma (smemA , smemB .permute ([1 , 0 ]), acc .index (1 ), mbarriers = [bar .index (2 )])
399+ blackwell .tcgen05_commit (bar .index (3 ))
400+
401+ mbarrier .wait (bar .index (BAR_IDX ), 0 )
402+
403+ acc .index (BUF_IDX ).store (ttgl .full ([XBLOCK , XBLOCK ], 42 , ttgl .float32 , blocked_layout ))
404+
405+ for i in range (4 ):
406+ mbarrier .invalidate (bar .index (i ))
407+
408+
408409@pytest .mark .skipif (not is_cuda () or torch .cuda .get_device_capability ()[0 ] < 10 , reason = "Requires blackwell or newer" )
409410@pytest .mark .parametrize ("BUF_IDX" , [0 , 1 ])
410411@pytest .mark .parametrize ("BAR_IDX" , [0 , 1 , 2 , 3 ])
0 commit comments