Skip to content

Commit 318fa9c

Browse files
authored
[triton_kernels] forbid use of split_k > 1 with fused scatter (#8618)
API doesn't accept scale for the intermediate tensor produced between split_k and fused_scatter; this mode should therefore be disabled for now. Will be re-enabled after expert aggregation is moved out of the matmul_ogs API
1 parent 14fd9cb commit 318fa9c

File tree

4 files changed

+40
-61
lines changed

4 files changed

+40
-61
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -303,27 +303,27 @@ class Case:
303303
],
304304
)
305305
@pytest.mark.parametrize("block_m", [16, 128])
306-
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter, inner_expt_opt", [
307-
(False, False, False, None),
308-
(True, False, False, None),
309-
(False, True, False, None),
310-
(False, True, True, None),
311-
(True, True, False, None),
312-
(True, True, True, None),
313-
(False, False, False, "pad_w"),
314-
(False, False, False, "pad_x"),
306+
@pytest.mark.parametrize("do_gather, do_scatter, inner_expt_opt", [
307+
(False, False, None),
308+
(True, False, None),
309+
(False, True, None),
310+
(False, True, None),
311+
(True, True, None),
312+
(True, True, None),
313+
(False, False, "pad_w"),
314+
(False, False, "pad_x"),
315315
])
316316
@pytest.mark.parametrize("has_y_gammas", [False, True])
317317
@pytest.mark.parametrize("is_persistent", [False, True])
318-
def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
318+
def test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
319319
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
320320
x_transpose, w_transpose, y_transpose,
321321
device, opt_flags_scope):
322322
# We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to
323323
# the frame that called pytest.skip, including all the tensors, leading to OOM.
324324
skip_message = None
325325
try:
326-
_test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
326+
_test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
327327
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
328328
x_transpose, w_transpose, y_transpose,
329329
device, opt_flags_scope)
@@ -333,7 +333,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
333333
if skip_message is not None:
334334
pytest.skip(skip_message)
335335

336-
def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
336+
def _test_op(m, n, k, split_k, do_gather, do_scatter, inner_expt_opt, has_y_gammas, is_persistent, n_expts_tot,
337337
n_expts_act, mode, act_dtype_str, weight_dtype_str, block_m, hbm_swizzling, colmajor_mxfp_weight, epilogue_subtile,
338338
x_transpose, w_transpose, y_transpose,
339339
device, opt_flags_scope):
@@ -362,9 +362,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_
362362
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
363363
pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
364364

365-
if fused_scatter and split_k is not None and split_k > 1:
366-
pytest.skip("fused scatter scratchpad not supported with split_k")
367-
368365
if hbm_swizzling:
369366
if is_hip():
370367
if not is_hip_cdna4():
@@ -413,7 +410,6 @@ def _test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_
413410
"block_m": block_m,
414411
"block_k": block_k,
415412
"split_k": split_k,
416-
"fused_scatter": fused_scatter,
417413
"is_persistent": is_persistent,
418414
"epilogue_subtile": epilogue_subtile,
419415
}
@@ -726,12 +722,11 @@ def test_set_idle_sms():
726722
(800, 800, 400, "batched"),
727723
])
728724
@pytest.mark.parametrize("split_k", [1, 2])
729-
@pytest.mark.parametrize("do_gather, do_scatter, fused_scatter", [
730-
(False, False, False),
731-
(True, False, False),
732-
(False, True, False),
733-
(True, True, False),
734-
(True, True, True),
725+
@pytest.mark.parametrize("do_gather, do_scatter", [
726+
(False, False),
727+
(True, False),
728+
(False, True),
729+
(True, True),
735730
])
736731
@pytest.mark.parametrize("is_persistent, epilogue_subtile", [
737732
(False, None),
@@ -743,16 +738,13 @@ def test_set_idle_sms():
743738
(1.0, 1.2),
744739
(0.7, 1.0),
745740
])
746-
def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter, is_persistent, epilogue_subtile,
741+
def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, is_persistent, epilogue_subtile,
747742
swiglu_alpha, swiglu_limit, device, opt_flags_scope):
748-
if fused_scatter and split_k > 1:
749-
pytest.skip("fused scatter scratchpad not supported with split_k")
750743
torch.manual_seed(0)
751744
constraints = {
752745
"is_persistent": is_persistent,
753746
"epilogue_subtile": epilogue_subtile,
754747
"split_k": split_k,
755-
"fused_scatter": fused_scatter,
756748
}
757749
n_expts_tot, n_expts_act = 1, 1
758750
opt_flags.update_opt_flags_constraints(constraints)

python/triton_kernels/tests/test_matmul_details/test_opt_flags_split_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def get_flags(split_k, max_mn):
185185
k,
186186
None,
187187
False,
188-
False,
188+
True,
189189
False,
190190
0,
191191
False,

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,10 @@ def matmul_ogs(x, w, bias,
419419
has_gather_tma = has_gather and target_info.has_tma_gather()
420420
# hopper w/ mxfp4 doesn't support TMA
421421
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
422+
can_use_split_k = scatter_indx is None and not x_has_mx and not w_has_mx
422423
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
423424
batch_size, M, N, w.shape[-2], routing_data,
424-
can_use_tma, scatter_indx is not None, epilogue.effective_itemsize,
425+
can_use_tma, can_use_split_k, epilogue.effective_itemsize,
425426
x_transpose, y_acc_in is not None,
426427
inner_routing_data.block_k if inner_routing_data is not None else None,
427428
)
@@ -618,21 +619,21 @@ def matmul_ogs(x, w, bias,
618619
**fused_comm_kwargs,
619620
**opt_flags.target_kernel_kwargs)
620621

622+
assert not (opt_flags.split_k > 1 and scatter_indx is not None)
621623
out_final_mx_scale = None
622624
if opt_flags.split_k > 1:
623625
assert not out_matmul_has_mx
624-
has_scatter = scatter_indx is not None
625626
postprocess_fn1 = ReducePostprocessFn(specs=reduce_fused_activation.specs, fn_args=reduce_fused_activation.fn_args)
626627
postprocess_fn2 = None if has_scatter else ReducePostprocessFn(specs=epilogue.specs, fn_args=epilogue.fn_arg_values_finalize)
627628
y, y_mx_scale = reduce(
628629
x = out_matmul.view(out_matmul.shape[0], -1, out_matmul.shape[-1]),
629630
dim = 0,
630631
# output data/metadata
631-
y = None if has_scatter else memory["output"].view(-1, memory["output"].shape[-1]),
632-
y_dtype = out_matmul.dtype if has_scatter else memory["output"].dtype,
633-
y_flex = OutFlexData() if has_scatter else precision_config.flex_ctx.out_data,
634-
y_flex_saturate_inf = None if has_scatter else precision_config.flexpoint_saturate_inf,
635-
y_has_mx = scatter_indx is None and precision_config.out_scale is not None,
632+
y = memory["output"].view(-1, memory["output"].shape[-1]),
633+
y_dtype = memory["output"].dtype,
634+
y_flex = precision_config.flex_ctx.out_data,
635+
y_flex_saturate_inf = precision_config.flexpoint_saturate_inf,
636+
y_has_mx = precision_config.out_scale is not None,
636637
# fused functions
637638
postprocess_fn1 = postprocess_fn1,
638639
postprocess_fn2 = postprocess_fn2,

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class OptFlags:
2222
w_cache_modifier: str
2323
split_k: int
2424
is_persistent: bool
25-
fused_scatter: bool
2625
idle_sms: int
2726
epilogue_subtile: int | None
2827
arch: str
@@ -56,14 +55,14 @@ def make_default_opt_flags_amd(
5655
k,
5756
routing_data,
5857
can_use_persistent_tma,
59-
can_use_fused_scatter,
58+
can_use_split_k,
6059
enforce_bitwise_invariance,
6160
epilogue_effective_itemsize,
6261
x_transpose,
6362
has_y_acc_in,
6463
constraints,
6564
):
66-
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "max_allowable_mn"]
65+
constraints_supported = ["block_m", "block_n", "block_k", "split_k", "is_persistent", "epilogue_subtile", "max_allowable_mn"]
6766
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
6867
# tokens per expert
6968
if routing_data is None:
@@ -102,13 +101,12 @@ def make_default_opt_flags_amd(
102101
)
103102
is_persistent = constraints.get("is_persistent", False)
104103
# split_k:
104+
split_k = 1
105105
if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None:
106106
split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k"))
107107
elif constraints.get("split_k", None) is not None:
108108
split_k = constraints["split_k"]
109-
elif is_persistent or enforce_bitwise_invariance:
110-
split_k = 1
111-
else:
109+
elif can_use_split_k and not enforce_bitwise_invariance:
112110
grid_size = grid_m * ((n + block_n - 1) // block_n)
113111
n_cu = torch.cuda.get_device_properties(0).multi_processor_count
114112
split_k = max(1, n_cu // grid_size)
@@ -156,7 +154,6 @@ def replace_with_valid_constraint(k: str, v):
156154
w_cache_modifier=w_cache_modifier,
157155
split_k=split_k,
158156
is_persistent=is_persistent,
159-
fused_scatter=constraints.get('fused_scatter', False),
160157
idle_sms=0,
161158
epilogue_subtile=epilogue_subtile,
162159
arch=None,
@@ -177,14 +174,14 @@ def make_default_opt_flags_nvidia(
177174
k,
178175
routing_data,
179176
can_use_persistent_tma,
180-
can_use_fused_scatter,
177+
can_use_split_k,
181178
enforce_bitwise_invariance,
182179
epilogue_effective_itemsize,
183180
x_transpose,
184181
has_y_acc_in,
185182
constraints,
186183
):
187-
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn"]
184+
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms", "max_allowable_mn"]
188185
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
189186
# tokens per expert
190187
if routing_data is None or batch_size > 1:
@@ -236,26 +233,21 @@ def make_default_opt_flags_nvidia(
236233
if constraints.get("block_k", None) is not None:
237234
block_k = constraints["block_k"]
238235
# split_k
236+
split_k = 1
239237
if constraints.get("max_allowable_mn", 0) > 0 and constraints.get("split_k") is not None:
240238
split_k = max_allowable_mn(constraints["max_allowable_mn"], m, n, constraints.get("split_k"))
241239
elif constraints.get("split_k", None) is not None:
242240
split_k = constraints["split_k"]
243-
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
244-
split_k = 1
245-
else:
241+
elif can_use_split_k and not enforce_bitwise_invariance:
246242
estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, batch_size, m, n, block_m, block_n)
247243
split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
248-
if split_k > 1:
249-
# With split_k, results are written in f32. Use that for the following computations.
250-
out_dtype = torch.float32
251244
compute_num_stages_args = (
252245
precision_config,
253246
is_persistent,
254-
255247
block_m,
256248
block_n,
257249
block_k,
258-
out_dtype,
250+
torch.float32 if split_k > 1 else out_dtype,
259251
lhs_dtype,
260252
rhs_dtype,
261253
x_transpose,
@@ -276,11 +268,6 @@ def make_default_opt_flags_nvidia(
276268
if constraints.get("num_stages", None):
277269
num_stages = constraints["num_stages"]
278270
assert num_stages >= 1
279-
# fused scatter scratchpad
280-
if constraints.get("fused_scatter", None) is not None:
281-
fused_scatter = constraints["fused_scatter"]
282-
else:
283-
fused_scatter = can_use_fused_scatter and split_k == 1
284271
# Handshake with the HBM swizzling
285272
num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, is_persistent, precision_config)
286273
ret = OptFlags(
@@ -289,7 +276,6 @@ def make_default_opt_flags_nvidia(
289276
block_k=block_k,
290277
num_warps=num_warps,
291278
num_stages=num_stages,
292-
fused_scatter=fused_scatter,
293279
group_m=group_m,
294280
xcd_swizzle=xcd_swizzle,
295281
w_cache_modifier=None,
@@ -343,16 +329,16 @@ def make_opt_flags(
343329
k,
344330
routing_data,
345331
can_use_persistent_tma,
346-
can_use_fused_scatter,
332+
can_use_split_k,
347333
epilogue_effective_itemsize,
348334
x_transpose,
349335
has_y_acc_in,
350336
block_k,
351337
):
352338
if _opt_flags_constraints.get("is_persistent", False) and not can_use_persistent_tma:
353339
raise InapplicableConstraint("cannot enforce `is_persistent=True` constraint")
354-
if _opt_flags_constraints.get("fused_scatter", False) and not can_use_fused_scatter:
355-
raise InapplicableConstraint("cannot enforce `fused_scatter=True` constraint")
340+
if _opt_flags_constraints.get("split_k") is not None and _opt_flags_constraints.get("split_k") > 1 and not can_use_split_k:
341+
raise InapplicableConstraint("cannot enforce `split_k=True` constraint")
356342
if _opt_flags_constraints.get("max_allowable_mn"):
357343
if not _opt_flags_constraints.get("split_k"):
358344
raise InapplicableConstraint("split_k also needs to be provided with max_allowable_mn")
@@ -366,7 +352,7 @@ def make_opt_flags(
366352
opt_flags_constraints = opt_flags_constraints.copy()
367353
opt_flags_constraints.update(block_k=block_k, split_k=1)
368354
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, batch_size, m, n, k,
369-
routing_data, can_use_persistent_tma, can_use_fused_scatter,
355+
routing_data, can_use_persistent_tma, can_use_split_k,
370356
enforce_bitwise_invariance, epilogue_effective_itemsize, x_transpose, has_y_acc_in,
371357
opt_flags_constraints]
372358
backend = triton.runtime.driver.active.get_current_target().backend

0 commit comments

Comments
 (0)