Skip to content

Commit eb7cdba

Browse files
authored
[KERNELS] Fix and enable batched matmul with split-k. (#8327)
1 parent bc22e6e commit eb7cdba

File tree

4 files changed

+25
-16
lines changed

4 files changed

+25
-16
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _apply_padding_and_fill_unused_part_with_nan(t, is_padded):
114114
# ---------------
115115

116116

117-
def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_expts_tot=1, expt_is_inner=False, device="cuda"):
117+
def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, mode, n_expts_tot=1, expt_is_inner=False, device="cuda"):
118118
weight_use_flexpoint = weight_dtype.itemsize == 1 and not weight_mxfp
119119
# flexpoint
120120
make_tensor = lambda val0, val1: torch.tensor([val0, val1] * (n_expts_tot // 2) +
@@ -133,8 +133,8 @@ def init_precision(out_dtype, act_use_flexpoint, weight_dtype, weight_mxfp, n_ex
133133
) if weight_use_flexpoint else InFlexData(),
134134
out_data=OutFlexData(
135135
dtype=out_dtype,
136-
expected_scale=make(4.00, 5.00, expt_is_inner),
137-
actual_scale=make(0, 0, expt_is_inner),
136+
expected_scale=make(4.00, 5.00, mode == "batched" or expt_is_inner),
137+
actual_scale=make(0, 0, mode == "batched" or expt_is_inner),
138138
checksum_scale=None,
139139
) if act_use_flexpoint else OutFlexData(),
140140
)
@@ -233,6 +233,7 @@ class Case:
233233
Case(1000, 700, 700, "ragged", "float16", "float16", 8, 2, split_k=9),
234234
Case(16, 16, 1000, "batched", "float16", "float16", 5, 1, split_k=None),
235235
Case(16, 16, 1000, "batched", "float8_e5m2", "float8_e5m2", 5, 1, split_k=None),
236+
Case(16, 16, 2048, "batched", "float8_e5m2", "float8_e5m2", 6, 1, split_k=5),
236237
# mx types:
237238
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1),
238239
Case(16, 256, 256, "plain", "bfloat16", "mxfloat4_e2m1", 1, 1, hbm_swizzling=True),
@@ -412,7 +413,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
412413
weight_dtype = dtype_str_to_torch(weight_dtype_str)
413414
act_dtype = dtype_str_to_torch(act_dtype_str)
414415
precision_opt = init_precision(act_dtype, act_is_float8, weight_dtype, weight_mxfp,
415-
n_expts_tot, expt_is_inner, device=device)
416+
mode, n_expts_tot, expt_is_inner, device=device)
416417
# precision_opt.x_pad_trans_requires_flexpoint = False
417418
if mode == "ragged":
418419
m, rdata, gindx, sindx = init_routing_data(m, n_expts_tot, n_expts_act, do_gather, do_scatter,
@@ -667,7 +668,7 @@ def test_fused_act(m, n, k, mode, split_k, do_gather, do_scatter, fused_scatter,
667668
else:
668669
rdata = gindx = sindx = None
669670

670-
precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, n_expts_tot, device=device)
671+
precision_opt = init_precision(act_dtype, str(act_dtype).startswith("torch.float8"), weight_dtype, False, mode, n_expts_tot, device=device)
671672
x, w, bias, _, _ = init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mode,
672673
act_dtype, weight_dtype, False, requires_grad=False, device=device)
673674

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,9 @@ def init_allocation(x, w, precision_config, fused_activation,
244244
scratchpad = dict()
245245
if opt_flags.split_k > 1 or (scatter_indx is not None and not opt_flags.fused_scatter):
246246
scratch_out_dtype = torch.float32 if opt_flags.split_k > 1 else out_dtype
247-
scratchpad["matmul"] = ((opt_flags.split_k, 1, M, N), scratch_out_dtype)
247+
scratchpad["matmul"] = ((opt_flags.split_k, batch_dim, M, N), scratch_out_dtype)
248248
if "matmul" in scratchpad and precision_config.out_scale is not None:
249+
assert batch_dim == 1, "batch_dim > 1 not supported yet"
249250
scratchpad["mx_out_scale"] = ((opt_flags.split_k, 1, M, triton.cdiv(N, MXFP_BLOCK_SIZE)), torch.uint8)
250251
return MatmulAllocation(x.device, output, scratchpad)
251252

@@ -323,11 +324,14 @@ def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_m
323324
Returns
324325
- The input tensor `x` (modified in place).
325326
"""
327+
M = x.shape[2] # Only used for per-batch flex scale.
326328
if indx is None and x.shape[0] == 1:
327329
return x.squeeze(0), None
328330
if indx is not None:
329331
num_groups = indx.shape[0]
330332
else:
333+
# Handle batched matmul (K, B, M, N) by pretending it to be (K, 1, B*M, N).
334+
x = x.view(x.shape[0], 1, x.shape[1] * x.shape[2], x.shape[3])
331335
num_groups = x.shape[-2]
332336
if x_flex is None:
333337
x_flex = InFlexData()
@@ -351,8 +355,10 @@ def reduce_grouped(x: torch.Tensor, indx: torch.Tensor, out: torch.Tensor, out_m
351355
x_flex.reinterpret(x), x.stride(0), x.stride(2), x.stride(3), #
352356
x_expected_scale, # scalar input scale
353357
out_flex.reinterpret(out), out.stride(1), out.stride(2), #
354-
out_expected_scale, out_actual_scale, out_checksum_scale, indx, #
355-
x.shape[0], x.shape[-1], #
358+
out_expected_scale, out_actual_scale, out_checksum_scale,
359+
out_flex is not None and out_flex.is_per_batch,
360+
indx,
361+
x.shape[0], M, x.shape[-1], #
356362
x_mx_scale, stride_mxb, stride_mxs, #
357363
out_mx_scale, stride_omxs, #
358364
*fused_activation.fn_args, fused_activation.reduction_n,
@@ -629,7 +635,7 @@ def matmul_ogs(x, w, bias,
629635
precision_config.allow_tf32,
630636
precision_config.flexpoint_saturate_inf,
631637
flex.rhs_data.is_per_batch,
632-
flex.out_data.is_per_batch,
638+
out_matmul_flex.is_per_batch,
633639
flex.acc_data.is_per_batch,
634640
opt_flags.block_m,
635641
opt_flags.block_n,

python/triton_kernels/triton_kernels/matmul_ogs_details/_reduce_grouped.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def _reduce_grouped(X, stride_xb: tl.uint64, stride_xm: tl.uint64, stride_xn, #
99
XScale, # input scalar flex scale
1010
Out, stride_om: tl.uint64, stride_on, # output tensor
1111
OutExpectedScale, OutActualScale, OutChecksumScale, # output scalar flex scales
12-
InIndx, B, N, #
12+
PER_BATCH_OUT_SCALE: tl.constexpr, InIndx, B, M, N, #
1313
XMxScale, stride_mxb: tl.uint64,
1414
stride_mxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
1515
OutMxScale, stride_omxs: tl.uint64, # optional per-32-col output MXFP scales (uint8)
@@ -42,6 +42,12 @@ def _reduce_grouped(X, stride_xb: tl.uint64, stride_xm: tl.uint64, stride_xn, #
4242
XScalePtrs = XMxScale + tl.arange(0, BLOCK_N // 32) * stride_xn
4343
if HAS_OUT_MX_SCALE:
4444
OutScalePtrs = OutMxScale + tl.arange(0, BLOCK_N_OUT // 32) * stride_on
45+
if PER_BATCH_OUT_SCALE:
46+
out_batch_idx = pid_t // M
47+
OutExpectedScale += out_batch_idx
48+
OutActualScale += out_batch_idx
49+
if OutChecksumScale is not None:
50+
OutChecksumScale += out_batch_idx
4551
x_scale = load_scale(XScale)
4652
for n_curr in tl.range(0, N, BLOCK_N, num_stages=4):
4753
acc = tl.zeros([BLOCK_N_OUT], dtype=tl.float32)

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,7 @@ def make_default_opt_flags_amd(
9090
)
9191
is_persistent = constraints.get("is_persistent", False)
9292
# split_k:
93-
if batch_size > 1:
94-
split_k = 1 # currently not supported
95-
elif constraints.get("split_k", None) is not None:
93+
if constraints.get("split_k", None) is not None:
9694
split_k = constraints["split_k"]
9795
elif is_persistent or enforce_bitwise_invariance:
9896
split_k = 1
@@ -222,9 +220,7 @@ def make_default_opt_flags_nvidia(
222220
# TODO: swizzle the HBM layout of the weights instead
223221
block_n, block_k = block_k, block_n
224222
# split_k
225-
if batch_size > 1:
226-
split_k = 1 # currently not supported
227-
elif constraints.get("split_k", None) is not None:
223+
if constraints.get("split_k", None) is not None:
228224
split_k = constraints["split_k"]
229225
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
230226
split_k = 1

0 commit comments

Comments
 (0)