Skip to content

Commit 9801a7a

Browse files
authored
Reapply "[KERNELS] Improve block sizes for batched matmul_ogs with small m/n/k (#7897)" (#8084)
This reverts commit 0a2e3a3. (Verified that this is still faster on GB200 on top of recent fixes.)
1 parent 1c03c46 commit 9801a7a

File tree

4 files changed

+80
-18
lines changed

4 files changed

+80
-18
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# isort: off
22
# fmt: off
33
from dataclasses import dataclass, fields, replace
4+
import itertools
45
import pytest
56
import torch
67
from typing import Union
@@ -470,14 +471,66 @@ def round_x(x, idx):
470471
tri_y_scale).abs() < 1e-10, f"ref_y_scale: {ref_y_scale}, tri_y_scale: {tri_y_scale.item()}"
471472

472473

474+
# Test that we don't use unsupported block sizes.
475+
@pytest.mark.parametrize("m", [8, 16, 32, 64, 128])
476+
@pytest.mark.parametrize("n", [8, 16, 32, 64, 128])
477+
@pytest.mark.parametrize("k", [8, 16, 32, 64, 128])
478+
def test_small_batch_matmul(m, n, k):
479+
if is_hip():
480+
pytest.skip("Not fully tested on AMD")
481+
482+
if m * n * k > 16384:
483+
pytest.skip()
484+
485+
BATCH_SIZE = 10000
486+
487+
def _make_tensor(shape, dtype, trans):
488+
if trans:
489+
shape = (shape[0], shape[2], shape[1])
490+
t = alloc_rand(shape, "cuda", dtype)
491+
return t.transpose(1, 2) if trans else t
492+
493+
for x_transpose, w_transpose, bias, dtype in itertools.product(
494+
(False, True),
495+
(False, True),
496+
(False, True),
497+
(torch.float16, torch.bfloat16, torch.float8_e5m2),
498+
):
499+
if (
500+
torch.cuda.get_device_capability()[0] < 10
501+
and dtype is torch.float8_e5m2
502+
and (not w_transpose)
503+
):
504+
continue # Not supported
505+
506+
x = _make_tensor((BATCH_SIZE, m, k), dtype, x_transpose)
507+
w = _make_tensor((BATCH_SIZE, k, n), dtype, w_transpose)
508+
bias = _make_tensor((BATCH_SIZE, n), torch.float32, False) if bias else None
509+
tri_y = matmul_ogs(x, w, bias)
510+
511+
# ref_y = matmul_ogs_torch(x.float(), w.float(), bias)
512+
513+
# This is faster than matmul_ogs_torch.
514+
ref_y = torch.bmm(x.float(), w.float())
515+
if bias is not None:
516+
ref_y += bias[:, None, :]
517+
518+
assert_close(
519+
ref_y,
520+
tri_y,
521+
maxtol=4e-1 if dtype is torch.float8_e5m2 else None,
522+
rmstol=4e-2 if dtype is torch.float8_e5m2 else None,
523+
)
524+
525+
473526
def test_set_idle_sms():
474527
if not is_cuda():
475528
pytest.skip("Only supported on CUDA")
476529
from triton_kernels.matmul_ogs_details.opt_flags import make_opt_flags
477530
num_idle_sms = 24
478531
matmul_ogs_set_idle_sms(num_idle_sms)
479532
flags = make_opt_flags(torch.float32, torch.float32, torch.float32, PrecisionConfig(), \
480-
1024, 1024, 1024, None, True, False, 1)
533+
1, 1024, 1024, 1024, None, True, False, 1)
481534
assert flags.idle_sms == num_idle_sms
482535

483536

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def matmul_ogs(x, w, bias,
368368
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
369369
can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
370370
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
371-
M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
371+
batch_size, M, N, K, routing_data, can_use_tma, can_use_fused_scatter, epilogue.effective_itemsize,
372372
)
373373
if not can_use_fused_scatter and opt_flags.fused_scatter:
374374
raise InapplicableConstraint("Fused scatter is not supported")

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def make_default_opt_flags_amd(
3535
lhs_dtype,
3636
rhs_dtype,
3737
precision_config,
38+
batch_size,
3839
m,
3940
n,
4041
k,
@@ -133,6 +134,7 @@ def make_default_opt_flags_nvidia(
133134
lhs_dtype,
134135
rhs_dtype,
135136
precision_config,
137+
batch_size,
136138
m,
137139
n,
138140
k,
@@ -146,7 +148,7 @@ def make_default_opt_flags_nvidia(
146148
constraints_supported = ["block_m", "block_k", "split_k", "is_persistent", "fused_scatter", "epilogue_subtile", "num_stages", "idle_sms"]
147149
assert not any([c not in constraints_supported for c in constraints]), constraints.keys()
148150
# tokens per expert
149-
if routing_data is None:
151+
if routing_data is None or batch_size > 1:
150152
tokens_per_expt = m
151153
elif routing_data.expected_tokens_per_expt is None:
152154
tokens_per_expt = max(1, m // routing_data.n_expts_tot)
@@ -164,11 +166,11 @@ def make_default_opt_flags_nvidia(
164166
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
165167
# block n
166168
arch = None
167-
block_n = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
169+
block_n, block_n_tma = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
168170
# is_persistent
169-
grid_size = opt_flags_nvidia.compute_grid_size(routing_data, m, n, block_m, block_n)
171+
grid_size_tma = opt_flags_nvidia.compute_grid_size(routing_data, batch_size, m, n, block_m, block_n_tma)
170172
n_sms = torch.cuda.get_device_properties(0).multi_processor_count
171-
tiles_per_sm = grid_size / n_sms
173+
tiles_per_sm = grid_size_tma / n_sms
172174
supports_persistent = can_use_persistent_tma and (arch is None or int(arch[2:-1]) >= 9)
173175
if constraints.get("is_persistent", None) is not None:
174176
is_persistent = constraints["is_persistent"]
@@ -178,6 +180,10 @@ def make_default_opt_flags_nvidia(
178180
# TEMP CHANGE
179181
if precision_config.act_scale is not None or precision_config.out_scale is not None:
180182
is_persistent = False
183+
# TMA is slower for batched matmuls with small m/n/k.
184+
if m * n * k < 131072:
185+
is_persistent = False
186+
block_n = block_n_tma if is_persistent else block_n
181187
# block k
182188
if constraints.get("block_k", None) is not None:
183189
block_k = constraints["block_k"]
@@ -189,7 +195,7 @@ def make_default_opt_flags_nvidia(
189195
elif is_persistent or enforce_bitwise_invariance or precision_config.act_scale is not None or precision_config.out_scale is not None:
190196
split_k = 1
191197
else:
192-
estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, m, n, block_m, block_n)
198+
estimated_actual_grid_size = opt_flags_nvidia.compute_grid_size(None, batch_size, m, n, block_m, block_n)
193199
split_k = opt_flags_nvidia.compute_split_k(block_k, k, estimated_actual_grid_size)
194200
if split_k > 1:
195201
# With split_k, results are written in f32. Use that for the following computations.
@@ -224,7 +230,7 @@ def make_default_opt_flags_nvidia(
224230
else:
225231
fused_scatter = can_use_fused_scatter and split_k == 1
226232
# Handshake with the HBM swizzling
227-
num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, precision_config)
233+
num_warps = opt_flags_nvidia.compute_num_warps(block_m, block_n, is_persistent, precision_config)
228234
ret = OptFlags(
229235
block_m=block_m,
230236
block_n=block_n,
@@ -275,6 +281,7 @@ def make_opt_flags(
275281
lhs_dtype,
276282
rhs_dtype,
277283
precision_config,
284+
batch_size,
278285
m,
279286
n,
280287
k,
@@ -291,7 +298,7 @@ def make_opt_flags(
291298
if _opt_flags is not None:
292299
assert not _opt_flags_constraints
293300
return _opt_flags
294-
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, m, n, k,
301+
args = [out_dtype, lhs_dtype, rhs_dtype, precision_config, batch_size, m, n, k,
295302
routing_data, can_use_persistent_tma, can_use_fused_scatter,
296303
enforce_bitwise_invariance, epilogue_effective_itemsize,
297304
_opt_flags_constraints]

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,25 @@
66
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
77

88

9-
def compute_grid_size(routing_data, m, n, block_m, block_n):
10-
if routing_data is not None:
9+
def compute_grid_size(routing_data, batch_size, m, n, block_m, block_n):
10+
if routing_data is not None and batch_size == 1:
1111
grid_m = routing_data.n_blocks(m, block_m)
1212
else:
1313
grid_m = triton.cdiv(m, block_m)
1414
grid_n = (n + block_n - 1) // block_n
15-
return grid_m * grid_n
15+
return batch_size * grid_m * grid_n
1616

1717

1818
def compute_block_n(n: int, arch, precision_config):
1919
# block_n:
2020
layout = get_layout(precision_config.weight_scale)
2121
if isinstance(layout, HopperMXScaleLayout) and layout.num_warps == 4:
22-
return 128
22+
return 128, 128
2323
elif precision_config.max_num_imprecise_acc is None and n > 128:
24-
return 256
24+
return 256, 256
2525
else:
26-
return max(16, min(128, triton.next_power_of_2(n)))
26+
target = min(128, triton.next_power_of_2(n))
27+
return max(8, target), max(16, target)
2728

2829

2930
def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_dtype, precision_config):
@@ -35,7 +36,8 @@ def compute_block_k(m: int, k: int | None, is_persistent: bool, lhs_dtype, rhs_d
3536
if rhs_width == 4 and not has_native_mxfp:
3637
block_k = 128
3738
elif k is not None:
38-
block_k = max(32, min(triton.next_power_of_2(k), block_k))
39+
min_block_k = 32 if is_persistent or lhs_width != 16 or rhs_width != 16 else 16
40+
block_k = max(min_block_k, min(triton.next_power_of_2(k), block_k))
3941
has_mx_weight_scale = precision_config is not None and precision_config.weight_scale is not None
4042
if has_native_mxfp and is_persistent and has_mx_weight_scale:
4143
block_k = min(block_k, 128)
@@ -54,11 +56,11 @@ def compute_split_k(block_k: int, k: int | None, grid_size: int) -> int:
5456
return split_k
5557

5658

57-
def compute_num_warps(block_m, block_n, precision_config):
59+
def compute_num_warps(block_m, block_n, is_persistent: bool, precision_config):
5860
layout = get_layout(precision_config.weight_scale)
5961
if isinstance(layout, HopperMXScaleLayout):
6062
return layout.num_warps
61-
return max(block_m * block_n // 4096, 4)
63+
return max(block_m * block_n // 4096, 4 if is_persistent else 1)
6264

6365

6466
def compute_num_stages(

0 commit comments

Comments
 (0)