11import torch
2+ import math
23import pytest
34import re
45from itertools import product
@@ -126,9 +127,9 @@ def test_async_copy_mbarrier(device):
126127
127128
128129@gluon .jit
129- def warpgroup_mma_kernel (a , b , out , M : ttgl .constexpr , N : ttgl .constexpr , K : ttgl .constexpr ,
130- block_layout : ttgl .constexpr , mma_layout : ttgl .constexpr , shared_layout_a : ttgl .constexpr ,
131- shared_layout_b : ttgl .constexpr , acc_dtype : ttgl .constexpr , ASYNC : ttgl .constexpr ):
130+ def mma_kernel (a , b , out , M : ttgl .constexpr , N : ttgl .constexpr , K : ttgl . constexpr , block_layout : ttgl .constexpr ,
131+ mma_layout : ttgl .constexpr , shared_layout_a : ttgl .constexpr , shared_layout_b : ttgl .constexpr ,
132+ acc_dtype : ttgl .constexpr , ASYNC : ttgl .constexpr , USE_TCGEN05 : ttgl .constexpr ):
132133 a_offs_m = ttgl .arange (0 , M , layout = ttgl .SliceLayout (1 , block_layout ))[:, None ]
133134 a_offs_k = ttgl .arange (0 , K , layout = ttgl .SliceLayout (0 , block_layout ))[None , :]
134135 b_offs_k = ttgl .arange (0 , K , layout = ttgl .SliceLayout (1 , block_layout ))[:, None ]
@@ -143,14 +144,37 @@ def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttg
143144
144145 smem_a = ttgl .allocate_shared_memory (operand_dtype , [M , K ], shared_layout_a , a_tile )
145146 smem_b = ttgl .allocate_shared_memory (operand_dtype , [K , N ], shared_layout_b , b_tile )
146-
147147 fence_async_shared ()
148148
149- acc = ttgl .zeros ([M , N ], dtype = acc_dtype , layout = mma_layout )
150- acc = hopper .warpgroup_mma (smem_a , smem_b , acc , is_async = ASYNC )
149+ if USE_TCGEN05 :
150+ tmem_layout : ttgl .constexpr = TensorMemoryLayout ((M , N ), col_stride = 32 // acc_dtype .primitive_bitwidth )
151+
152+ num_warps : ttgl .constexpr = ttgl .num_warps ()
153+ tmem_reg_layout : ttgl .constexpr = get_tmem_32x32b_reg_layout (
154+ M = M ,
155+ N = N ,
156+ shape = [M , N ],
157+ num_warps = num_warps ,
158+ )
159+
160+ mma_barrier = ttgl .allocate_shared_memory (ttgl .int64 , [1 ], mbarrier .MBarrierLayout ())
161+ mbarrier .init (mma_barrier , count = 1 )
162+
163+ acc_zero = ttgl .zeros ([M , N ], dtype = acc_dtype , layout = tmem_reg_layout )
164+ acc_tmem = allocate_tensor_memory (acc_dtype , [M , N ], tmem_layout , acc_zero )
151165
152- if ASYNC :
153- acc = hopper .warpgroup_mma_wait (num_outstanding = 0 , deps = [acc ])
166+ tcgen05_mma (smem_a , smem_b , acc_tmem , use_acc = False )
167+ tcgen05_commit (mma_barrier )
168+ mbarrier .wait (mma_barrier , phase = 0 )
169+ mbarrier .invalidate (mma_barrier )
170+ acc = acc_tmem .load (tmem_reg_layout )
171+ acc = ttgl .convert_layout (acc , layout = mma_layout )
172+ else :
173+ acc = ttgl .zeros ([M , N ], dtype = acc_dtype , layout = mma_layout )
174+ acc = hopper .warpgroup_mma (smem_a , smem_b , acc , is_async = ASYNC )
175+
176+ if ASYNC :
177+ acc = hopper .warpgroup_mma_wait (num_outstanding = 0 , deps = [acc ])
154178
155179 ttgl .store (out + out_offs_m * N + out_offs_n , acc )
156180
@@ -168,7 +192,7 @@ def test_warpgroup_mma(ASYNC):
168192 a = torch .randn ((M , K ), device = "cuda" , dtype = torch .float16 )
169193 b = torch .randn ((K , N ), device = "cuda" , dtype = torch .float16 )
170194 out = torch .zeros ((M , N ), device = "cuda" , dtype = torch .float16 )
171- warpgroup_mma_kernel [(1 , )](
195+ mma_kernel [(1 , )](
172196 a ,
173197 b ,
174198 out ,
@@ -181,6 +205,7 @@ def test_warpgroup_mma(ASYNC):
181205 shared_layout_b ,
182206 ttgl .float16 ,
183207 ASYNC ,
208+ False ,
184209 num_warps = warps [0 ] * warps [1 ],
185210 )
186211
@@ -189,19 +214,24 @@ def test_warpgroup_mma(ASYNC):
189214 torch .testing .assert_close (out , ref , atol = 1e-3 , rtol = 1e-1 )
190215
191216
192- @pytest .mark .xfail (not is_hopper (), reason = "Requires Hopper" , run = False )
217+ @pytest .mark .xfail (not ( is_hopper () or is_blackwell ()) , reason = "Requires Hopper or Blackwell " , run = False )
193218@pytest .mark .parametrize ("bitwidth, transpose_a, transpose_b, acc_dtype" ,
194219 [(bitwidth , transpose_a , transpose_b , acc_dtype )
195220 for bitwidth in [8 , 16 , 32 ]
196221 for (transpose_a , transpose_b ) in product ([False , True ], repeat = 2 )
197222 for acc_dtype in [torch .float16 , torch .float32 ]
198223 if bitwidth == 16 or (acc_dtype == torch .float32 and not transpose_a and transpose_b )])
199224@pytest .mark .parametrize ("warps" , ([8 , 1 ], [4 , 2 ], [4 , 1 ]))
200- # Swizzling 0 does not map to a valid memory descriptor lol
201- @pytest .mark .parametrize ("swizzling_a, swizzling_b" , product ([32 , 64 , 128 ], repeat = 2 ))
225+ @pytest .mark .parametrize ("swizzling_a, swizzling_b" , product ([0 , 32 , 64 , 128 ], repeat = 2 ))
202226@pytest .mark .parametrize ("shape_m, shape_n, shape_k" , [(1 , 1 , 1 ), (2 , 4 , 1 ), (2 , 2 , 4 )])
203- def test_warpgroup_mma_shared_inputs (bitwidth , transpose_a , transpose_b , acc_dtype , warps , swizzling_a , swizzling_b ,
204- shape_m , shape_n , shape_k ):
227+ def test_mma_shared_inputs (bitwidth , transpose_a , transpose_b , acc_dtype , warps , swizzling_a , swizzling_b , shape_m ,
228+ shape_n , shape_k , fresh_knobs ):
229+
230+ # FIXME: Workaround for a bug in PTXAS when the shared layout is transposed and the swizzling is 0
231+ if bitwidth == 16 and ((transpose_a and swizzling_a == 0 and shape_m > 1 ) or
232+ (not transpose_b and swizzling_b == 0 and shape_n > 1 )):
233+ fresh_knobs .nvidia .disable_ptxas_opt = True
234+ use_tcgen05 = is_blackwell ()
205235
206236 torch_dtype_map = {
207237 8 : torch .float8_e4m3fn ,
@@ -214,8 +244,7 @@ def test_warpgroup_mma_shared_inputs(bitwidth, transpose_a, transpose_b, acc_dty
214244 }
215245
216246 # We'll choose a larger instr shape along N, but sure
217- instr_shape_k_map = {8 : 32 , 16 : 16 , 32 : 8 }
218- instr_shape = [16 , 32 , instr_shape_k_map [bitwidth ]]
247+ instr_shape = [16 , 32 , 256 // bitwidth ]
219248 M = instr_shape [0 ] * warps [0 ]
220249 N = instr_shape [1 ] * warps [1 ]
221250 K = instr_shape [2 ]
@@ -239,7 +268,27 @@ def min_shape(swizzling, dim0, dim1, trans):
239268 K *= shape_k
240269 instr_shape [1 ] *= shape_n
241270
242- shared_mem_accum = M * K * bitwidth // 8 + K * N * bitwidth // 8
271+ if use_tcgen05 :
272+ M = 128
273+
274+ def get_shared_swizzling_zero (M , K , transpose ):
275+ # K-contig
276+ if transpose :
277+ K , M = M , K
278+ bases = []
279+ for i in range (int (math .log2 (128 // bitwidth ))):
280+ bases .append ([0 , 1 << i ])
281+ for i in range (int (math .log2 (M ))):
282+ bases .append ([1 << i , 0 ])
283+ for i in range (int (math .log2 (K // (128 // bitwidth )))):
284+ offset = int (math .log2 (128 // bitwidth )) + i
285+ bases .append ([0 , 1 << offset ])
286+ if transpose :
287+ for i in range (len (bases )):
288+ bases [i ] = [bases [i ][1 ], bases [i ][0 ]]
289+ return ttgl .SharedLinearLayout (bases )
290+
291+ shared_mem_accum = (M + N ) * K * bitwidth // 8
243292 if triton .runtime .driver .active .utils .get_device_properties (
244293 triton .runtime .driver .active .get_current_device ())["max_shared_mem" ] < shared_mem_accum :
245294 pytest .skip ("Skipped due to insufficient shared memory on this GPU." )
@@ -248,11 +297,17 @@ def min_shape(swizzling, dim0, dim1, trans):
248297 gl_acc_dtype = acc_dtype_map [acc_dtype ]
249298 out_dtype = torch .float32
250299
251- block_layout = ttgl .BlockedLayout ([1 , 1 ], [1 , THREADS_PER_WARP ], warps_per_cta = warps , order = [1 , 0 ])
252- shared_layout_a = ttgl .NVMMASharedLayout (swizzle_byte_width = swizzling_a , element_bitwidth = bitwidth , rank = 2 ,
253- transposed = transpose_a )
254- shared_layout_b = ttgl .NVMMASharedLayout (swizzle_byte_width = swizzling_b , element_bitwidth = bitwidth , rank = 2 ,
255- transposed = transpose_b )
300+ block_layout = ttgl .BlockedLayout ([1 , 8 ], [1 , THREADS_PER_WARP ], warps_per_cta = warps , order = [1 , 0 ])
301+ if swizzling_a == 0 :
302+ shared_layout_a = get_shared_swizzling_zero (M , K , transpose_a )
303+ else :
304+ shared_layout_a = ttgl .NVMMASharedLayout (swizzle_byte_width = swizzling_a , element_bitwidth = bitwidth , rank = 2 ,
305+ transposed = transpose_a )
306+ if swizzling_b == 0 :
307+ shared_layout_b = get_shared_swizzling_zero (K , N , transpose_b )
308+ else :
309+ shared_layout_b = ttgl .NVMMASharedLayout (swizzle_byte_width = swizzling_b , element_bitwidth = bitwidth , rank = 2 ,
310+ transposed = transpose_b )
256311 mma_layout = ttgl .NVMMADistributedLayout (version = [3 , 0 ], warps_per_cta = warps , instr_shape = instr_shape )
257312
258313 torch .manual_seed (0 )
@@ -271,7 +326,7 @@ def cast(x, dtype):
271326 b = cast (torch .randn ((K , N ), device = "cuda" , dtype = torch .float32 ), torch_dtype )
272327 out = torch .zeros ((M , N ), device = "cuda" , dtype = out_dtype )
273328
274- warpgroup_mma_kernel [(1 , )](
329+ mma_kernel [(1 , )](
275330 a ,
276331 b ,
277332 out ,
@@ -284,6 +339,7 @@ def cast(x, dtype):
284339 shared_layout_b ,
285340 gl_acc_dtype ,
286341 False ,
342+ use_tcgen05 ,
287343 num_warps = warps [0 ] * warps [1 ],
288344 )
289345
@@ -298,9 +354,9 @@ def cast(x, dtype):
298354 torch .backends .cuda .matmul .allow_fp16_reduced_precision_reduction = allow_fp16_red
299355
300356 if bitwidth == 8 :
301- atol , rtol = 0.5 , 0.5
357+ atol , rtol = 5e-2 , 5e-1
302358 elif bitwidth == 16 :
303- atol , rtol = 3e -2 , 1e -1
359+ atol , rtol = 5e -2 , 5e -1
304360 else :
305361 atol , rtol = 5e-4 , 5e-3
306362 torch .testing .assert_close (out , ref , atol = atol , rtol = rtol )
0 commit comments