@@ -80,6 +80,27 @@ def _matmul_ogs(
8080 SWAP_XW : tl .constexpr = False ,
8181 IS_EPILOGUE_DEQUANT_MXFP8 : tl .constexpr = False ):
8282
83+ tl .assume (stride_y_k >= 0 )
84+ tl .assume (stride_y_z >= 0 )
85+ tl .assume (stride_y_m >= 0 )
86+ tl .assume (stride_y_n >= 0 )
87+ tl .assume (stride_x_z >= 0 )
88+ tl .assume (stride_x_m >= 0 )
89+ tl .assume (stride_x_k >= 0 )
90+ tl .assume (stride_w_e >= 0 )
91+ tl .assume (stride_w_k >= 0 )
92+ tl .assume (stride_w_n >= 0 )
93+ if stride_w_mx_e is not None :
94+ tl .assume (stride_w_mx_e >= 0 )
95+ if stride_w_mx_k is not None :
96+ tl .assume (stride_w_mx_k >= 0 )
97+ if stride_w_mx_n is not None :
98+ tl .assume (stride_w_mx_n >= 0 )
99+ tl .assume (stride_b_e >= 0 )
100+ tl .assume (batch_size >= 0 )
101+ tl .assume (grid_m >= 0 )
102+ tl .assume (grid_n >= 0 )
103+
83104 is_w_microscaled : tl .constexpr = WMxScale is not None
84105 MX_PACK_DIVISOR : tl .constexpr = MXFP_BLOCK_SIZE
85106 if is_w_microscaled :
@@ -116,7 +137,9 @@ def _matmul_ogs(
116137 HAS_FUSED_SCATTER : tl .constexpr = WriteBackIndx is not None
117138 index_type : tl .constexpr = tl .int64 if UPCAST_INDICES else tl .int32
118139
119- total_actual_tiles = batch_size * (grid_m - padding_m ) * grid_n * SPLIT_K
140+ unpadded_m = grid_m - padding_m
141+ tl .assume (unpadded_m >= 0 )
142+ total_actual_tiles = batch_size * unpadded_m * grid_n * SPLIT_K
120143 if padding_m > 0 and pid >= total_actual_tiles :
121144 tl .device_assert (batch_size == 0 )
122145 pid_mn = pid - total_actual_tiles
@@ -132,11 +155,11 @@ def _matmul_ogs(
132155 pid_emnk = pid
133156 if XCD_SWIZZLE != 1 :
134157 pid_emnk = xcd_swizzle (pid_emnk , total_actual_tiles , XCD_SWIZZLE )
135- pid_e = pid_emnk // (( grid_m - padding_m ) * grid_n * SPLIT_K )
136- pid_mnk = pid_emnk % (( grid_m - padding_m ) * grid_n * SPLIT_K )
158+ pid_e = pid_emnk // (unpadded_m * grid_n * SPLIT_K )
159+ pid_mnk = pid_emnk % (unpadded_m * grid_n * SPLIT_K )
137160 pid_k = pid_mnk % SPLIT_K
138161 pid_mn = pid_mnk // SPLIT_K
139- pid_m , pid_n = swizzle2d (pid_mn , ( grid_m - padding_m ) , grid_n , GROUP_M )
162+ pid_m , pid_n = swizzle2d (pid_mn , unpadded_m , grid_n , GROUP_M )
140163 # For split-k, advance to the output k slice
141164 if SPLIT_K > 1 :
142165 Y += pid_k .to ( index_type ) * stride_y_k
0 commit comments