Skip to content

Commit e1fb6f6

Browse files
authored
[kernels] restore old behavior that output for tokens routed to zero experts should be zero-initialized (#7150)
triton-lang/triton#7140 introduced a subtle change in the semantics of `matmul_ogs`. We actually care that the output of rows that have scatter_indx==-1 be zero-initialized because some expert parallelism code may reduce them also found some missing mask in the AMD implementation, which most likely explains the test failure.
1 parent 88a2851 commit e1fb6f6

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/_finalize_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def _finalize_matmul(
291291
if src_idx != -1:
292292
As = A + src_idx.to(tl.int64) * stride_a_m + offs_n
293293
for ki in tl.static_range(K):
294-
acc += tl.load(As, mask=n_mask, other=0.0)
294+
acc += tl.load(As, mask=(src_idxs != -1)[:, None] & n_mask[None, :], other=0.0)
295295
As += stride_a_k
296296
else:
297297
As = A + src_idxs.to(tl.int64)[:, None] * stride_a_m + offs_n[None, :]

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def _compute_writeback_idx(
387387
is_src_active = (src_idxs != -1).to(tl.int32)
388388
num_src_active = tl.sum(is_src_active, axis=1)
389389

390-
need_finalize_scatter = mask_m & (num_src_active > 1)
390+
need_finalize_scatter = mask_m & (num_src_active != 1)
391391
finalize_scatter_count = tl.sum(need_finalize_scatter.to(tl.int32))
392392
if finalize_scatter_count == 0:
393393
return

0 commit comments

Comments
 (0)