Skip to content

Commit ade4d3a

Browse files
authored
[KERNELS] decouple matmul_ogs.Epilogue specifications from arguments; fix NUM_SMS value (#6878)
1 parent d457677 commit ade4d3a

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,25 @@
1919

2020

2121
@dataclass
22-
class Epilogue:
22+
class EpilogueSpecs:
2323
name: str
2424
fn: "triton.runtime.jit.JITFunction"
2525
fn_arg_names: tuple[str]
26+
fn_arg_do_not_specialize: tuple[str] = tuple()
27+
28+
29+
@dataclass
30+
class Epilogue:
31+
specs: EpilogueSpecs
2632
fn_arg_values_matmul: tuple[object]
2733
fn_arg_values_finalize: tuple[object]
28-
fn_arg_do_not_specialize: tuple[str] = tuple()
2934
is_expensive: bool = False
3035

3136

3237
_kernels = dict()
3338

3439

35-
def get_kernels(epilogue: Epilogue):
40+
def get_kernels(epilogue: EpilogueSpecs):
3641
global _kernels
3742
if epilogue.name in _kernels:
3843
return _kernels[epilogue.name]
@@ -375,7 +380,7 @@ def compute_grid(BLOCK_N, num_warps):
375380
grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
376381
STAGES = 1 if num_warps == 1 else min(triton.cdiv(triton.cdiv(N, BLOCK_N), grid[1]), 5)
377382

378-
kernels = get_kernels(epilogue)
383+
kernels = get_kernels(epilogue.specs)
379384
kernels._finalize_matmul[grid](
380385
flex_ctx.out_data.reinterpret(out_scatter),
381386
*out_scatter_flex,
@@ -485,7 +490,8 @@ def matmul_ogs(x, w, bias,
485490
if precision_config is None:
486491
precision_config = PrecisionConfig()
487492
if epilogue is None:
488-
epilogue = Epilogue("dflt", None, tuple(), tuple(), tuple(), False)
493+
epilogue_specs = EpilogueSpecs("dflt", None, tuple(), tuple())
494+
epilogue = Epilogue(epilogue_specs, tuple(), tuple(), False)
489495
if w.ndim == 2:
490496
w = w.view(1, w.shape[-2], w.shape[-1])
491497
if x.ndim == 2:
@@ -550,7 +556,7 @@ def matmul_ogs(x, w, bias,
550556
flex = precision_config.flex_ctx
551557
bias_stride = None if bias is None else bias.stride(0)
552558
num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
553-
kernels = get_kernels(epilogue)
559+
kernels = get_kernels(epilogue.specs)
554560
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)](
555561
flex.out_data.reinterpret(memory["output"]),
556562
flex.out_data.reinterpret(out0), *out0.stride(),
@@ -595,7 +601,7 @@ def matmul_ogs(x, w, bias,
595601
UPCAST_INDICES=should_upcast_indices(x, w, out0),
596602
DISABLE_Y_TMA=out0.stride(-2) * out0.dtype.itemsize % 16 != 0,
597603
SWAP_XW=swap_xw,
598-
NUM_SMS = n_cta,
604+
NUM_SMS = n_cta if opt_flags.is_persistent else 0,
599605
**opt_flags.target_kernel_kwargs)
600606
# post-processing
601607
out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_data.offs,

0 commit comments

Comments
 (0)