Skip to content

Commit 0f5eccc

Browse files
aeng-openaiptillet
andauthored
[BENCH] fused swiglu activation fn (#6797)
fuse the activation function into the epilogue of the matmul, largely for free reopens and rebases #6756 --------- Co-authored-by: Phil Tillet <[email protected]>
1 parent f5274d4 commit 0f5eccc

File tree

10 files changed

+252
-85
lines changed

10 files changed

+252
-85
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import triton_kernels
88
import triton_kernels.swiglu
99
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, SwizzlingType
10-
from triton_kernels.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx
10+
from triton_kernels.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
1111
from triton_kernels.numerics import InFlexData
1212
from triton_kernels.routing import routing
1313
from triton_kernels.target_info import is_hip, get_cdna_version
@@ -143,7 +143,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
143143
w1, w1_flex, w1_mx = quantize(w1, w_dtype, dev, **opt1)
144144
w2, w2_flex, w2_mx = quantize(w2, w_dtype, dev, **opt2)
145145
pcg = PrecisionConfig(mx_ctx=wg_mx, flex_ctx=FlexCtx(rhs_data=wg_flex))
146-
pcs = triton_kernels.swiglu.PrecisionConfig(limit=1.0)
146+
act = FusedActivation(FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")), (1.0, 1.0), 2)
147147
pc1 = PrecisionConfig(mx_ctx=w1_mx, flex_ctx=FlexCtx(rhs_data=w1_flex))
148148
pc2 = PrecisionConfig(mx_ctx=w2_mx, flex_ctx=FlexCtx(rhs_data=w2_flex))
149149

@@ -166,8 +166,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP,
166166
rdata, gather_indx, scatter_indx = routing(logits, n_expts_act, simulated_ep=EP)
167167
else:
168168
rdata, gather_indx, scatter_indx = None, None, None
169-
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1)
170-
x = triton_kernels.swiglu.swiglu(x, 1.0, pcs, routing_data=rdata)
169+
x = matmul_ogs(x, w1, b1, rdata, gather_indx=gather_indx, precision_config=pc1, fused_activation=act)
171170
x = matmul_ogs(x, w2, b2, rdata, scatter_indx=scatter_indx, precision_config=pc2)
172171
proton.finalize()
173172

python/triton_kernels/tests/test_matmul.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from triton_kernels.routing import routing
88
# matmul utilities
99
import triton_kernels.matmul_ogs_details.opt_flags as opt_flags
10-
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, MicroscalingCtx
10+
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig, MicroscalingCtx, FusedActivation, FnSpecs
1111
from triton_kernels.matmul_ogs import can_use_persistent_tma
1212
from triton_kernels.matmul_ogs import matmul_ogs, matmul_ogs_torch
13+
from triton_kernels.swiglu import swiglu, swiglu_fn, PrecisionConfig as SwiGLUPrecisionConfig
1314
# numerics utilities
1415
from triton_kernels.numerics import InFlexData, OutFlexData
1516
from triton_kernels.numerics_details.mxfp import SwizzlingType, downcast_to_mxfp, upcast_from_mxfp
@@ -122,6 +123,13 @@ def dtype_str_to_torch(dtype_str: str) -> torch.dtype:
122123
return torch.uint8 if dtype_str == "float4_e2m1" else getattr(torch, dtype_str)
123124

124125

126+
# Scope to ensure that the opt_flags_constraints are reset after the test
127+
@pytest.fixture
128+
def opt_flags_scope(request):
129+
yield
130+
opt_flags.reset_opt_flags_constraints()
131+
132+
125133
# ---------------
126134
# unit tests
127135
# ---------------
@@ -218,7 +226,7 @@ class Case:
218226
@pytest.mark.parametrize("is_persistent", [False, True])
219227
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas, is_persistent, n_expts_tot,
220228
n_expts_act, n_expt_shards, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, epilogue_subtile,
221-
device):
229+
device, opt_flags_scope):
222230
# TODO: remove when Triton FP8 supports proper RTNE
223231
if "float8" in weight_dtype_str and torch.cuda.get_device_capability()[0] < 9:
224232
pytest.skip("Float8 not tested on A100")
@@ -401,3 +409,65 @@ def round_x(x, idx):
401409
ref_y_scale = compute_actual_scale(ref_y, tri_y.dtype)
402410
assert (ref_y_scale -
403411
tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}"
412+
413+
414+
@pytest.mark.parametrize("m, n, k, mode", [
415+
(1200, 704, 608, "ragged"),
416+
(800, 800, 400, "batched"),
417+
])
418+
@pytest.mark.parametrize("split_k", [1, 2])
419+
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [
420+
(False, False, False),
421+
(True, False, False),
422+
(False, True, False),
423+
(True, True, False),
424+
(True, True, True),
425+
])
426+
@pytest.mark.parametrize("is_persistent, epilogue_subtile", [
427+
(False, False),
428+
(True, False),
429+
(True, True),
430+
])
431+
@pytest.mark.parametrize("swiglu_alpha, swiglu_limit", [
432+
(1.1, 1.4),
433+
(1.0, 1.2),
434+
(0.7, 1.0),
435+
])
436+
def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, is_persistent, epilogue_subtile,
437+
swiglu_alpha, swiglu_limit, device, opt_flags_scope):
438+
if fused_scatter and split_k > 1:
439+
pytest.skip("fused scatter scratchpad not supported with split_k")
440+
torch.manual_seed(0)
441+
constraints = {
442+
"is_persistent": is_persistent,
443+
"epilogue_subtile": epilogue_subtile,
444+
"fused_scatter": fused_scatter,
445+
"split_k": split_k,
446+
}
447+
n_expts_tot, n_expts_act, n_expt_shards = 1, 1, 1
448+
opt_flags.update_opt_flags_constraints(constraints)
449+
450+
weight_dtype, act_dtype = torch.float16, torch.float16
451+
if mode == "ragged":
452+
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, n_expt_shards, do_gather, do_scatter,
453+
device=device)
454+
else:
455+
rdata = gindx = sindx = None
456+
457+
precision_opt = init_precision(act_dtype, False, False, n_expts_tot // n_expt_shards, device=device)
458+
x, w, bias, _, _ = init_compute_data(m, n, k, gindx, sindx, n_expts_tot, n_expts_act, n_expt_shards, mode,
459+
act_dtype, weight_dtype, False, requires_grad=False, device=device)
460+
461+
if is_persistent and not can_use_persistent_tma(x.view(1, x.shape[-2], x.shape[-1]),
462+
w.view(1, w.shape[-2], w.shape[-1]), gindx, precision_opt):
463+
pytest.skip("persistent TMAs not supported for this test")
464+
465+
if mode == "batched":
466+
rdata, gindx, sindx = None, None, None
467+
a = swiglu(matmul_ogs(x, w, bias, rdata, gindx, sindx, precision_opt), swiglu_alpha,
468+
precision_config=SwiGLUPrecisionConfig(swiglu_limit))
469+
b = matmul_ogs(
470+
x, w, bias, rdata, gindx, sindx, precision_opt,
471+
fused_activation=FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (swiglu_alpha, swiglu_limit),
472+
2))
473+
assert_close(a, b)

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,42 +20,62 @@
2020

2121

2222
@dataclass
23-
class EpilogueSpecs:
23+
class FnSpecs:
2424
name: str
2525
fn: "triton.runtime.jit.JITFunction"
2626
fn_arg_names: tuple[str]
2727
fn_arg_do_not_specialize: tuple[str] = tuple()
2828

29+
@staticmethod
30+
def default():
31+
return FnSpecs("dflt", None, tuple())
32+
33+
34+
@dataclass
35+
class FusedActivation:
36+
specs: FnSpecs
37+
fn_args: tuple[object]
38+
reduction_n: int
39+
2940

3041
@dataclass
3142
class Epilogue:
32-
specs: EpilogueSpecs
43+
specs: FnSpecs
3344
fn_arg_values_matmul: tuple[object]
3445
fn_arg_values_finalize: tuple[object]
3546
is_expensive: bool = False
3647

3748

49+
EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
50+
3851
_kernels = dict()
3952

4053

41-
def get_kernels(epilogue: EpilogueSpecs):
54+
def get_kernels(epilogue: FnSpecs = FnSpecs.default(), fused_activation: FnSpecs = FnSpecs.default()):
4255
global _kernels
43-
if epilogue.name in _kernels:
44-
return _kernels[epilogue.name]
45-
spec_constants = {"EPILOGUE_FN": epilogue.fn}
46-
spec_tuples = {"epilogue_fn_args": epilogue.fn_arg_names}
47-
do_not_specialize = epilogue.fn_arg_do_not_specialize
56+
key = (fused_activation.name, epilogue.name)
57+
if key in _kernels:
58+
return _kernels[key]
59+
spec_constants = {
60+
"ACTIVATION_FN": fused_activation.fn,
61+
"EPILOGUE_FN": epilogue.fn,
62+
}
63+
spec_tuples = {
64+
"activation_fn_args": fused_activation.fn_arg_names,
65+
"epilogue_fn_args": epilogue.fn_arg_names,
66+
}
67+
do_not_specialize = fused_activation.fn_arg_do_not_specialize + epilogue.fn_arg_do_not_specialize
4868
import types
4969

50-
module = types.ModuleType(f"matmul_ogs_{epilogue.name}")
70+
module = types.ModuleType(f"matmul_ogs_{'_'.join(key)}")
5171
sys.modules[module.__name__] = module
5272
module._finalize_matmul = specialize(_finalize_matmul, module, spec_constants, spec_tuples,
5373
do_not_specialize=do_not_specialize)
5474
module._matmul_ogs = specialize(_matmul_ogs, module, spec_constants, spec_tuples,
5575
do_not_specialize=do_not_specialize)
5676
module._p_matmul_ogs = specialize(_p_matmul_ogs, module, spec_constants, spec_tuples,
5777
do_not_specialize=do_not_specialize)
58-
_kernels[epilogue.name] = module
78+
_kernels[key] = module
5979
return module
6080

6181

@@ -254,8 +274,8 @@ def can_use_persistent_tma(x, w, gather_indx, precision_config):
254274
and mx_ctx.swizzle_value is None
255275
)
256276

257-
def can_use_fused_scatter(scatter_indx):
258-
return scatter_indx is not None
277+
def can_use_fused_scatter(scatter_indx, fused_activation):
278+
return scatter_indx is not None and fused_activation.specs.fn is None
259279

260280
# ---------------------
261281
# Preprocessing
@@ -341,7 +361,7 @@ def init_postprocessing_features(routing_data, scatter_indx, opt_flags):
341361
return PostprocessingFeatures(finalize)
342362

343363
def apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_offs, num_indx, precision_config, routing_data,
344-
postprocess_features, memory, epilogue):
364+
postprocess_features, memory, fused_activation, epilogue):
345365
out = memory["output"]
346366
flex_ctx = precision_config.flex_ctx
347367
if postprocess_features.finalize:
@@ -407,14 +427,15 @@ def compute_grid(BLOCK_N, num_warps):
407427
grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
408428
STAGES = 1 if num_warps == 1 else min(triton.cdiv(triton.cdiv(N, BLOCK_N), grid[1]), 5)
409429

410-
kernels = get_kernels(epilogue.specs)
430+
kernels = get_kernels(epilogue.specs, fused_activation.specs)
411431
kernels._finalize_matmul[grid](
412432
flex_ctx.out_data.reinterpret(out_scatter),
413433
*out_scatter_flex,
414434
flex_ctx.out_data.reinterpret(inp), inp.stride(0), inp.stride(2),
415435
inp_flex.expected_scale,
416436
scatter_src_indx, finalize_scatter_idxs,
417437
inp.shape[0], M, N, num_rows,
438+
*fused_activation.fn_args, fused_activation.reduction_n,
418439
*epilogue.fn_arg_values_finalize,
419440
EXPT_PER_TOK=EXPT_PER_TOK,
420441
BLOCK_N=BLOCK_N,
@@ -443,7 +464,7 @@ class MatmulAllocation:
443464
output: tuple[tuple[int], torch.dtype]
444465
scratchpads: dict[str, tuple]
445466

446-
def init_allocation(x, w, precision_config, routing_data, gather_indx, scatter_indx, opt_flags,
467+
def init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags,
447468
preprocessing_features, postprocessing_features):
448469
# ---- output ------
449470
N = precision_config.mx_ctx.get_packed_tensor_logical_shape(w)[-1]
@@ -462,7 +483,7 @@ def init_allocation(x, w, precision_config, routing_data, gather_indx, scatter_i
462483
else:
463484
Mc = scatter_indx.src_indx.shape[0] // routing_data.n_expts_act # compressed number of rows
464485
y_rows = Mc
465-
y_shape = (x.shape[0], y_rows, N)
486+
y_shape = (x.shape[0], y_rows, N // fused_activation.reduction_n)
466487
out_dtype = precision_config.out_dtype or x.dtype
467488
output = (y_shape, out_dtype)
468489
# ---- scratchpad -----#
@@ -500,6 +521,7 @@ def matmul_ogs(x, w, bias,
500521
gammas: torch.Tensor | None = None,
501522
out_alpha: float | None = None,
502523
y: torch.Tensor | None = None,
524+
fused_activation: FusedActivation | None = None,
503525
epilogue: Epilogue | None = None,
504526
):
505527
"""
@@ -516,9 +538,10 @@ def matmul_ogs(x, w, bias,
516538
assert w.ndim == 3 and w.shape[0] == x.shape[0]
517539
if precision_config is None:
518540
precision_config = PrecisionConfig()
541+
if fused_activation is None:
542+
fused_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
519543
if epilogue is None:
520-
epilogue_specs = EpilogueSpecs("dflt", None, tuple(), tuple())
521-
epilogue = Epilogue(epilogue_specs, tuple(), tuple(), False)
544+
epilogue = Epilogue(FnSpecs.default(), tuple(), tuple(), False)
522545
if w.ndim == 2:
523546
w = w.view(1, w.shape[-2], w.shape[-1])
524547
if x.ndim == 2:
@@ -540,7 +563,7 @@ def matmul_ogs(x, w, bias,
540563
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
541564
M, N, K, routing_data,
542565
can_use_persistent_tma(x, w, gather_indx, precision_config),
543-
can_use_fused_scatter(scatter_indx),
566+
can_use_fused_scatter(scatter_indx, fused_activation),
544567
epilogue.is_expensive,
545568
)
546569
# compute grid size
@@ -551,25 +574,27 @@ def matmul_ogs(x, w, bias,
551574
grid_n = triton.cdiv(N, opt_flags.block_n)
552575
assert n_expts_tot == routing_data.n_expts_tot
553576
assert grid_m > 0
554-
assert x.dtype == w.dtype or mx_ctx.weight_scale is not None
555577
# determine necessary pre/post processing
556578
preprocessing_features = init_preprocessing_features(w, precision_config, opt_flags)
557579
postprocessing_features = init_postprocessing_features(routing_data, scatter_indx, opt_flags)
558580
# allocate output/scratchpad memory
559-
allocation = init_allocation(x, w, precision_config, routing_data, gather_indx, scatter_indx, opt_flags,
581+
allocation = init_allocation(x, w, precision_config, fused_activation, routing_data, gather_indx, scatter_indx, opt_flags,
560582
preprocessing_features, postprocessing_features)
561583
memory = apply_allocation(allocation, y)
562584
# TMA descriptors require a global memory allocation
563585
if opt_flags.is_persistent:
564586
triton.set_allocator(get_per_device_per_stream_alloc_fn(x.device))
565587
# Intermediate tensors and postprocess kernels for each situation
566588
out0, out0_flex = memory["output"], precision_config.flex_ctx.out_data
589+
fused_postprocess_activation = FusedActivation(FnSpecs.default(), tuple(), 1)
567590
if postprocessing_features.finalize:
568591
if opt_flags.fused_scatter:
569592
out0 = memory["output"]
570593
else:
571594
out0 = memory["scratchpad"]["matmul"]
572595
out0_flex = OutFlexData() if out0.dtype == torch.float32 else precision_config.flex_ctx.out_data
596+
597+
fused_activation, fused_postprocess_activation = fused_postprocess_activation, fused_activation
573598
# pre-processing
574599
x, w, swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs, expt_data = apply_preprocessing_features(
575600
x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features
@@ -584,7 +609,7 @@ def matmul_ogs(x, w, bias,
584609
flex = precision_config.flex_ctx
585610
bias_stride = None if bias is None else bias.stride(0)
586611
num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
587-
kernels = get_kernels(epilogue.specs)
612+
kernels = get_kernels(epilogue.specs, fused_activation.specs)
588613
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)](
589614
flex.out_data.reinterpret(memory["output"]),
590615
flex.out_data.reinterpret(out0), *out0.stride(),
@@ -606,6 +631,7 @@ def matmul_ogs(x, w, bias,
606631
expt_data.hist, expt_data.offs, expt_data.offs_sum, expt_data.blocks,
607632
batch_size, grid_m, grid_n,
608633
out_alpha,
634+
*fused_activation.fn_args, fused_activation.reduction_n,
609635
*epilogue.fn_arg_values_matmul,
610636
routing_data.n_expts_tot, routing_data.n_expts_act,
611637
precision_config.max_num_imprecise_acc,
@@ -635,7 +661,7 @@ def matmul_ogs(x, w, bias,
635661
# post-processing
636662
out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_data.offs,
637663
num_indx, precision_config, routing_data,
638-
postprocessing_features, memory, epilogue)
664+
postprocessing_features, memory, fused_postprocess_activation, epilogue)
639665

640666
# remove split-k
641667
out = out.squeeze(0)

0 commit comments

Comments
 (0)