Skip to content

Commit 05b2c18

Browse files
authored
[TRITON_KERNELS] cast index to int64 in finalize_matmul (#7794)
The product of `row * outN` could overflow in int32. To avoid this, we perform the offset multiplication in int64 instead
1 parent f33bcbd commit 05b2c18

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/_finalize_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def _finalize_matmul(
280280
for off_n in tl.range(tl.program_id(1) * OUT_BLOCK_N, outN, tl.num_programs(1) * OUT_BLOCK_N):
281281
offs_n = off_n + tl.arange(0, OUT_BLOCK_N)
282282
n_mask = offs_n < outN
283-
tl.store(Out + row * outN + offs_n, tl.zeros([OUT_BLOCK_N], dtype=Out.dtype.element_ty), mask=n_mask)
283+
tl.store(Out + row.to(tl.int64) * outN + offs_n, tl.zeros([OUT_BLOCK_N], dtype=Out.dtype.element_ty), mask=n_mask)
284284
else:
285285
for off_n in tl.range(tl.program_id(1) * BLOCK_N, N, tl.num_programs(1) * BLOCK_N, num_stages=STAGES):
286286
offs_n = off_n + tl.arange(0, BLOCK_N)
@@ -346,7 +346,7 @@ def _finalize_matmul(
346346
pid=row * tl.num_programs(1) + tl.program_id(1))
347347
tl.static_assert(OUT_BLOCK_N % OUT_MX_SCALE_BLOCK_N == 0, "")
348348
tl.store(OutActualScale + row * stride_out_mx_m + offs_n_scale * stride_out_mx_n, acc_scale, mask=n_mask_scale)
349-
tl.store(Out + row * outN + offs_n[None, :], acc, mask=n_mask[None, :])
349+
tl.store(Out + row.to(tl.int64) * outN + offs_n[None, :], acc, mask=n_mask[None, :])
350350
else:
351351
out = float_to_flex(out, out_scale if OutExpectedScale is not None else None, None, OutChecksumScale,
352352
None, Out, flexpoint_saturate_inf)
@@ -355,7 +355,7 @@ def _finalize_matmul(
355355
pid=row * tl.num_programs(1) + tl.program_id(1))
356356
offs_n = off_n // ACTIVATION_REDUCTION_N + tl.arange(0, OUT_BLOCK_N)
357357
n_mask = offs_n < outN
358-
tl.store(Out + row * outN + offs_n, out, mask=n_mask)
358+
tl.store(Out + row.to(tl.int64) * outN + offs_n, out, mask=n_mask)
359359

360360
persisent_m = tl.num_programs(0) < MBound
361361
if not persisent_m and n_active_experts == 0:

0 commit comments

Comments
 (0)