Skip to content

Commit 55613a7

Browse files
authored
[KERNELS] tuning for small batch MoE (#8206)
small batch MoE should ideally have block_m be not too large, but also >= the number of tokens per expert for the large majority of cases. This minimizes the number of times weights are loaded in duplicate. It also improves mxfp4 in bandwidth bound cases to load a full cache line.
1 parent de2ba39 commit 55613a7

File tree

1 file changed

+11
-1
lines changed
  • python/triton_kernels/triton_kernels/matmul_ogs_details

1 file changed

+11
-1
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# isort: off
22
# fmt: off
33
from dataclasses import dataclass
4+
45
import triton
56
from triton_kernels.target_info import get_cdna_version
7+
from triton_kernels.tensor import FP4
68
import torch
79
from .opt_flags_details import opt_flags_amd, opt_flags_nvidia
810
from triton_kernels.tensor import bitwidth
@@ -186,7 +188,11 @@ def make_default_opt_flags_nvidia(
186188
elif enforce_bitwise_invariance:
187189
block_m = 128
188190
else:
189-
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
191+
if tokens_per_expt <= 64 and routing_data is not None and routing_data.expt_hist is not None:
192+
# Ragged and likely memory bound; set the block size higher to minimize loading weights more than once.
193+
block_m = max(16, min(triton.next_power_of_2(2 * tokens_per_expt), 64))
194+
else:
195+
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
190196
# block n
191197
arch = None
192198
block_n, block_n_tma = opt_flags_nvidia.compute_block_n(n, arch, precision_config)
@@ -209,6 +215,10 @@ def make_default_opt_flags_nvidia(
209215
block_k = constraints["block_k"]
210216
else:
211217
block_k = opt_flags_nvidia.compute_block_k(m, k, is_persistent, lhs_dtype, rhs_dtype, precision_config)
218+
if block_n == 256 and block_k == 128 and block_m <= 64 and is_persistent and rhs_dtype == FP4 and k >= 4096 and tokens_per_expt > 1:
219+
# Swap block_n and block_k for mxfp4 weights so that block_k is a full cacheline, so long as K is sufficiently large.
220+
# TODO: swizzle the HBM layout of the weights instead
221+
block_n, block_k = block_k, block_n
212222
# split_k
213223
if batch_size > 1:
214224
split_k = 1 # currently not supported

0 commit comments

Comments
 (0)