Skip to content

Commit abd3bb0

Browse files
yongjikJokeren
andauthored
Fix launch metadata (memory bandwidth) for matmul_ogs. (#6890)
(The previous code wasn't correct if only one of GatherIndx or WriteBackIndx was given - it would assume reading/writing the full tensor on the side without the index.) --------- Co-authored-by: Keren Zhou <[email protected]>
1 parent 975446e commit abd3bb0

File tree

2 files changed

+80
-7
lines changed

2 files changed

+80
-7
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytest
33
import torch
44
from typing import Union
5+
import triton
56
# routing utilities
67
from triton_kernels.routing import routing
78
# matmul utilities
@@ -243,6 +244,9 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
243244
# Automatic padding not implemented for Hopper swizzle
244245
pytest.skip("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles).")
245246

247+
# launch metadata for batched / mx types may not work yet.
248+
test_launch_metadata = (mode == "ragged") and ("mx" not in weight_dtype_str)
249+
246250
torch.manual_seed(0)
247251

248252
block_k = None
@@ -314,8 +318,48 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
314318

315319
if w_tri.shape[0] == 1:
316320
# Test the case when weight has dim 2, i.e., shape (K, N).
317-
w_tri = w_tri.squeeze(0).detach().requires_grad_()
318-
w_ref = w_ref.squeeze(0).detach().requires_grad_()
321+
w_tri = w_tri.squeeze(0).detach().requires_grad_(test_bwd)
322+
w_ref = w_ref.squeeze(0).detach().requires_grad_(test_bwd)
323+
324+
if test_launch_metadata:
325+
326+
def _clobber(t, used_mask):
327+
# Fill the unread part of the tensor with garbage, to be sure that
328+
# we don't actually read from the part.
329+
if len(used_mask) == 1:
330+
return
331+
elif t.element_size() == 1:
332+
t.view(torch.int8)[~used_mask] = 127
333+
else:
334+
t[~used_mask] = torch.inf
335+
336+
if rdata is not None:
337+
n_tokens = rdata.expt_hist.sum().item()
338+
used_expts = (rdata.expt_hist > 0)
339+
_clobber(w_tri, used_expts)
340+
n_w_bytes = used_expts.sum().item() * n * k * w_tri.element_size()
341+
else:
342+
n_tokens = m
343+
n_w_bytes = w_tri.numel() * w_tri.element_size()
344+
345+
if gindx is not None:
346+
used_x_rows = (gindx.dst_indx.view(-1, n_expts_act) != -1).any(dim=1)
347+
_clobber(x_tri, used_x_rows)
348+
n_x_bytes = used_x_rows.sum().item() * k * x_tri.element_size()
349+
elif rdata is not None:
350+
n_x_bytes = n_tokens * k * x_tri.element_size()
351+
else:
352+
n_x_bytes = x_tri.numel() * x_tri.element_size()
353+
354+
nbytes = None
355+
356+
def _hook(launch_metadata):
357+
nonlocal nbytes
358+
metadata = launch_metadata.get()
359+
if "matmul_ogs" in metadata["name"]:
360+
nbytes = metadata["bytes"]
361+
362+
triton.knobs.runtime.launch_enter_hook = _hook
319363

320364
if mode == "batched":
321365
rdata, gindx, sindx = None, None, None
@@ -327,6 +371,16 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
327371
sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
328372
y_scale = flex.out_data.expected_scale if act_is_float8 else 1
329373

374+
if test_launch_metadata:
375+
if gindx is not None:
376+
n_y_bytes = (gindx.src_indx != -1).sum().item() * n * tri_y.element_size()
377+
elif rdata is not None:
378+
n_y_bytes = n_tokens * n * tri_y.element_size()
379+
else:
380+
n_y_bytes = tri_y.numel() * tri_y.element_size()
381+
assert nbytes == n_x_bytes + n_y_bytes + n_w_bytes
382+
triton.knobs.runtime.launch_enter_hook = None
383+
330384
def round_x(x, idx):
331385
return x.to(act_dtype).to(torch.float32) if sep_gather else x
332386

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import torch
2+
13
import triton
24
import triton.language as tl
35

@@ -87,10 +89,27 @@ def matmul_launch_metadata(grid, kernel, args):
8789
fM = M if M is not None else n_tokens
8890
fK = K if K is not None else n_tokens
8991
ret[f"flops{nbits}"] = 2.0 * fM * N * fK
92+
9093
gindx = args.get("GatherIndx", None)
91-
sindx = args.get("WriteBackIndx", None)
92-
sskipped = 0. if sindx is None else (sindx == -1).sum() / sindx.shape[0]
93-
gskipped = 0. if gindx is None else (gindx == -1).sum() / gindx.shape[0]
94-
ret["bytes"] = int((1 - sskipped) * Y.numel() * Y.element_size() + (1 - gskipped) * X.numel() * X.element_size() +
95-
n_w_bytes)
94+
# sindx = args.get("WriteBackIndx", None)
95+
n_x_bytes = X.numel() * X.element_size()
96+
n_y_bytes = Y.numel() * Y.element_size()
97+
if hist is not None:
98+
assert X.shape[0] == Y.shape[0] == 1, "batched mode not supported"
99+
assert n_tokens is not None
100+
n_expts_act = args["N_EXPTS_ACT"]
101+
102+
if gindx is not None:
103+
# recreate inverse GatherIndx.
104+
dst = torch.full_like(gindx, -1)
105+
idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)
106+
mask = (gindx != -1)
107+
dst[gindx[mask]] = idx[mask]
108+
n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum()
109+
else:
110+
n_read_rows = n_tokens
111+
n_x_bytes = n_read_rows * X.shape[-1] * X.element_size()
112+
n_y_bytes = n_tokens * Y.shape[-1] * Y.element_size()
113+
ret["bytes"] = int(n_x_bytes + n_y_bytes + n_w_bytes)
114+
96115
return ret

0 commit comments

Comments
 (0)