@@ -205,23 +205,17 @@ def get_cuda_autotune_config():
205205
206206
207207def get_hip_autotune_config ():
208- return [
209- triton .Config (
210- {'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 16 , 'GROUP_SIZE_M' : 1 , 'waves_per_eu' : 2 },
211- num_warps = 4 , num_stages = 2 ),
212- triton .Config (
213- {'BLOCK_SIZE_M' : 256 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 16 , 'GROUP_SIZE_M' : 4 , 'waves_per_eu' : 2 },
214- num_warps = 8 , num_stages = 2 ),
215- triton .Config (
216- {'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 32 , 'GROUP_SIZE_M' : 1 , 'waves_per_eu' : 2 },
217- num_warps = 8 , num_stages = 2 ),
218- triton .Config (
219- {'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 32 , 'GROUP_SIZE_M' : 8 , 'waves_per_eu' : 3 },
220- num_warps = 4 , num_stages = 2 ),
221- triton .Config (
222- {'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 32 , 'GROUP_SIZE_M' : 1 , 'waves_per_eu' : 8 },
223- num_warps = 4 , num_stages = 2 ),
208+ sizes = [
209+ {'BLOCK_SIZE_M' : 32 , 'BLOCK_SIZE_N' : 32 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 6 },
210+ {'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 32 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 4 },
211+ {'BLOCK_SIZE_M' : 32 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 6 },
212+ {'BLOCK_SIZE_M' : 64 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 6 },
213+ {'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 64 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 4 },
214+ {'BLOCK_SIZE_M' : 128 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 4 },
215+ {'BLOCK_SIZE_M' : 256 , 'BLOCK_SIZE_N' : 128 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 4 },
216+ {'BLOCK_SIZE_M' : 256 , 'BLOCK_SIZE_N' : 256 , 'BLOCK_SIZE_K' : 64 , 'GROUP_SIZE_M' : 6 },
224217 ]
218+ return [triton .Config (s | {'matrix_instr_nonkdim' : 16 }, num_warps = 8 , num_stages = 2 ) for s in sizes ]
225219
226220
227221def get_autotune_config ():
@@ -274,6 +268,19 @@ def matmul_kernel(
274268 pid_m = first_pid_m + ((pid % num_pid_in_group ) % group_size_m )
275269 pid_n = (pid % num_pid_in_group ) // group_size_m
276270
271+ # -----------------------------------------------------------
272+ # Add some integer bound assumptions.
273+ # This helps to guide integer analysis in the backend to optimize
274+ # load/store offset address calculation
275+ tl .assume (pid_m >= 0 )
276+ tl .assume (pid_n >= 0 )
277+ tl .assume (stride_am > 0 )
278+ tl .assume (stride_ak > 0 )
279+ tl .assume (stride_bn > 0 )
280+ tl .assume (stride_bk > 0 )
281+ tl .assume (stride_cm > 0 )
282+ tl .assume (stride_cn > 0 )
283+
277284 # ----------------------------------------------------------
278285 # Create pointers for the first blocks of A and B.
279286 # We will advance this pointer as we move in the K direction
0 commit comments