Skip to content

Commit 1ab9f65

Browse files
authored
[KERNELS] Add an option to avoid device sync on launch_metadata. (#7296)
1 parent 9182231 commit 1ab9f65

File tree

4 files changed

+49
-6
lines changed

4 files changed

+49
-6
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,37 @@ def convert_dtype(dtype):
7878

7979

8080
def matmul_launch_metadata(grid, kernel, args):
81+
from ..proton_opts import launch_metadata_allow_sync
82+
8183
ret = dict()
8284
M, N, K = args["M"], args["N"], args["K"]
8385
Y, X, W = [t.base if isinstance(t, TensorDescriptor) else t for t in [args["Y"], args["X"], args["W"]]]
86+
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
8487
hist = args["ExptHist"]
8588
if hist is not None:
86-
n_tokens = float(hist.sum())
87-
n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum()
89+
# If annotation is given, use that to generate name for profiling.
90+
if tokens_per_expt is not None:
91+
n_rows = f"{tokens_per_expt}*"
92+
elif launch_metadata_allow_sync():
93+
n_rows = int(hist.float().mean())
94+
else:
95+
n_rows = "unknown"
96+
97+
if launch_metadata_allow_sync():
98+
n_tokens = float(hist.sum())
99+
n_w_bytes = (W.numel() * W.element_size() // hist.numel()) * (hist > 0).sum()
100+
elif tokens_per_expt is not None:
101+
n_tokens = tokens_per_expt * args["N_EXPTS_TOT"]
102+
# This may not be totally correct (e.g., we might not be using all experts)
103+
# but it's better than nothing.
104+
n_w_bytes = W.numel() * W.element_size()
105+
else:
106+
n_tokens = None
107+
n_w_bytes = 0
88108

89109
# If annotation is given, use that to generate name for profiling.
90110
tokens_per_expt = args.get("TOKENS_PER_EXPT_FOR_ANNOTATION")
91-
n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else int(hist.float().mean())
111+
n_rows = f"{tokens_per_expt}*" if tokens_per_expt is not None else n_rows
92112
else:
93113
n_tokens = None
94114
n_w_bytes = W.numel() * W.element_size()
@@ -101,6 +121,10 @@ def matmul_launch_metadata(grid, kernel, args):
101121
ep_subtile = args["EPILOGUE_SUBTILE"]
102122
if ep_subtile is not None and ep_subtile > 1:
103123
ret["name"] += f" ep/{ep_subtile}"
124+
125+
if hist is not None and n_tokens is None:
126+
return ret # Don't fill metadata because we can't compute them properly.
127+
104128
fM = M if M is not None else n_tokens
105129
fK = K if K is not None else n_tokens
106130
ret[f"flops{nbits}"] = 2.0 * fM * N * fK
@@ -115,7 +139,7 @@ def matmul_launch_metadata(grid, kernel, args):
115139
assert n_tokens is not None
116140
n_expts_act = args["N_EXPTS_ACT"]
117141

118-
if gindx is not None:
142+
if (gindx is not None) and launch_metadata_allow_sync():
119143
# recreate inverse GatherIndx.
120144
dst = torch.full_like(gindx, -1)
121145
idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def _zero_masked_rows(
2929

3030

3131
_matmul_ogs_repr = make_matmul_repr("_matmul_ogs", [0, 1, 2])
32-
@triton.jit(repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
32+
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
33+
repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
3334
def _matmul_ogs(
3435
Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
3536
YExpectedScale, YActualScale, YChecksumScale,

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def _load_writeback_idx_and_mask(WriteBackIndx, writeback_size, offs, mask):
9696

9797

9898
_matmul_ogs_repr = make_matmul_repr("_p_matmul_ogs", [0, 1, 2])
99-
@triton.jit(repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
99+
@triton.jit(do_not_specialize=["TOKENS_PER_EXPT_FOR_ANNOTATION"],
100+
repr=_matmul_ogs_repr, launch_metadata=matmul_launch_metadata)
100101
def _p_matmul_ogs(
101102
Y, Out, stride_y_k, stride_y_z, stride_y_m, stride_y_n,
102103
YExpectedScale, YActualScale, YChecksumScale,
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# proton options
2+
3+
import os
4+
5+
_launch_metadata_allow_sync = None
6+
7+
8+
def launch_metadata_allow_sync():
9+
global _launch_metadata_allow_sync
10+
if _launch_metadata_allow_sync is None:
11+
_launch_metadata_allow_sync = not (os.getenv("PROTON_LAUNCH_METADATA_NOSYNC") == "1")
12+
return _launch_metadata_allow_sync
13+
14+
15+
def set_launch_metadata_allow_sync(allow_sync: bool):
16+
global _launch_metadata_allow_sync
17+
_launch_metadata_allow_sync = allow_sync

0 commit comments

Comments
 (0)