Skip to content

Commit 24863d6

Browse files
authored
[KERNELS] Enable 4-way epilogue subtiling for _p_matmul_ogs. (#7056)
* This uses PR #7044. * Also streamlined "expensive epilogue" handling so that we can specify how expensive it is.
1 parent d141ab8 commit 24863d6

File tree

6 files changed

+63
-36
lines changed

6 files changed

+63
-36
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class Case:
148148
n_expt_shards: int = 1
149149
split_k: int = 1
150150
hbm_swizzling: bool = False
151-
epilogue_subtile: Union[bool, None] = None
151+
epilogue_subtile: Union[int, None] = None
152152

153153

154154
@pytest.mark.parametrize(
@@ -171,8 +171,9 @@ class Case:
171171
Case(300, 400, 400, "ragged", "float16", "float16"),
172172
Case(300, 400, 400, "ragged", "float8_e5m2", "float8_e5m2"),
173173
Case(1000, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 3, 1),
174-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=False),
175-
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=True),
174+
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=1),
175+
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=2),
176+
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, epilogue_subtile=4),
176177
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2),
177178
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 2, n_expt_shards=2),
178179
Case(600, 400, 400, "ragged", "float8_e5m2", "float8_e5m2", 4, 1, n_expt_shards=2),
@@ -424,9 +425,9 @@ def round_x(x, idx):
424425
(True, True, True),
425426
])
426427
@pytest.mark.parametrize("is_persistent, epilogue_subtile", [
427-
(False, False),
428-
(True, False),
429-
(True, True),
428+
(False, None),
429+
(True, 1),
430+
(True, 4),
430431
])
431432
@pytest.mark.parametrize("swiglu_alpha, swiglu_limit", [
432433
(1.1, 1.4),

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class Epilogue:
4343
specs: FnSpecs
4444
fn_arg_values_matmul: tuple[object]
4545
fn_arg_values_finalize: tuple[object]
46-
is_expensive: bool = False
46+
effective_itemsize: float | None = None
4747

4848

4949
EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
@@ -564,7 +564,7 @@ def matmul_ogs(x, w, bias,
564564
M, N, K, routing_data,
565565
can_use_persistent_tma(x, w, gather_indx, precision_config),
566566
can_use_fused_scatter(scatter_indx, fused_activation),
567-
epilogue.is_expensive,
567+
epilogue.effective_itemsize,
568568
)
569569
# compute grid size
570570
if not is_input_batched:

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ def matmul_launch_metadata(grid, kernel, args):
8585
batch_repr = ""
8686
if "batch_size" in args and args["batch_size"] > 1:
8787
batch_repr = repr("B", args["batch_size"]) + ", "
88-
ret["name"] = f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}]"
88+
ret["name"] = f"{kernel.name} [{batch_repr}{repr('M', M)}, {repr('N', N)}, {repr('K', K)}] stg{kernel.num_stages}"
89+
ep_subtile = args["EPILOGUE_SUBTILE"]
90+
if ep_subtile is not None and ep_subtile > 1:
91+
ret["name"] += f" ep/{ep_subtile}"
8992
fM = M if M is not None else n_tokens
9093
fK = K if K is not None else n_tokens
9194
ret[f"flops{nbits}"] = 2.0 * fM * N * fK

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,11 @@ def _p_matmul_ogs(
199199
HAS_FUSED_SCATTER: tl.constexpr = WriteBackIndx is not None
200200
index_type: tl.constexpr = tl.int64
201201

202-
EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // 2 if EPILOGUE_SUBTILE else BLOCK_N
202+
if EPILOGUE_SUBTILE is None:
203+
SUBTILE_FACTOR: tl.constexpr = 1
204+
else:
205+
SUBTILE_FACTOR: tl.constexpr = EPILOGUE_SUBTILE
206+
EPILOGUE_BLOCK_N: tl.constexpr = BLOCK_N // SUBTILE_FACTOR
203207
OUT_BLOCK_N: tl.constexpr = EPILOGUE_BLOCK_N // ACTIVATION_REDUCTION_N
204208
yN = N // ACTIVATION_REDUCTION_N
205209

@@ -500,12 +504,26 @@ def _p_matmul_ogs(
500504
else:
501505
w_scale = load_scale(WScale)
502506

503-
if EPILOGUE_SUBTILE:
504-
accs = tl.split(tl.permute(tl.reshape(acc, (BLOCK_M, 2, EPILOGUE_BLOCK_N)), (0, 2, 1)))
505-
biases = tl.split(tl.permute(tl.reshape(bias, (2, EPILOGUE_BLOCK_N)), (1, 0)))
506-
else:
507-
accs = (acc,)
508-
biases = (bias,)
507+
accs = (acc,)
508+
biases = (bias,)
509+
510+
if SUBTILE_FACTOR >= 2:
511+
acc0, acc1 = acc.reshape(BLOCK_M, 2, BLOCK_N // 2).permute(0, 2, 1).split()
512+
accs = (acc0, acc1)
513+
bias0, bias1 = bias.reshape(2, BLOCK_N // 2).permute(1, 0).split()
514+
biases = (bias0, bias1)
515+
516+
if SUBTILE_FACTOR >= 4:
517+
acc00, acc01 = acc0.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
518+
acc10, acc11 = acc1.reshape(BLOCK_M, 2, BLOCK_N // 4).permute(0, 2, 1).split()
519+
accs = (acc00, acc01, acc10, acc11)
520+
bias00, bias01 = bias0.reshape(2, BLOCK_N // 4).permute(1, 0).split()
521+
bias10, bias11 = bias1.reshape(2, BLOCK_N // 4).permute(1, 0).split()
522+
biases = (bias00, bias01, bias10, bias11)
523+
524+
tl.static_assert(EPILOGUE_BLOCK_N == BLOCK_N // SUBTILE_FACTOR)
525+
tl.static_assert(len(accs) == SUBTILE_FACTOR)
526+
tl.static_assert(len(biases) == SUBTILE_FACTOR)
509527

510528
for a_i in tl.static_range(len(accs)):
511529
acc_tile = accs[a_i]

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class OptFlags:
2020
split_k: int
2121
fused_scatter: bool
2222
is_persistent: bool
23-
epilogue_subtile: bool
23+
epilogue_subtile: int | None
2424
arch: str
2525
target_kernel_kwargs: dict
2626

@@ -43,7 +43,7 @@ def make_default_opt_flags_amd(
4343
can_use_persistent_tma,
4444
can_use_fused_scatter,
4545
enforce_bitwise_invariance,
46-
has_expensive_epilogue,
46+
epilogue_effective_itemsize,
4747
constraints,
4848
):
4949
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
@@ -106,7 +106,7 @@ def make_default_opt_flags_amd(
106106
if constraints.get("epilogue_subtile", None) is not None:
107107
epilogue_subtile = constraints["epilogue_subtile"]
108108
else:
109-
epilogue_subtile = False
109+
epilogue_subtile = None
110110
# AMD-specific
111111
target_kernel_kwargs = {"waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1}
112112
ret = OptFlags(
@@ -142,10 +142,10 @@ def make_default_opt_flags_nvidia(
142142
can_use_persistent_tma,
143143
can_use_fused_scatter,
144144
enforce_bitwise_invariance,
145-
has_expensive_epilogue,
145+
epilogue_effective_itemsize,
146146
constraints,
147147
):
148-
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile"]
148+
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "num_stages"]
149149
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
150150
# tokens per expert
151151
if routing_data is None:
@@ -175,7 +175,7 @@ def make_default_opt_flags_nvidia(
175175
if constraints.get("is_persistent", None) is not None:
176176
is_persistent = constraints["is_persistent"]
177177
else:
178-
has_simple_epilogue = precision_config.max_num_imprecise_acc is None and not has_expensive_epilogue
178+
has_simple_epilogue = precision_config.max_num_imprecise_acc is None
179179
is_persistent = supports_persistent and has_simple_epilogue and (tiles_per_sm >= 2.0 or lhs_dtype.itemsize <= 1) and out_dtype.itemsize < 4
180180
# block k
181181
if constraints.get("block_k", None) is not None:
@@ -204,14 +204,20 @@ def make_default_opt_flags_nvidia(
204204
lhs_dtype,
205205
rhs_dtype,
206206
)
207+
207208
if constraints.get("epilogue_subtile", None) is not None:
208-
epilogue_subtile = constraints["epilogue_subtile"]
209+
subtiles_to_check = [constraints["epilogue_subtile"]]
209210
else:
210-
n1 = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, False, has_expensive_epilogue)
211-
n2 = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, True, has_expensive_epilogue)
212-
epilogue_subtile = n2 > n1 # enable epilogue_subtile if it increases the number of stages
213-
# num_stages
214-
num_stages = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, epilogue_subtile, has_expensive_epilogue)
211+
subtiles_to_check = [1, 2, 4]
212+
num_stages = -1
213+
for ep in subtiles_to_check:
214+
ns = opt_flags_nvidia.compute_num_stages(*compute_num_stages_args, ep, epilogue_effective_itemsize)
215+
if ns > num_stages:
216+
epilogue_subtile, num_stages = ep, ns
217+
assert num_stages >= 1
218+
if constraints.get("num_stages", None):
219+
num_stages = constraints["num_stages"]
220+
215221
# fused scatter scratchpad
216222
if constraints.get("fused_scatter", None) is not None:
217223
fused_scatter = constraints["fused_scatter"]
@@ -273,7 +279,7 @@ def make_opt_flags(
273279
routing_data,
274280
can_use_persistent_tma,
275281
can_use_fused_scatter,
276-
has_expensive_epilogue,
282+
epilogue_effective_itemsize,
277283
):
278284
microscaling_ctx = precision_config.mx_ctx
279285
enforce_bitwise_invariance = precision_config.enforce_bitwise_invariance
@@ -282,7 +288,7 @@ def make_opt_flags(
282288
return _opt_flags
283289
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, microscaling_ctx, m, n, k,
284290
routing_data, can_use_persistent_tma, can_use_fused_scatter,
285-
enforce_bitwise_invariance, has_expensive_epilogue, _opt_flags_constraints]
291+
enforce_bitwise_invariance, epilogue_effective_itemsize, _opt_flags_constraints]
286292
backend = triton.runtime.driver.active.get_current_target().backend
287293
if backend == "hip":
288294
return make_default_opt_flags_amd(*args)

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_nvidia.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def compute_num_stages(
6969
lhs_dtype,
7070
rhs_dtype,
7171
epilogue_subtile,
72-
has_expensive_epilogue,
72+
epilogue_effective_itemsize,
7373
):
7474
if precision_config.max_num_imprecise_acc is not None:
7575
return 3
@@ -88,19 +88,18 @@ def compute_num_stages(
8888
if is_persistent:
8989
# Per-stage wait barrier
9090
stage_size += 8
91-
acc_size = out_dtype.itemsize
9291
if target_info.cuda_capability_geq(10, 0):
93-
acc_size = 4 if has_expensive_epilogue else out_dtype.itemsize
92+
acc_size = epilogue_effective_itemsize or out_dtype.itemsize
9493
else:
9594
acc_size = out_dtype.itemsize
96-
if target_info.cuda_capability_geq(10, 0) and epilogue_subtile and not has_expensive_epilogue:
97-
acc_block_n = block_n // 2
95+
if target_info.cuda_capability_geq(10, 0) and epilogue_subtile is not None:
96+
acc_block_n = block_n // epilogue_subtile
9897
else:
9998
acc_block_n = block_n
10099
# pipelined TMA store local to global, or
101100
# pipelined layout conversion before store of the accumulator
102101
# note: layout conversion has some padding
103-
smem_capacity -= (block_m + 4) * acc_block_n * acc_size
102+
smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
104103
if microscaling_ctx.weight_scale is not None:
105104
# mx scales
106105
stage_size += block_n * (block_k // 32)

0 commit comments

Comments
 (0)