Skip to content

Commit 7be5b8a

Browse files
authored
[triton_kernels][matmul] skip some unnecessary compute (#7140)
1 parent d0c65f9 commit 7be5b8a

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,15 @@ def round_x(x, idx):
414414
rdata, gindx, sindx, round_x=round_x, round_y=round_y, gammas=gs1_ref)
415415
scale = lambda val, scal: val if scal is None else val / scal
416416
if n_expt_shards > 1:
417-
if not do_scatter:
417+
if do_scatter:
418+
indx = sindx.dst_indx[sindx.dst_indx != -1]
419+
ref_y = ref_y[indx // n_expts_act, :]
420+
if act_is_float8:
421+
tri_y = tri_y.view(torch.int8)
422+
tri_y = tri_y[indx // n_expts_act, :]
423+
if act_is_float8:
424+
tri_y = tri_y.view(act_dtype)
425+
else:
418426
n_rows = rdata.expt_hist.sum()
419427
assert n_rows > 0
420428
ref_y = ref_y[:n_rows]

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,12 +574,12 @@ def init_allocation(x, w, precision_config, fused_activation, routing_data, gath
574574
def apply_allocation(allocation: MatmulAllocation, output):
575575
ret = dict()
576576
if output is None:
577-
output = torch.zeros(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
577+
output = torch.empty(allocation.output[0], device=allocation.device, dtype=allocation.output[1])
578578
else:
579579
assert output.shape == allocation.output[0]
580580
ret["output"] = output[None, :, :]
581581
ret["scratchpad"] = {
582-
k: torch.zeros(v[0], device=allocation.device, dtype=v[1])
582+
k: torch.empty(v[0], device=allocation.device, dtype=v[1])
583583
for k, v in allocation.scratchpads.items()
584584
}
585585
return ret
@@ -837,7 +837,6 @@ def matmul_ogs(x, w, bias,
837837
out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
838838
num_indx, precision_config, routing_data,
839839
postprocessing_features, memory, fused_postprocess_activation, epilogue)
840-
841840
# remove split-k
842841
out = out.squeeze(0)
843842
if not is_input_batched:

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,9 +385,9 @@ def _compute_writeback_idx(
385385
src_offs = offs_m[:, None] * N_EXPTS_ACT + tl.arange(0, N_EXPTS_ACT)[None, :]
386386
src_idxs = tl.load(ScatterSrcIndx + src_offs, mask=mask_m[:, None], other=-1)
387387
is_src_active = (src_idxs != -1).to(tl.int32)
388-
has_one_active = tl.sum(is_src_active, axis=1) == 1
388+
num_src_active = tl.sum(is_src_active, axis=1)
389389

390-
need_finalize_scatter = mask_m & (~has_one_active)
390+
need_finalize_scatter = mask_m & (num_src_active > 1)
391391
finalize_scatter_count = tl.sum(need_finalize_scatter.to(tl.int32))
392392
if finalize_scatter_count == 0:
393393
return

0 commit comments

Comments
 (0)