@@ -280,7 +280,7 @@ def _finalize_matmul(
280280 for off_n in tl .range (tl .program_id (1 ) * OUT_BLOCK_N , outN , tl .num_programs (1 ) * OUT_BLOCK_N ):
281281 offs_n = off_n + tl .arange (0 , OUT_BLOCK_N )
282282 n_mask = offs_n < outN
283- tl .store (Out + row * outN + offs_n , tl .zeros ([OUT_BLOCK_N ], dtype = Out .dtype .element_ty ), mask = n_mask )
283+ tl .store (Out + row . to ( tl . int64 ) * outN + offs_n , tl .zeros ([OUT_BLOCK_N ], dtype = Out .dtype .element_ty ), mask = n_mask )
284284 else :
285285 for off_n in tl .range (tl .program_id (1 ) * BLOCK_N , N , tl .num_programs (1 ) * BLOCK_N , num_stages = STAGES ):
286286 offs_n = off_n + tl .arange (0 , BLOCK_N )
@@ -346,7 +346,7 @@ def _finalize_matmul(
346346 pid = row * tl .num_programs (1 ) + tl .program_id (1 ))
347347 tl .static_assert (OUT_BLOCK_N % OUT_MX_SCALE_BLOCK_N == 0 , "" )
348348 tl .store (OutActualScale + row * stride_out_mx_m + offs_n_scale * stride_out_mx_n , acc_scale , mask = n_mask_scale )
349- tl .store (Out + row * outN + offs_n [None , :], acc , mask = n_mask [None , :])
349+ tl .store (Out + row . to ( tl . int64 ) * outN + offs_n [None , :], acc , mask = n_mask [None , :])
350350 else :
351351 out = float_to_flex (out , out_scale if OutExpectedScale is not None else None , None , OutChecksumScale ,
352352 None , Out , flexpoint_saturate_inf )
@@ -355,7 +355,7 @@ def _finalize_matmul(
355355 pid = row * tl .num_programs (1 ) + tl .program_id (1 ))
356356 offs_n = off_n // ACTIVATION_REDUCTION_N + tl .arange (0 , OUT_BLOCK_N )
357357 n_mask = offs_n < outN
358- tl .store (Out + row * outN + offs_n , out , mask = n_mask )
358+ tl .store (Out + row . to ( tl . int64 ) * outN + offs_n , out , mask = n_mask )
359359
360360 persisent_m = tl .num_programs (0 ) < MBound
361361 if not persisent_m and n_active_experts == 0 :
0 commit comments