Skip to content

Commit ba63c59

Browse files
authored
[KERNELS][TUTORIAL][BLACKWELL] Use optimized TMA layout for block scale factors (#7123)
Improves the Triton MOE kernel performance for the mxfp4 block scaled workload with a better layout choice for TMA. This is up to a ~30% boost in performance for some shapes in the roofline sweep. Importantly, with this patch using TMA for scale factors is always faster than traditional loads so we can remove their use for SF. Tutorial 10 is updated to reflect this. * Use TMA for block scales even when performing W @ X * Use host TMA descriptors for X, W, and SF in persistent kernel when possible * Use 2x256 shape for scale factor TMAs * Also update tutorial 10 to use the faster 5d 2x256xu8 TMA for block scale factors.
1 parent ff57a4d commit ba63c59

File tree

7 files changed

+349
-162
lines changed

7 files changed

+349
-162
lines changed

python/triton_kernels/bench/bench_mlp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def quantize(w, dtype, dev, **opt):
2828
elif dtype == "fp8":
2929
fp8e4_dtype = torch.float8_e4m3fn if get_cdna_version() != 3 \
3030
else torch.float8_e4m3fnuz
31-
wq = w.to(fp8e4_dtype).transpose(-1, -2).contiguous().transpose(-1, -2)
31+
wq = w.to(fp8e4_dtype)
3232
return wq, InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)), \
3333
MicroscalingCtx()
3434
else:

python/triton_kernels/tests/test_matmul.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ class Case:
207207
Case(300, 400, 832, "ragged", "float8_e5m2", "mxfloat4_e2m1", 8, 4, hbm_swizzling=True),
208208
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4),
209209
Case(300, 400, 400, "batched", "float8_e5m2", "mxfloat8_e4m3fn", 32, 4, hbm_swizzling=True),
210+
Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=True),
211+
Case(256, 256, 256, "ragged", "float8_e5m2", "mxfloat4_e2m1", 128, 4, hbm_swizzling=False),
210212
# AMD
211213
Case(300, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz"),
212214
Case(1000, 400, 400, "ragged", "float8_e4m3fnuz", "float8_e4m3fnuz", 3, 1),
@@ -386,6 +388,7 @@ def _hook(launch_metadata):
386388
if mode == "batched":
387389
rdata, gindx, sindx = None, None, None
388390
flex = precision_opt.flex_ctx
391+
389392
# triton
390393
tri_y = matmul_ogs(x_tri, w_tri, bias_tri, rdata, gindx, sindx, precision_opt, gammas=gs1_ref)
391394
# If split_k > 1, then the intermediate tensor is fp32.

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 183 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
from dataclasses import dataclass
22
import itertools
3+
import math
34
import sys
45
import torch
56
import triton
67
# utilities
78
from triton_kernels import target_info
89
from triton_kernels.numerics import InFlexData, OutFlexData
9-
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
10+
from triton_kernels.routing import ExptData, GatherIndx, RoutingData, ScatterIndx
11+
from triton.tools.tensor_descriptor import TensorDescriptor
1012
# details
1113
from .matmul_ogs_details._matmul_ogs import _compute_writeback_idx
1214
from .matmul_ogs_details._matmul_ogs import _matmul_ogs
1315
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
1416
from .matmul_ogs_details._finalize_matmul import _finalize_matmul
15-
from .matmul_ogs_details.opt_flags import make_opt_flags
17+
from .matmul_ogs_details.opt_flags import make_opt_flags, OptFlags
1618
from .matmul_ogs_details.fast_contiguous import fast_contiguous
1719
from .numerics_details.mxfp import SwizzlingType
1820
from .specialize import specialize
21+
from typing import Tuple, Optional
1922

2023

2124
@dataclass
@@ -95,6 +98,84 @@ def should_upcast_indices(*args):
9598
return any(tensor is not None and can_overflow_int32(tensor) for tensor in args)
9699

97100

101+
class TensorDescriptorBuilder:
102+
"""Builder for creating different types of tensor descriptors"""
103+
104+
@staticmethod
105+
def create_basic_descriptor(tensor: torch.Tensor, block_shape: Tuple[int, ...],
106+
transpose: bool = False) -> TensorDescriptor:
107+
"""Create a basic tensor descriptor with optional transpose"""
108+
if transpose:
109+
block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
110+
tensor = tensor.permute(0, 2, 1)
111+
return TensorDescriptor.from_tensor(tensor, block_shape=block_shape)
112+
113+
@staticmethod
114+
def create_weight_descriptor(w_tensor: torch.Tensor, block_k: int, block_n: int,
115+
transpose: bool) -> TensorDescriptor:
116+
"""Create a tensor descriptor for weight matrix"""
117+
# Two e2m1 packed in a uint8 or a single fp8
118+
W_PACK_DIVISOR = 2 if w_tensor.dtype == torch.uint8 else 1
119+
PACKED_BLOCK_K_W = block_k // W_PACK_DIVISOR
120+
return TensorDescriptorBuilder.create_basic_descriptor(w_tensor, block_shape=[1, PACKED_BLOCK_K_W, block_n],
121+
transpose=transpose)
122+
123+
@staticmethod
124+
def create_block_scale_descriptor(mx_tensor: torch.Tensor, block_k: int, block_n: int, K: int, N: int,
125+
mx_scale_stride_k: int, mx_scale_stride_n: int, n_expts_tot: int, batch_size: int,
126+
expt_data: Optional[ExptData], swizzle_mx: bool,
127+
transpose: bool) -> TensorDescriptor:
128+
"""Create a tensor descriptor for block scale factors"""
129+
MX_PACK_DIVISOR = 32
130+
MX_SCALE_BLOCK_K = block_k // MX_PACK_DIVISOR
131+
PackedK = (K + MX_PACK_DIVISOR - 1) // MX_PACK_DIVISOR
132+
133+
if swizzle_mx:
134+
num_expt_x_ncol = (n_expts_tot if expt_data is not None and len(expt_data.block_pid_map) > 0 else
135+
batch_size) * ((N + 127) // 128)
136+
return TensorDescriptor(
137+
base=mx_tensor, shape=[1, num_expt_x_ncol, (PackedK + 3) // 4, 2, 256],
138+
strides=[num_expt_x_ncol * mx_scale_stride_n, mx_scale_stride_n, mx_scale_stride_k, 256,
139+
1], block_shape=[1, block_n // 128, MX_SCALE_BLOCK_K // 4, 2, 256])
140+
else:
141+
# Non-optimal SF layout, expect slow transfers
142+
# from global to shmem and from shmem to tmem
143+
return TensorDescriptorBuilder.create_basic_descriptor(mx_tensor,
144+
block_shape=[1, MX_SCALE_BLOCK_K,
145+
block_n], transpose=transpose)
146+
147+
@staticmethod
148+
def create_input_descriptor_gather(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int,
149+
block_k: int) -> TensorDescriptor:
150+
"""Create a tensor descriptor for input matrix X via TMA gather"""
151+
x_desc = x_tensor.squeeze()
152+
assert x_desc.ndim == 2, "TMA gather descriptor requires 2D input"
153+
INT_MAX = 2147483647
154+
return TensorDescriptor(base=x_desc, shape=[INT_MAX, K], strides=[x_stride_1, x_stride_2],
155+
block_shape=[1, block_k])
156+
157+
@staticmethod
158+
def create_input_descriptor_load(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int, block_m: int,
159+
block_k: int) -> TensorDescriptor:
160+
"""Create a tensor descriptor for input matrix X via TMA"""
161+
x_desc = x_tensor.squeeze()
162+
assert x_desc.ndim in [2, 3], "LHS input TMA descriptor builder expects 2D or 3D input"
163+
return TensorDescriptor(base=x_desc, shape=[x_desc.shape[0], K], strides=[x_stride_1, x_stride_2],
164+
block_shape=[block_m, block_k])
165+
166+
@staticmethod
167+
def create_input_descriptor(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int, block_k: int,
168+
block_m: int, use_gather_tma: bool, use_load_tma: bool) -> TensorDescriptor:
169+
"""Create a tensor descriptor for input matrix X based on TMA usage"""
170+
if use_gather_tma:
171+
return TensorDescriptorBuilder.create_input_descriptor_gather(x_tensor, K, x_stride_1, x_stride_2, block_k)
172+
elif use_load_tma:
173+
return TensorDescriptorBuilder.create_input_descriptor_load(x_tensor, K, x_stride_1, x_stride_2, block_m,
174+
block_k)
175+
else:
176+
return x_tensor
177+
178+
98179
# ---------------------
99180
# Numerics
100181
# ---------------------
@@ -490,7 +571,6 @@ def init_allocation(x, w, precision_config, fused_activation, routing_data, gath
490571
scratchpad["matmul"] = ((opt_flags.split_k, x.shape[0], M, N), dtype)
491572
return MatmulAllocation(x.device, output, scratchpad)
492573

493-
494574
def apply_allocation(allocation: MatmulAllocation, output):
495575
ret = dict()
496576
if output is None:
@@ -504,10 +584,82 @@ def apply_allocation(allocation: MatmulAllocation, output):
504584
}
505585
return ret
506586

587+
507588
# -----------------------------------------------------------------------------
508589
# Triton Implementation
509590
# -----------------------------------------------------------------------------
510591

592+
def _create_tma_descriptors(
593+
x: torch.Tensor,
594+
x_tensor: torch.Tensor,
595+
w_tensor: torch.Tensor,
596+
mx_tensor: Optional[torch.Tensor],
597+
routing_data: RoutingData,
598+
mx_ctx: MicroscalingCtx,
599+
expt_data: ExptData,
600+
opt_flags: OptFlags,
601+
batch_size: int,
602+
K: int,
603+
N: int,
604+
mx_scale_stride_k: int,
605+
mx_scale_stride_n: int,
606+
USE_GATHER_TMA: bool,
607+
X_USE_LOAD_TMA: bool,
608+
w_transpose: bool,
609+
mx_transpose: bool,
610+
) -> Tuple[bool, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
611+
"""Create and cache TMA descriptors for tensors."""
612+
use_host_tma_descriptors = opt_flags.is_persistent and target_info.cuda_capability_geq(10, 0)
613+
614+
x_desc, w_desc = [None] * 2
615+
descriptors = []
616+
# The dense case currently uses on device descriptor updates
617+
# so we bail out on using host descriptors in that case
618+
if (use_host_tma_descriptors):
619+
if USE_GATHER_TMA or X_USE_LOAD_TMA:
620+
x_desc = TensorDescriptorBuilder.create_input_descriptor(
621+
x_tensor, K, x.stride(1), x.stride(2),
622+
opt_flags.block_k, opt_flags.block_m,
623+
USE_GATHER_TMA, X_USE_LOAD_TMA
624+
)
625+
descriptors.append(x_desc)
626+
if (expt_data is not None and len(expt_data.block_pid_map) > 0):
627+
w_desc = TensorDescriptorBuilder.create_weight_descriptor(
628+
w_tensor, opt_flags.block_k, opt_flags.block_n, w_transpose
629+
)
630+
is_microscaled_format = (mx_ctx.weight_scale is not None) and (w_tensor.dtype == torch.uint8)
631+
if is_microscaled_format:
632+
# Pad the inner shape to 128 for mxfp4 weights
633+
# for mixed precision fp8 x mxfp4 compute
634+
pad = 128
635+
dim_to_pad = -1 if w_transpose else -2
636+
old_size = w_desc.shape[dim_to_pad]
637+
padded_size = math.ceil(old_size / pad) * pad
638+
if padded_size != old_size:
639+
w_desc.shape = list(w_desc.shape)
640+
w_desc.shape[dim_to_pad] = padded_size
641+
descriptors.append(w_desc)
642+
# Optional MX scale descriptor
643+
descriptors.append(None)
644+
if mx_tensor is not None:
645+
descriptors[-1] = TensorDescriptorBuilder.create_block_scale_descriptor(
646+
mx_tensor, opt_flags.block_k, opt_flags.block_n, K, N,
647+
mx_scale_stride_k, mx_scale_stride_n, routing_data.n_expts_tot,
648+
batch_size,
649+
expt_data, mx_ctx.swizzle_scale, mx_transpose
650+
)
651+
652+
# TODO: Currently all or none, instead should support a mixture
653+
# of host and device descriptors
654+
if None in descriptors or len(descriptors) == 0:
655+
descriptors = [x_tensor, w_tensor, mx_tensor]
656+
use_host_tma_descriptors = False
657+
if opt_flags.is_persistent:
658+
opt_flags.target_kernel_kwargs["USE_HOST_TMA_DESCRIPTORS"] = use_host_tma_descriptors
659+
660+
return use_host_tma_descriptors, *descriptors
661+
662+
511663
def matmul_ogs(x, w, bias,
512664
routing_data: RoutingData | None = None,
513665
gather_indx: GatherIndx | None = None,
@@ -601,22 +753,47 @@ def matmul_ogs(x, w, bias,
601753
flex = precision_config.flex_ctx
602754
bias_stride = None if bias is None else bias.stride(0)
603755
num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
756+
604757
kernels = get_kernels(epilogue.specs, fused_activation.specs)
605758
expt_data = routing_data.expt_data
606759
block_m = opt_flags.block_m
607760
expt_hist = None if expt_data is None else expt_data.hist
608761
expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1]
609762
expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
610763
expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
764+
765+
HAS_TMA_GS = target_info.cuda_capability_geq(10, 0)
766+
USE_GATHER_TMA = HAS_TMA_GS and gather_indx is not None
767+
X_USE_LOAD_TMA = gather_indx is None and not USE_GATHER_TMA
768+
_, x_tensor, w_tensor, mx_tensor = _create_tma_descriptors(
769+
x=x,
770+
x_tensor=flex.lhs_data.reinterpret(x),
771+
w_tensor=flex.rhs_data.reinterpret(w),
772+
mx_tensor=mx_ctx.weight_scale,
773+
routing_data=routing_data,
774+
mx_ctx=mx_ctx,
775+
expt_data=expt_data,
776+
opt_flags=opt_flags,
777+
batch_size=batch_size,
778+
K=K,
779+
N=N,
780+
mx_scale_stride_k=mx_scale_stride_k,
781+
mx_scale_stride_n=mx_scale_stride_n,
782+
USE_GATHER_TMA=USE_GATHER_TMA,
783+
X_USE_LOAD_TMA=X_USE_LOAD_TMA,
784+
w_transpose=w.stride(2) != 1,
785+
mx_transpose=mx_scale_stride_n != 1,
786+
)
787+
611788
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)](
612789
flex.out_data.reinterpret(memory["output"]),
613790
flex.out_data.reinterpret(out0), *out0.stride(),
614791
*out0_flex,
615-
flex.lhs_data.reinterpret(x), x.stride(0), x.stride(1), x.stride(2),
792+
x_tensor, x.stride(0), x.stride(1), x.stride(2),
616793
flex.lhs_data.scale,
617-
flex.rhs_data.reinterpret(w), w.stride(0), w.stride(1), w.stride(2), w.stride(2) != 1,
794+
w_tensor, w.stride(0), w.stride(1), w.stride(2), w.stride(2) != 1,
618795
flex.rhs_data.scale,
619-
mx_ctx.weight_scale, mx_scale_stride_e, mx_scale_stride_k, mx_scale_stride_n, mx_scale_stride_n != 1,
796+
mx_tensor, mx_scale_stride_e, mx_scale_stride_k, mx_scale_stride_n, mx_scale_stride_n != 1,
620797
bias, bias_stride,
621798
x.shape[1],
622799
x.shape[1] if routing_data.expt_hist is None else None,

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import triton
44
import triton.language as tl
5+
from triton.tools.tensor_descriptor import TensorDescriptor
56

67
# -----------------------------------------------------------------------------
78
# Utilities
@@ -48,8 +49,16 @@ def matmul_repr(specialization):
4849
constants = specialization.constants
4950
reorder = lambda L: [L[i] for i in order]
5051
layout = lambda stride: "N" if stride in constants else "T"
51-
convert_dtype = lambda dtype: "mxfp4" if "u8" in dtype else dtype
52-
dtypes = "x".join([convert_dtype(f"{signature[i][1:]}") for i in reorder(["Y", "X", "W"])])
52+
53+
def convert_dtype(dtype):
54+
if "tensordesc" in dtype:
55+
return dtype.split("<")[1].split("[")[0]
56+
elif "u8" in dtype:
57+
return "mxfp4"
58+
else:
59+
return dtype[1:]
60+
61+
dtypes = "x".join([convert_dtype(f"{signature[i]}") for i in reorder(["Y", "X", "W"])])
5362
layouts = "".join([f"{layout(i)}" for i in reorder(["stride_y_n", "stride_x_k", "stride_w_n"])])
5463
blocks = "x".join([f"{constants[i]}" for i in ["BLOCK_M", "BLOCK_N", "BLOCK_K", "SPLIT_K"]])
5564
# mode = []
@@ -68,7 +77,7 @@ def matmul_repr(specialization):
6877
def matmul_launch_metadata(grid, kernel, args):
6978
ret = dict()
7079
M, N, K = args["M"], args["N"], args["K"]
71-
Y, X, W = args["Y"], args["X"], args["W"]
80+
Y, X, W = [t.base if isinstance(t, TensorDescriptor) else t for t in [args["Y"], args["X"], args["W"]]]
7281
hist = args["ExptHist"]
7382
if hist is not None:
7483
n_tokens = float(hist.sum())
@@ -98,7 +107,8 @@ def matmul_launch_metadata(grid, kernel, args):
98107
n_x_bytes = X.numel() * X.element_size()
99108
n_y_bytes = Y.numel() * Y.element_size()
100109
if hist is not None:
101-
assert X.shape[0] == Y.shape[0] == 1, "batched mode not supported"
110+
if not isinstance(args["X"], TensorDescriptor):
111+
assert X.shape[0] == Y.shape[0] == 1, "batched mode not supported"
102112
assert n_tokens is not None
103113
n_expts_act = args["N_EXPTS_ACT"]
104114

0 commit comments

Comments
 (0)