@@ -34,7 +34,7 @@ def matmul_kernel( #
3434 stride_cm , stride_cn , #
3535 BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr , #
3636 NUM_STAGES : tl .constexpr , SCALE_A : tl .constexpr = None , PRECISION : tl .constexpr = "ieee" ,
37- A_TRANS : tl .constexpr = False , EPILOGUE_SUBTILE : tl .constexpr = False ):
37+ A_TRANS : tl .constexpr = False , EPILOGUE_SUBTILE : tl .constexpr = False , dummy : tl . constexpr = 0 ):
3838 pid = tl .program_id (axis = 0 )
3939 num_pid_m = tl .cdiv (M , BLOCK_M )
4040 pid_m = pid % num_pid_m
@@ -97,8 +97,9 @@ def get_src_element_ty_size(dtype_str):
9797@pytest .mark .parametrize ("NUM_CTAS" , [1 , 2 ])
9898@pytest .mark .parametrize ("NUM_WARPS" , [4 , 8 ])
9999@pytest .mark .parametrize ("EPILOGUE_SUBTILE" , [True , False ])
100+ @pytest .mark .parametrize ("LAYOUT_16x256" , [True , False ])
100101def test_simple_matmul (dtype_src_str , dtype_dst_str , BLOCK_M , BLOCK_N , BLOCK_K , NUM_STAGES , NUM_WARPS , NUM_CTAS , device ,
101- EPILOGUE_SUBTILE ):
102+ EPILOGUE_SUBTILE , LAYOUT_16x256 , monkeypatch ):
102103 if NUM_CTAS > 1 and (not is_cuda () or torch .cuda .get_device_capability ()[0 ] < 9 ):
103104 pytest .xfail ("Clusters requires nvidia compute capability >= 9" )
104105 if is_hip () and ((BLOCK_K * BLOCK_M + BLOCK_K * BLOCK_N ) * NUM_STAGES * get_src_element_ty_size (dtype_src_str )
@@ -118,6 +119,8 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
118119 pytest .skip ("multi-CTAs is broken for mmav2" )
119120 if EPILOGUE_SUBTILE and not is_xpu () and (is_hip () or NUM_CTAS > 1 or BLOCK_N >= 512 ):
120121 pytest .skip ("creates convert layout too big to fit in smem" )
122+ if LAYOUT_16x256 and (not is_cuda () or torch .cuda .get_device_capability ()[0 ] < 10 ):
123+ pytest .xfail ("skip forcing tmem layout on non blackwell targets." )
121124 M , N , K = 1024 , 512 , 256
122125 torch .manual_seed (42 )
123126 precision = "tf32" if dtype_src_str == "tensorfloat32" else "ieee"
@@ -133,12 +136,16 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
133136 b = torch .randn (K , N , dtype = dtype_src , device = device )
134137 A = a
135138 B = b
139+ # pass a dummy constexpr argument to force recompilation.
140+ if LAYOUT_16x256 :
141+ monkeypatch .setenv ("TRITON_PREFER_TMEM_16x256_LAYOUT" , "1" )
136142 dtype_dst = getattr (torch , dtype_dst_str )
137143 output = torch .empty ((M , N ), dtype = dtype_dst , device = device )
138144 grid = (triton .cdiv (M , BLOCK_M ) * triton .cdiv (N , BLOCK_N ), 1 )
139145 k = matmul_kernel [grid ](a , b , output , M , N , K , a .stride (0 ), a .stride (1 ), b .stride (0 ), b .stride (1 ), output .stride (0 ),
140146 output .stride (1 ), BLOCK_M , BLOCK_N , BLOCK_K , NUM_STAGES = NUM_STAGES , PRECISION = precision ,
141- num_warps = NUM_WARPS , num_ctas = NUM_CTAS , EPILOGUE_SUBTILE = EPILOGUE_SUBTILE )
147+ num_warps = NUM_WARPS , num_ctas = NUM_CTAS , EPILOGUE_SUBTILE = EPILOGUE_SUBTILE ,
148+ dummy = LAYOUT_16x256 )
142149 ref_out = torch .matmul (A , B ).to (torch .float32 )
143150 output = output .to (torch .float32 )
144151 if dtype_src_str == "float32" :
@@ -161,6 +168,13 @@ def test_simple_matmul(dtype_src_str, dtype_dst_str, BLOCK_M, BLOCK_N, BLOCK_K,
161168 ttgir = k .asm ["ttgir" ]
162169 count = ttgir .count ("ttng.tc_gen5_mma" )
163170 assert count == 2 , "The TTGIR does not match the expected pattern."
171+ ptx = k .asm ["ptx" ]
172+ if LAYOUT_16x256 :
173+ assert "16x256b" in ptx , "PTX does not contain 16x256b"
174+ else :
175+ if "32x32b" not in ptx and "16x32b" not in ptx :
176+ print (ptx )
177+ assert ("32x32b" in ptx ) or ("16x32b" in ptx ), "PTX does not contain 32x32b or 16x32b"
164178
165179
166180# persistent matmul with fused loops
0 commit comments