Skip to content

Commit f1872ed

Browse files
authored
[KERNELS] Improve block sizes for batched matmul_ogs with small m/n/k. (triton-lang#7897)
(Previously, block sizes could be much bigger than m/n/k.) Example perf difference: ``` H100: B=500000 M=8 N=8 K=8 >> torch.float16 0.850 ms -> 0.388 ms >> torch.bfloat16 0.828 ms -> 0.354 ms >> torch.float8_e5m2 0.829 ms -> 0.373 ms B=500000 M=16 N=16 K=16 >> torch.float16 0.791 ms -> 0.381 ms >> torch.bfloat16 0.790 ms -> 0.382 ms >> torch.float8_e5m2 0.779 ms -> 0.366 ms GB200: B=500000 M=8 N=8 K=8 >> torch.float16 0.676 ms -> 0.314 ms >> torch.bfloat16 0.652 ms -> 0.297 ms >> torch.float8_e5m2 0.659 ms -> 0.294 ms B=500000 M=16 N=16 K=16 >> torch.float16 0.622 ms -> 0.305 ms >> torch.bfloat16 0.606 ms -> 0.306 ms >> torch.float8_e5m2 0.616 ms -> 0.296 ms ```
1 parent 139ce65 commit f1872ed

File tree

7 files changed

+97
-28
lines changed

7 files changed

+97
-28
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ test-unit: all
3535
--ignore=language/test_subprocess.py --ignore=test_debug.py
3636
$(PYTEST) -s -n $(NUM_PROCS) python/test/unit/language/test_subprocess.py
3737
$(PYTEST) -s -n $(NUM_PROCS) python/test/unit/test_debug.py --forked
38-
$(PYTEST) -s -n 8 python/triton_kernels/tests/
38+
$(PYTEST) -s -n 6 python/triton_kernels/tests/
3939
TRITON_DISABLE_LINE_INFO=0 $(PYTEST) -s python/test/unit/language/test_line_info.py
4040
# Run attention separately to avoid out of gpu memory
4141
$(PYTEST) -vs python/tutorials/06-fused-attention.py

python/triton_kernels/tests/test_matmul.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# isort: off
22
# fmt: off
33
from dataclasses import dataclass, fields, replace
4+
import itertools
45
import pytest
56
import torch
67
from typing import Union
@@ -517,14 +518,66 @@ def round_x(x, idx):
517518
tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}"
518519

519520

521+
# Test that we don't use unsupported block sizes.
522+
@pytest.mark.parametrize("m", [8, 16, 32, 64, 128])
523+
@pytest.mark.parametrize("n", [8, 16, 32, 64, 128])
524+
@pytest.mark.parametrize("k", [8, 16, 32, 64, 128])
525+
def test_small_batch_matmul(m, n, k):
526+
if is_hip():
527+
pytest.skip("Not fully tested on AMD")
528+
529+
if m * n * k > 16384:
530+
pytest.skip()
531+
532+
BATCH_SIZE = 10000
533+
534+
def _make_tensor(shape, dtype, trans):
535+
if trans:
536+
shape = (shape[0], shape[2], shape[1])
537+
t = alloc_rand(shape, "cuda", dtype)
538+
return t.transpose(1, 2) if trans else t
539+
540+
for x_transpose, w_transpose, bias, dtype in itertools.product(
541+
(False, True),
542+
(False, True),
543+
(False, True),
544+
(torch.float16, torch.bfloat16, torch.float8_e5m2),
545+
):
546+
if (
547+
torch.cuda.get_device_capability()[0] < 10
548+
and dtype is torch.float8_e5m2
549+
and (not w_transpose)
550+
):
551+
continue # Not supported
552+
553+
x = _make_tensor((BATCH_SIZE, m, k), dtype, x_transpose)
554+
w = _make_tensor((BATCH_SIZE, k, n), dtype, w_transpose)
555+
bias = _make_tensor((BATCH_SIZE, n), torch.float32, False) if bias else None
556+
tri_y = matmul_ogs(x, w, bias)
557+
558+
# ref_y = matmul_ogs_torch(x.float(), w.float(), bias)
559+
560+
# This is faster than matmul_ogs_torch.
561+
ref_y = torch.bmm(x.float(), w.float())
562+
if bias is not None:
563+
ref_y += bias[:, None, :]
564+
565+
assert_close(
566+
ref_y,
567+
tri_y,
568+
maxtol=4e-1 if dtype is torch.float8_e5m2 else None,
569+
rmstol=4e-2 if dtype is torch.float8_e5m2 else None,
570+
)
571+
572+
520573
def test_set_idle_sms():
521574
if not is_cuda():
522575
pytest.skip("Only supported on CUDA")
523576
from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags
524577
num_idle_sms = 24
525578
matmul_ogs_set_idle_sms(num_idle_sms)
526579
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
527-
1024, 1024, 1024, None, True, False, 1)
580+
1, 1024, 1024, 1024, None, True, False, 1)
528581
assert flags.idle_sms == num_idle_sms
529582

530583

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def matmul_ogs(x, w, bias,
444444
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
445445
can_use_fused_scatter = scatter_indx is not None and fused_activation.specs.fn is None
446446
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
447-
M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
447+
batch_size, M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
448448
)
449449
if w_scale is not None and opt_flags.is_persistent and not target_info.has_native_mxfp():
450450
raise NotImplementedError("Must use non-persistent kernel for simulated MXFP")
@@ -631,10 +631,10 @@ def matmul_ogs_torch(x, w, bias,
631631
assert routing_data is None, "routing not supported in batched mode"
632632
assert w.ndim == 3 and w.shape[0] == x.shape[0]
633633
if round_x is None:
634-
round_x = lambda x: x
634+
round_x = lambda x, idx: x
635635
if round_y is None:
636636
round_y = lambda x: x
637-
if bias.ndim == 1:
637+
if bias is not None and bias.ndim == 1:
638638
bias = bias.view(1, *bias.shape)
639639
if w.ndim == 2:
640640
w = w.view(1, *w.shape)

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ def _matmul_ogs(
9696
tl.assume(stride_w_mx_k >= 0)
9797
if stride_w_mx_n is not None:
9898
tl.assume(stride_w_mx_n >= 0)
99-
tl.assume(stride_b_e >= 0)
99+
if B is not None:
100+
tl.assume(stride_b_e >= 0)
100101
tl.assume(batch_size >= 0)
101102
tl.assume(grid_m >= 0)
102103
tl.assume(grid_n >= 0)

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def make_default_opt_flags_amd(
3636
lhs_dtype,
3737
rhs_dtype,
3838
precision_config,
39+
batch_size,
3940
m,
4041
n,
4142
k,
@@ -134,6 +135,7 @@ def make_default_opt_flags_nvidia(
134135
lhs_dtype,
135136
rhs_dtype,
136137
precision_config,
138+
batch_size,
137139
m,
138140
n,
139141
k,
@@ -147,7 +149,7 @@ def make_default_opt_flags_nvidia(
147149
constraints_supported = ["block_m", "block_k", "split_k", "fused_scatter", "is_persistent", "epilogue_subtile", "num_stages", "idle_sms"]
148150
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
149151
# tokens per expert
150-
if routing_data is None:
152+
if routing_data is None or batch_size > 1:
151153
tokens_per_expt = m
152154
elif routing_data.expected_tokens_per_expt is None:
153155
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
@@ -165,11 +167,11 @@ def make_default_opt_flags_nvidia(
165167
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
166168
# block n
167169
arch = None
168-
block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
170+
block_n, block_n_tma = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
169171
# is_persistent
170-
grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n)
172+
grid_size_tma = opt_flags_nvidia.compute_grid_size(routing_data, batch_size, m, n, block_m, block_n_tma)
171173
n_sms = torch.cuda.get_device_properties(0).multi_processor_count
172-
tiles_per_sm = grid_size / n_sms
174+
tiles_per_sm = grid_size_tma / n_sms
173175
supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
174176
if constraints.get("is_persistent", None) is not None:
175177
is_persistent = constraints["is_persistent"]
@@ -179,6 +181,10 @@ def make_default_opt_flags_nvidia(
179181
# TEMP CHANGE
180182
if precision_config.act_scale is not None or precision_config.out_scale is not None:
181183
is_persistent = False
184+
# TMA is slower for batched matmuls with small m/n/k.
185+
if m * n * k < 131072:
186+
is_persistent = False
187+
block_n = block_n_tma if is_persistent else block_n
182188
# block k
183189
if constraints.get("block_k", None) is not None:
184190
block_k = constraints["block_k"]
@@ -190,7 +196,7 @@ def make_default_opt_flags_nvidia(
190196
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
191197
split_k = 1
192198
else:
193-
estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n)
199+
estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, batch_size, m, n, block_m, block_n)
194200
split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
195201
if split_k > 1:
196202
# With split_k, results are written in f32. Use that for the following computations.
@@ -225,7 +231,7 @@ def make_default_opt_flags_nvidia(
225231
else:
226232
fused_scatter = can_use_fused_scatter and split_k == 1
227233
# Handshake with the HBM swizzling
228-
num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config)
234+
num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, is_persistent, precision_config)
229235
ret = OptFlags(
230236
block_m=block_m,
231237
block_n=block_n,
@@ -276,6 +282,7 @@ def make_opt_flags(
276282
lhs_dtype,
277283
rhs_dtype,
278284
precision_config,
285+
batch_size,
279286
m,
280287
n,
281288
k,
@@ -290,7 +297,7 @@ def make_opt_flags(
290297
if _opt_flags is not None:
291298
assert not _opt_flags_constraints
292299
return _opt_flags
293-
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, m, n, k,
300+
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, batch_size, m, n, k,
294301
routing_data, can_use_persistent_tma, can_use_fused_scatter,
295302
enforce_bitwise_invariance, epilogue_effective_itemsize,
296303
_opt_flags_constraints]

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,25 @@
66
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
77

88

9-
def compute_grid_size(routing_data, m, n, block_m, block_n):
10-
if routing_data is not None:
9+
def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n):
10+
if routing_data is not None and batch_size == 1:
1111
grid_m = routing_data.n_blocks(m, block_m)
1212
else:
1313
grid_m = triton.cdiv(m, block_m)
1414
grid_n = (n + block_n - 1) // block_n
15-
return grid_m * grid_n
15+
return batch_size * grid_m * grid_n
1616

1717

1818
def compute_block_n(n: int, arch, precision_config):
1919
# block_n:
2020
layout = get_layout(precision_config.weight_scale)
2121
if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4:
22-
return 128
22+
return 128, 128
2323
elif precision_config.max_num_imprecise_acc is None and n > 128:
24-
return 256
24+
return 256, 256
2525
else:
26-
return max(16, min(128, triton.next_power_of_2(n)))
26+
target = min(128, triton.next_power_of_2(n))
27+
return max(8, target), max(16, target)
2728

2829

2930
def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config):
@@ -35,7 +36,8 @@ def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_d
3536
if rhs_width == 4 and not has_native_mxfp:
3637
block_k = 128
3738
elif k is not None:
38-
block_k = max(32, min(triton.next_power_of_2(k), block_k))
39+
min_block_k = 32 if is_persistent or lhs_width != 16 or rhs_width != 16 else 16
40+
block_k = max(min_block_k, min(triton.next_power_of_2(k), block_k))
3941
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
4042
if has_native_mxfp and is_persistent and has_mx_weight_scale:
4143
block_k = min(block_k, 128)
@@ -54,11 +56,11 @@ def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
5456
return split_k
5557

5658

57-
def compute_num_warps(block_m, block_n, precision_config):
59+
def compute_num_warps(block_m, block_n, is_persistent: bool, precision_config):
5860
layout = get_layout(precision_config.weight_scale)
5961
if isinstance(layout, HopperMXScaleLayout):
6062
return layout.num_warps
61-
return max(block_m * block_n // 4096, 4)
63+
return max(block_m * block_n // 4096, 4 if is_persistent else 1)
6264

6365

6466
def compute_num_stages(

python/triton_kernels/triton_kernels/routing.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,12 +333,18 @@ def compute_expt_data_torch(hist, n_expts_tot, n_gates):
333333
token_offs_pad[block_m] = torch.cat((torch.zeros(1, device=device), token_offs_pad[block_m]))
334334
token_offs_pad[block_m] = token_offs_pad[block_m].int()
335335
# compute data required to drive ragged batch matmul
336-
block_pid_map[block_m] = -torch.ones(max_n_tiles, device=device)
337-
for e in range(n_expts_tot):
338-
offset = token_offs_pad[block_m][e]
339-
for b in range(n_tiles[e]):
340-
block_pid_map[block_m][offset + b] = (b << 16) + e
341-
block_pid_map[block_m] = block_pid_map[block_m].int()
336+
block_pid_map[block_m] = -torch.ones(max_n_tiles, dtype=torch.int32, device=device)
337+
338+
# for e in range(n_expts_tot):
339+
# offset = token_offs_pad[block_m][e]
340+
# for b in range(n_tiles[e]):
341+
# block_pid_map[block_m][offset + b] = (b << 16) + e
342+
343+
col = torch.arange(max_n_tiles, device=device)
344+
map_vals = torch.arange(n_expts_tot, device=device)[:, None] + (col << 16)[None, :]
345+
map_idxs = token_offs_pad[block_m][:-1, None] + col[None, :]
346+
mask = col[None, :] < n_tiles[:, None]
347+
block_pid_map[block_m].index_put_((map_idxs[mask], ), map_vals.int()[mask])
342348
return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
343349

344350

0 commit comments

Comments
 (0)