|
19 | 19 |
|
20 | 20 |
|
21 | 21 | @dataclass
|
22 |
| -class Epilogue: |
| 22 | +class EpilogueSpecs: |
23 | 23 | name: str
|
24 | 24 | fn: "triton.runtime.jit.JITFunction"
|
25 | 25 | fn_arg_names: tuple[str]
|
| 26 | + fn_arg_do_not_specialize: tuple[str] = tuple() |
| 27 | + |
| 28 | + |
| 29 | +@dataclass |
| 30 | +class Epilogue: |
| 31 | + specs: EpilogueSpecs |
26 | 32 | fn_arg_values_matmul: tuple[object]
|
27 | 33 | fn_arg_values_finalize: tuple[object]
|
28 |
| - fn_arg_do_not_specialize: tuple[str] = tuple() |
29 | 34 | is_expensive: bool = False
|
30 | 35 |
|
31 | 36 |
|
32 | 37 | _kernels = dict()
|
33 | 38 |
|
34 | 39 |
|
35 |
| -def get_kernels(epilogue: Epilogue): |
| 40 | +def get_kernels(epilogue: EpilogueSpecs): |
36 | 41 | global _kernels
|
37 | 42 | if epilogue.name in _kernels:
|
38 | 43 | return _kernels[epilogue.name]
|
@@ -375,7 +380,7 @@ def compute_grid(BLOCK_N, num_warps):
|
375 | 380 | grid, (BLOCK_N, num_warps) = sorted([(compute_grid(*c), c) for c in candidates], key=lambda x: x[0][1])[0]
|
376 | 381 | STAGES = 1 if num_warps == 1 else min(triton.cdiv(triton.cdiv(N, BLOCK_N), grid[1]), 5)
|
377 | 382 |
|
378 |
| - kernels = get_kernels(epilogue) |
| 383 | + kernels = get_kernels(epilogue.specs) |
379 | 384 | kernels._finalize_matmul[grid](
|
380 | 385 | flex_ctx.out_data.reinterpret(out_scatter),
|
381 | 386 | *out_scatter_flex,
|
@@ -485,7 +490,8 @@ def matmul_ogs(x, w, bias,
|
485 | 490 | if precision_config is None:
|
486 | 491 | precision_config = PrecisionConfig()
|
487 | 492 | if epilogue is None:
|
488 |
| - epilogue = Epilogue("dflt", None, tuple(), tuple(), tuple(), False) |
| 493 | + epilogue_specs = EpilogueSpecs("dflt", None, tuple(), tuple()) |
| 494 | + epilogue = Epilogue(epilogue_specs, tuple(), tuple(), False) |
489 | 495 | if w.ndim == 2:
|
490 | 496 | w = w.view(1, w.shape[-2], w.shape[-1])
|
491 | 497 | if x.ndim == 2:
|
@@ -550,7 +556,7 @@ def matmul_ogs(x, w, bias,
|
550 | 556 | flex = precision_config.flex_ctx
|
551 | 557 | bias_stride = None if bias is None else bias.stride(0)
|
552 | 558 | num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
|
553 |
| - kernels = get_kernels(epilogue) |
| 559 | + kernels = get_kernels(epilogue.specs) |
554 | 560 | (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)](
|
555 | 561 | flex.out_data.reinterpret(memory["output"]),
|
556 | 562 | flex.out_data.reinterpret(out0), *out0.stride(),
|
@@ -595,7 +601,7 @@ def matmul_ogs(x, w, bias,
|
595 | 601 | UPCAST_INDICES=should_upcast_indices(x, w, out0),
|
596 | 602 | DISABLE_Y_TMA=out0.stride(-2) * out0.dtype.itemsize % 16 != 0,
|
597 | 603 | SWAP_XW=swap_xw,
|
598 |
| - NUM_SMS = n_cta, |
| 604 | + NUM_SMS = n_cta if opt_flags.is_persistent else 0, |
599 | 605 | **opt_flags.target_kernel_kwargs)
|
600 | 606 | # post-processing
|
601 | 607 | out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_data.offs,
|
|
0 commit comments