Skip to content

Commit 9326a2d

Browse files
authored
[kernels] use more host TMA for X, W, Mx in persistent matmul (#7182)
host TMA is used for X when it is loaded, not gathered host TMA is used always for W and Mx scales gather TMA stays device-side as this performs better than host TMAs also, some fixes in the epilogue for subtiling of the bias tensor. Load chunks of it separately instead of loading all of it then splitting it; this reduces spilling.
1 parent d748303 commit 9326a2d

File tree

4 files changed

+121
-246
lines changed

4 files changed

+121
-246
lines changed

python/triton/tools/tensor_descriptor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from typing import List, Any
33
from triton._utils import validate_block_shape
4+
from torch._subclasses.fake_tensor import FakeTensor
5+
from torch._subclasses.functional_tensor import FunctionalTensor
46

57

68
@dataclass
@@ -16,7 +18,8 @@ def __post_init__(self):
1618
assert len(self.block_shape) == rank, f"rank mismatch: {self}"
1719
assert rank > 0, "rank must not be zero"
1820
assert rank <= 5, "rank cannot be more than 5"
19-
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
21+
if not isinstance(self.base, (FakeTensor, FunctionalTensor)):
22+
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
2023
validate_block_shape(self.block_shape)
2124
elem_bytes = self.base.dtype.itemsize
2225
for stride in self.strides[:-1]:

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 74 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from dataclasses import dataclass
22
import itertools
3-
import math
43
import sys
54
import torch
65
import triton
@@ -121,20 +120,19 @@ def create_weight_descriptor(w_tensor: torch.Tensor, block_k: int, block_n: int,
121120
transpose=transpose)
122121

123122
@staticmethod
124-
def create_block_scale_descriptor(mx_tensor: torch.Tensor, block_k: int, block_n: int, K: int, N: int,
125-
mx_scale_stride_k: int, mx_scale_stride_n: int, n_expts_tot: int, batch_size: int,
126-
expt_data: Optional[ExptData], swizzle_mx: bool,
127-
transpose: bool) -> TensorDescriptor:
123+
def create_block_scale_descriptor(mx_tensor: torch.Tensor, block_k: int, block_n: int, B: int, K: int, N: int,
124+
mx_scale_stride_k: int, mx_scale_stride_n: int, swizzle_mx: bool,
125+
transpose: Optional[bool]) -> TensorDescriptor:
128126
"""Create a tensor descriptor for block scale factors"""
129127
MX_PACK_DIVISOR = 32
130128
MX_SCALE_BLOCK_K = block_k // MX_PACK_DIVISOR
131129
PackedK = (K + MX_PACK_DIVISOR - 1) // MX_PACK_DIVISOR
132130

133131
if swizzle_mx:
134-
num_expt_x_ncol = (n_expts_tot if expt_data is not None and len(expt_data.block_pid_map) > 0 else
135-
batch_size) * ((N + 127) // 128)
132+
assert transpose is None
133+
num_expt_x_ncol = B * triton.cdiv(N, 128)
136134
return TensorDescriptor(
137-
base=mx_tensor, shape=[1, num_expt_x_ncol, (PackedK + 3) // 4, 2, 256],
135+
base=mx_tensor, shape=[1, num_expt_x_ncol, triton.cdiv(PackedK, 4), 2, 256],
138136
strides=[num_expt_x_ncol * mx_scale_stride_n, mx_scale_stride_n, mx_scale_stride_k, 256,
139137
1], block_shape=[1, block_n // 128, MX_SCALE_BLOCK_K // 4, 2, 256])
140138
else:
@@ -151,35 +149,12 @@ def squeeze_after_dim(x, dim=2):
151149
return x.view(*new_shape)
152150

153151
@staticmethod
154-
def create_input_descriptor_gather(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int,
155-
block_k: int) -> TensorDescriptor:
156-
"""Create a tensor descriptor for input matrix X via TMA gather"""
157-
x_desc = TensorDescriptorBuilder.squeeze_after_dim(x_tensor)
158-
assert x_desc.ndim == 2, "TMA gather descriptor requires 2D input"
159-
INT_MAX = 2147483647
160-
return TensorDescriptor(base=x_desc, shape=[INT_MAX, K], strides=[x_stride_1, x_stride_2],
161-
block_shape=[1, block_k])
162-
163-
@staticmethod
164-
def create_input_descriptor_load(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int, block_m: int,
165-
block_k: int) -> TensorDescriptor:
166-
"""Create a tensor descriptor for input matrix X via TMA"""
167-
x_desc = TensorDescriptorBuilder.squeeze_after_dim(x_tensor)
168-
assert x_desc.ndim in [2, 3], "LHS input TMA descriptor builder expects 2D or 3D input"
169-
return TensorDescriptor(base=x_desc, shape=[x_desc.shape[0], K], strides=[x_stride_1, x_stride_2],
170-
block_shape=[block_m, block_k])
171-
172-
@staticmethod
173-
def create_input_descriptor(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int, block_k: int,
174-
block_m: int, use_gather_tma: bool, use_load_tma: bool) -> TensorDescriptor:
175-
"""Create a tensor descriptor for input matrix X based on TMA usage"""
176-
if use_gather_tma:
177-
return TensorDescriptorBuilder.create_input_descriptor_gather(x_tensor, K, x_stride_1, x_stride_2, block_k)
178-
elif use_load_tma:
179-
return TensorDescriptorBuilder.create_input_descriptor_load(x_tensor, K, x_stride_1, x_stride_2, block_m,
180-
block_k)
181-
else:
182-
return x_tensor
152+
def create_descriptor(x_tensor: torch.Tensor, block_m: int, block_k: int) -> TensorDescriptor:
153+
"""Create a tensor descriptor for matrix X via TMA"""
154+
x_tensor = TensorDescriptorBuilder.squeeze_after_dim(x_tensor)
155+
assert x_tensor.ndim in [2, 3], "TMA descriptor builder expects 2D or 3D input"
156+
block_shape = [1] * (x_tensor.ndim - 2) + [block_m, block_k]
157+
return TensorDescriptor.from_tensor(x_tensor, block_shape=block_shape)
183158

184159

185160
# ---------------------
@@ -590,66 +565,53 @@ def _create_tma_descriptors(
590565
mx_ctx: MicroscalingCtx,
591566
expt_data: ExptData,
592567
opt_flags: OptFlags,
593-
batch_size: int,
568+
B: int,
594569
K: int,
595570
N: int,
596571
mx_scale_stride_k: int,
597572
mx_scale_stride_n: int,
598-
USE_GATHER_TMA: bool,
599-
X_USE_LOAD_TMA: bool,
600-
w_transpose: bool,
601-
mx_transpose: bool,
573+
HAS_GATHER: bool,
602574
) -> Tuple[bool, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
603575
"""Create and cache TMA descriptors for tensors."""
604-
use_host_tma_descriptors = opt_flags.is_persistent and target_info.cuda_capability_geq(10, 0)
605-
606-
x_desc, w_desc = [None] * 2
607-
descriptors = []
608-
# The dense case currently uses on device descriptor updates
609-
# so we bail out on using host descriptors in that case
610-
if (use_host_tma_descriptors):
611-
if USE_GATHER_TMA or X_USE_LOAD_TMA:
612-
x_desc = TensorDescriptorBuilder.create_input_descriptor(
613-
x, K, x.stride(1), x.stride(2),
614-
opt_flags.block_k, opt_flags.block_m,
615-
USE_GATHER_TMA, X_USE_LOAD_TMA
616-
)
617-
descriptors.append(x_desc)
618-
if (expt_data is not None and len(expt_data.block_pid_map) > 0):
619-
w_desc = TensorDescriptorBuilder.create_weight_descriptor(
620-
w, opt_flags.block_k, opt_flags.block_n, w_transpose
621-
)
622-
is_microscaled_format = (mx_ctx.weight_scale is not None) and (w.dtype == torch.uint8)
623-
if is_microscaled_format:
624-
# Pad the inner shape to 128 for mxfp4 weights
625-
# for mixed precision fp8 x mxfp4 compute
626-
pad = 128
627-
dim_to_pad = -1
628-
old_size = w_desc.shape[dim_to_pad]
629-
padded_size = math.ceil(old_size / pad) * pad
630-
if padded_size != old_size:
631-
w_desc.shape = list(w_desc.shape)
632-
w_desc.shape[dim_to_pad] = padded_size
633-
descriptors.append(w_desc)
634-
# Optional MX scale descriptor
635-
descriptors.append(None)
636-
if mx_tensor is not None:
637-
descriptors[-1] = TensorDescriptorBuilder.create_block_scale_descriptor(
638-
mx_tensor, opt_flags.block_k, opt_flags.block_n, K, N,
639-
mx_scale_stride_k, mx_scale_stride_n, routing_data.n_expts_tot,
640-
batch_size,
641-
expt_data, mx_ctx.swizzle_scale, mx_transpose
642-
)
643576

644-
# TODO: Currently all or none, instead should support a mixture
645-
# of host and device descriptors
646-
if None in descriptors or len(descriptors) == 0:
647-
descriptors = [x, w, mx_tensor]
648-
use_host_tma_descriptors = False
649-
if opt_flags.is_persistent:
650-
opt_flags.target_kernel_kwargs["USE_HOST_TMA_DESCRIPTORS"] = use_host_tma_descriptors
577+
x_tensor_or_desc, mx_desc_and_transpose = x, (None, False)
651578

652-
return use_host_tma_descriptors, *descriptors
579+
if not HAS_GATHER:
580+
x_tensor_or_desc = TensorDescriptorBuilder.create_descriptor(x, opt_flags.block_m, opt_flags.block_k)
581+
582+
w_transpose = w.stride(2) != 1
583+
w_desc = TensorDescriptorBuilder.create_weight_descriptor(
584+
w, opt_flags.block_k, opt_flags.block_n, w_transpose
585+
)
586+
w_desc_and_transpose = (w_desc, w_transpose)
587+
588+
is_microscaled_format = mx_ctx.weight_scale is not None and w.dtype == torch.uint8
589+
if is_microscaled_format:
590+
# Pad the inner shape to 128 for mxfp4 weights; TMA requires this when the compiler uses
591+
# CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B.
592+
# This technically makes the shape masking incorrect, but it's fine because:
593+
# - When the N dim is padded, the scales will be masked to 0.
594+
# - When the K dim is padded, the activations we perform tl.dot with will be masked to 0.
595+
# Note: the scales can't be relied on for zeroing in this case, because they apply to groups
596+
# of 32 elements in the K dimension.
597+
pad = 128
598+
dim_to_pad = -1
599+
old_size = w_desc.shape[dim_to_pad]
600+
padded_size = triton.cdiv(old_size, pad) * pad
601+
if padded_size != old_size:
602+
w_desc.shape = list(w_desc.shape)
603+
w_desc.shape[dim_to_pad] = padded_size
604+
605+
if mx_tensor is not None:
606+
mx_transpose = mx_scale_stride_n != 1 if mx_ctx.swizzle_scale is None else None
607+
mx_desc = TensorDescriptorBuilder.create_block_scale_descriptor(
608+
mx_tensor, opt_flags.block_k, opt_flags.block_n,
609+
routing_data.n_expts_tot if expt_data is not None and len(expt_data.block_pid_map) > 0 else B, K, N,
610+
mx_scale_stride_k, mx_scale_stride_n, mx_ctx.swizzle_scale, mx_transpose
611+
)
612+
mx_desc_and_transpose = (mx_desc, mx_transpose)
613+
614+
return x_tensor_or_desc, w_desc_and_transpose, mx_desc_and_transpose
653615

654616

655617
def matmul_ogs(x, w, bias,
@@ -754,41 +716,39 @@ def matmul_ogs(x, w, bias,
754716
expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
755717
expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
756718

757-
HAS_TMA_GS = target_info.cuda_capability_geq(10, 0)
758-
USE_GATHER_TMA = HAS_TMA_GS and gather_indx is not None
759-
X_USE_LOAD_TMA = gather_indx is None and not USE_GATHER_TMA
760-
_, x_tensor, w_tensor, mx_tensor = _create_tma_descriptors(
761-
x=x, w=w,
762-
mx_tensor=mx_ctx.weight_scale,
763-
routing_data=routing_data,
764-
mx_ctx=mx_ctx,
765-
expt_data=expt_data,
766-
opt_flags=opt_flags,
767-
batch_size=batch_size,
768-
K=K,
769-
N=N,
770-
mx_scale_stride_k=mx_scale_stride_k,
771-
mx_scale_stride_n=mx_scale_stride_n,
772-
USE_GATHER_TMA=USE_GATHER_TMA,
773-
X_USE_LOAD_TMA=X_USE_LOAD_TMA,
774-
w_transpose=w.stride(2) != 1,
775-
mx_transpose=mx_scale_stride_n != 1,
776-
)
719+
if opt_flags.is_persistent:
720+
x_tensor, w_tensor_and_transpose, mx_tensor_and_tranpose = _create_tma_descriptors(
721+
x=x, w=w, mx_tensor=mx_ctx.weight_scale,
722+
routing_data=routing_data,
723+
mx_ctx=mx_ctx,
724+
expt_data=expt_data,
725+
opt_flags=opt_flags,
726+
B=batch_size,
727+
K=K,
728+
N=N,
729+
mx_scale_stride_k=mx_scale_stride_k,
730+
mx_scale_stride_n=mx_scale_stride_n,
731+
HAS_GATHER=gather_indx is not None,
732+
)
733+
w_tensor, w_tma_transpose = w_tensor_and_transpose
734+
mx_tensor, mx_tma_transpose = mx_tensor_and_tranpose
735+
else:
736+
x_tensor = x
737+
w_tensor, w_tma_transpose = w, False
738+
mx_tensor, mx_tma_transpose = mx_ctx.weight_scale, False
777739
if isinstance(x_tensor, torch.Tensor):
778740
x_tensor = flex.lhs_data.reinterpret(x)
779741
if isinstance(w_tensor, torch.Tensor):
780742
w_tensor = flex.rhs_data.reinterpret(w)
781743
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)](
782744
flex.out_data.reinterpret(memory["output"]),
783-
flex.out_data.reinterpret(out0), *out0.stride(),
784-
*out0_flex,
745+
flex.out_data.reinterpret(out0), *out0.stride(), *out0_flex,
785746
x_tensor, x.stride(0), x.stride(1), x.stride(2),
786747
flex.lhs_data.scale,
787-
w_tensor, w.stride(0), w.stride(1), w.stride(2), w.stride(2) != 1,
748+
w_tensor, w.stride(0), w.stride(1), w.stride(2), w_tma_transpose,
788749
flex.rhs_data.scale,
789-
mx_tensor, mx_scale_stride_e, mx_scale_stride_k, mx_scale_stride_n, mx_scale_stride_n != 1,
750+
mx_tensor, mx_scale_stride_e, mx_scale_stride_k, mx_scale_stride_n, mx_tma_transpose,
790751
bias, bias_stride,
791-
x.shape[1],
792752
x.shape[1] if routing_data.expt_hist is None else None,
793753
N, K,
794754
betas, gammas,

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _matmul_ogs(
3939
WScale,
4040
MxScale, stride_mx_e, stride_mx_k, stride_mx_n, MX_TRANSPOSE: tl.constexpr,
4141
B, stride_b_e, # Bias
42-
NRows, M, N, K, # shapes
42+
M, N, K, # shapes
4343
# expt data
4444
Betas, Gammas,
4545
GatherIndx,

0 commit comments

Comments
 (0)