Skip to content

Commit eacb681

Browse files
[release/3.4] Cherry-pick triton-lang#7182 (triton-lang#7437)
We are from NVIDIA and have been testing the moe kernels internally. We have seen some strange illegal memory access issue on release/3.4.x, but it's gone on main. After bisection we see this triton-lang#7182 fixes the issue. We think it's important to get this one in release branch so that we can make good use of 3.4 release once it's out. Co-authored-by: aeng-openai <[email protected]>
1 parent 6e1dafa commit eacb681

File tree

4 files changed

+121
-246
lines changed

4 files changed

+121
-246
lines changed

python/triton/tools/tensor_descriptor.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from typing import List, Any
33
from triton._utils import validate_block_shape
4+
from torch._subclasses.fake_tensor import FakeTensor
5+
from torch._subclasses.functional_tensor import FunctionalTensor
46

57

68
@dataclass
@@ -16,7 +18,8 @@ def __post_init__(self):
1618
assert len(self.block_shape) == rank, f"rank mismatch: {self}"
1719
assert rank > 0, "rank must not be zero"
1820
assert rank <= 5, "rank cannot be more than 5"
19-
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
21+
if not isinstance(self.base, (FakeTensor, FunctionalTensor)):
22+
assert self.base.data_ptr() % 16 == 0, "base must be 16-byte aligned"
2023
validate_block_shape(self.block_shape)
2124
elem_bytes = self.base.dtype.itemsize
2225
for stride in self.strides[:-1]:

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 74 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from dataclasses import dataclass
22
import itertools
3-
import math
43
import sys
54
import torch
65
import triton
@@ -121,20 +120,19 @@ def create_weight_descriptor(w_tensor: torch.Tensor, block_k: int, block_n: int,
121120
transpose=transpose)
122121

123122
@staticmethod
124-
def create_block_scale_descriptor(mx_tensor: torch.Tensor, block_k: int, block_n: int, K: int, N: int,
125-
mx_scale_stride_k: int, mx_scale_stride_n: int, n_expts_tot: int, batch_size: int,
126-
expt_data: Optional[ExptData], swizzle_mx: bool,
127-
transpose: bool) -> TensorDescriptor:
123+
def create_block_scale_descriptor(mx_tensor: torch.Tensor, block_k: int, block_n: int, B: int, K: int, N: int,
124+
mx_scale_stride_k: int, mx_scale_stride_n: int, swizzle_mx: bool,
125+
transpose: Optional[bool]) -> TensorDescriptor:
128126
"""Create a tensor descriptor for block scale factors"""
129127
MX_PACK_DIVISOR = 32
130128
MX_SCALE_BLOCK_K = block_k // MX_PACK_DIVISOR
131129
PackedK = (K + MX_PACK_DIVISOR - 1) // MX_PACK_DIVISOR
132130

133131
if swizzle_mx:
134-
num_expt_x_ncol = (n_expts_tot if expt_data is not None and len(expt_data.block_pid_map) > 0 else
135-
batch_size) * ((N + 127) // 128)
132+
assert transpose is None
133+
num_expt_x_ncol = B * triton.cdiv(N, 128)
136134
return TensorDescriptor(
137-
base=mx_tensor, shape=[1, num_expt_x_ncol, (PackedK + 3) // 4, 2, 256],
135+
base=mx_tensor, shape=[1, num_expt_x_ncol, triton.cdiv(PackedK, 4), 2, 256],
138136
strides=[num_expt_x_ncol * mx_scale_stride_n, mx_scale_stride_n, mx_scale_stride_k, 256,
139137
1], block_shape=[1, block_n // 128, MX_SCALE_BLOCK_K // 4, 2, 256])
140138
else:
@@ -151,35 +149,12 @@ def squeeze_after_dim(x, dim=2):
151149
return x.view(*new_shape)
152150

153151
@staticmethod
154-
def create_input_descriptor_gather(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int,
155-
block_k: int) -> TensorDescriptor:
156-
"""Create a tensor descriptor for input matrix X via TMA gather"""
157-
x_desc = TensorDescriptorBuilder.squeeze_after_dim(x_tensor)
158-
assert x_desc.ndim == 2, "TMA gather descriptor requires 2D input"
159-
INT_MAX = 2147483647
160-
return TensorDescriptor(base=x_desc, shape=[INT_MAX, K], strides=[x_stride_1, x_stride_2],
161-
block_shape=[1, block_k])
162-
163-
@staticmethod
164-
def create_input_descriptor_load(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int, block_m: int,
165-
block_k: int) -> TensorDescriptor:
166-
"""Create a tensor descriptor for input matrix X via TMA"""
167-
x_desc = TensorDescriptorBuilder.squeeze_after_dim(x_tensor)
168-
assert x_desc.ndim in [2, 3], "LHS input TMA descriptor builder expects 2D or 3D input"
169-
return TensorDescriptor(base=x_desc, shape=[x_desc.shape[0], K], strides=[x_stride_1, x_stride_2],
170-
block_shape=[block_m, block_k])
171-
172-
@staticmethod
173-
def create_input_descriptor(x_tensor: torch.Tensor, K: int, x_stride_1: int, x_stride_2: int, block_k: int,
174-
block_m: int, use_gather_tma: bool, use_load_tma: bool) -> TensorDescriptor:
175-
"""Create a tensor descriptor for input matrix X based on TMA usage"""
176-
if use_gather_tma:
177-
return TensorDescriptorBuilder.create_input_descriptor_gather(x_tensor, K, x_stride_1, x_stride_2, block_k)
178-
elif use_load_tma:
179-
return TensorDescriptorBuilder.create_input_descriptor_load(x_tensor, K, x_stride_1, x_stride_2, block_m,
180-
block_k)
181-
else:
182-
return x_tensor
152+
def create_descriptor(x_tensor: torch.Tensor, block_m: int, block_k: int) -> TensorDescriptor:
153+
"""Create a tensor descriptor for matrix X via TMA"""
154+
x_tensor = TensorDescriptorBuilder.squeeze_after_dim(x_tensor)
155+
assert x_tensor.ndim in [2, 3], "TMA descriptor builder expects 2D or 3D input"
156+
block_shape = [1] * (x_tensor.ndim - 2) + [block_m, block_k]
157+
return TensorDescriptor.from_tensor(x_tensor, block_shape=block_shape)
183158

184159

185160
# ---------------------
@@ -590,66 +565,53 @@ def _create_tma_descriptors(
590565
mx_ctx: MicroscalingCtx,
591566
expt_data: ExptData,
592567
opt_flags: OptFlags,
593-
batch_size: int,
568+
B: int,
594569
K: int,
595570
N: int,
596571
mx_scale_stride_k: int,
597572
mx_scale_stride_n: int,
598-
USE_GATHER_TMA: bool,
599-
X_USE_LOAD_TMA: bool,
600-
w_transpose: bool,
601-
mx_transpose: bool,
573+
HAS_GATHER: bool,
602574
) -> Tuple[bool, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
603575
"""Create and cache TMA descriptors for tensors."""
604-
use_host_tma_descriptors = opt_flags.is_persistent and target_info.cuda_capability_geq(10, 0)
605-
606-
x_desc, w_desc = [None] * 2
607-
descriptors = []
608-
# The dense case currently uses on device descriptor updates
609-
# so we bail out on using host descriptors in that case
610-
if (use_host_tma_descriptors):
611-
if USE_GATHER_TMA or X_USE_LOAD_TMA:
612-
x_desc = TensorDescriptorBuilder.create_input_descriptor(
613-
x, K, x.stride(1), x.stride(2),
614-
opt_flags.block_k, opt_flags.block_m,
615-
USE_GATHER_TMA, X_USE_LOAD_TMA
616-
)
617-
descriptors.append(x_desc)
618-
if (expt_data is not None and len(expt_data.block_pid_map) > 0):
619-
w_desc = TensorDescriptorBuilder.create_weight_descriptor(
620-
w, opt_flags.block_k, opt_flags.block_n, w_transpose
621-
)
622-
is_microscaled_format = (mx_ctx.weight_scale is not None) and (w.dtype == torch.uint8)
623-
if is_microscaled_format:
624-
# Pad the inner shape to 128 for mxfp4 weights
625-
# for mixed precision fp8 x mxfp4 compute
626-
pad = 128
627-
dim_to_pad = -1
628-
old_size = w_desc.shape[dim_to_pad]
629-
padded_size = math.ceil(old_size / pad) * pad
630-
if padded_size != old_size:
631-
w_desc.shape = list(w_desc.shape)
632-
w_desc.shape[dim_to_pad] = padded_size
633-
descriptors.append(w_desc)
634-
# Optional MX scale descriptor
635-
descriptors.append(None)
636-
if mx_tensor is not None:
637-
descriptors[-1] = TensorDescriptorBuilder.create_block_scale_descriptor(
638-
mx_tensor, opt_flags.block_k, opt_flags.block_n, K, N,
639-
mx_scale_stride_k, mx_scale_stride_n, routing_data.n_expts_tot,
640-
batch_size,
641-
expt_data, mx_ctx.swizzle_scale, mx_transpose
642-
)
643576

644-
# TODO: Currently all or none, instead should support a mixture
645-
# of host and device descriptors
646-
if None in descriptors or len(descriptors) == 0:
647-
descriptors = [x, w, mx_tensor]
648-
use_host_tma_descriptors = False
649-
if opt_flags.is_persistent:
650-
opt_flags.target_kernel_kwargs["USE_HOST_TMA_DESCRIPTORS"] = use_host_tma_descriptors
577+
x_tensor_or_desc, mx_desc_and_transpose = x, (None, False)
651578

652-
return use_host_tma_descriptors, *descriptors
579+
if not HAS_GATHER:
580+
x_tensor_or_desc = TensorDescriptorBuilder.create_descriptor(x, opt_flags.block_m, opt_flags.block_k)
581+
582+
w_transpose = w.stride(2) != 1
583+
w_desc = TensorDescriptorBuilder.create_weight_descriptor(
584+
w, opt_flags.block_k, opt_flags.block_n, w_transpose
585+
)
586+
w_desc_and_transpose = (w_desc, w_transpose)
587+
588+
is_microscaled_format = mx_ctx.weight_scale is not None and w.dtype == torch.uint8
589+
if is_microscaled_format:
590+
# Pad the inner shape to 128 for mxfp4 weights; TMA requires this when the compiler uses
591+
# CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B.
592+
# This technically makes the shape masking incorrect, but it's fine because:
593+
# - When the N dim is padded, the scales will be masked to 0.
594+
# - When the K dim is padded, the activations we perform tl.dot with will be masked to 0.
595+
# Note: the scales can't be relied on for zeroing in this case, because they apply to groups
596+
# of 32 elements in the K dimension.
597+
pad = 128
598+
dim_to_pad = -1
599+
old_size = w_desc.shape[dim_to_pad]
600+
padded_size = triton.cdiv(old_size, pad) * pad
601+
if padded_size != old_size:
602+
w_desc.shape = list(w_desc.shape)
603+
w_desc.shape[dim_to_pad] = padded_size
604+
605+
if mx_tensor is not None:
606+
mx_transpose = mx_scale_stride_n != 1 if mx_ctx.swizzle_scale is None else None
607+
mx_desc = TensorDescriptorBuilder.create_block_scale_descriptor(
608+
mx_tensor, opt_flags.block_k, opt_flags.block_n,
609+
routing_data.n_expts_tot if expt_data is not None and len(expt_data.block_pid_map) > 0 else B, K, N,
610+
mx_scale_stride_k, mx_scale_stride_n, mx_ctx.swizzle_scale, mx_transpose
611+
)
612+
mx_desc_and_transpose = (mx_desc, mx_transpose)
613+
614+
return x_tensor_or_desc, w_desc_and_transpose, mx_desc_and_transpose
653615

654616

655617
def matmul_ogs(x, w, bias,
@@ -754,41 +716,39 @@ def matmul_ogs(x, w, bias,
754716
expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
755717
expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
756718

757-
HAS_TMA_GS = target_info.cuda_capability_geq(10, 0)
758-
USE_GATHER_TMA = HAS_TMA_GS and gather_indx is not None
759-
X_USE_LOAD_TMA = gather_indx is None and not USE_GATHER_TMA
760-
_, x_tensor, w_tensor, mx_tensor = _create_tma_descriptors(
761-
x=x, w=w,
762-
mx_tensor=mx_ctx.weight_scale,
763-
routing_data=routing_data,
764-
mx_ctx=mx_ctx,
765-
expt_data=expt_data,
766-
opt_flags=opt_flags,
767-
batch_size=batch_size,
768-
K=K,
769-
N=N,
770-
mx_scale_stride_k=mx_scale_stride_k,
771-
mx_scale_stride_n=mx_scale_stride_n,
772-
USE_GATHER_TMA=USE_GATHER_TMA,
773-
X_USE_LOAD_TMA=X_USE_LOAD_TMA,
774-
w_transpose=w.stride(2) != 1,
775-
mx_transpose=mx_scale_stride_n != 1,
776-
)
719+
if opt_flags.is_persistent:
720+
x_tensor, w_tensor_and_transpose, mx_tensor_and_tranpose = _create_tma_descriptors(
721+
x=x, w=w, mx_tensor=mx_ctx.weight_scale,
722+
routing_data=routing_data,
723+
mx_ctx=mx_ctx,
724+
expt_data=expt_data,
725+
opt_flags=opt_flags,
726+
B=batch_size,
727+
K=K,
728+
N=N,
729+
mx_scale_stride_k=mx_scale_stride_k,
730+
mx_scale_stride_n=mx_scale_stride_n,
731+
HAS_GATHER=gather_indx is not None,
732+
)
733+
w_tensor, w_tma_transpose = w_tensor_and_transpose
734+
mx_tensor, mx_tma_transpose = mx_tensor_and_tranpose
735+
else:
736+
x_tensor = x
737+
w_tensor, w_tma_transpose = w, False
738+
mx_tensor, mx_tma_transpose = mx_ctx.weight_scale, False
777739
if isinstance(x_tensor, torch.Tensor):
778740
x_tensor = flex.lhs_data.reinterpret(x)
779741
if isinstance(w_tensor, torch.Tensor):
780742
w_tensor = flex.rhs_data.reinterpret(w)
781743
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)](
782744
flex.out_data.reinterpret(memory["output"]),
783-
flex.out_data.reinterpret(out0), *out0.stride(),
784-
*out0_flex,
745+
flex.out_data.reinterpret(out0), *out0.stride(), *out0_flex,
785746
x_tensor, x.stride(0), x.stride(1), x.stride(2),
786747
flex.lhs_data.scale,
787-
w_tensor, w.stride(0), w.stride(1), w.stride(2), w.stride(2) != 1,
748+
w_tensor, w.stride(0), w.stride(1), w.stride(2), w_tma_transpose,
788749
flex.rhs_data.scale,
789-
mx_tensor, mx_scale_stride_e, mx_scale_stride_k, mx_scale_stride_n, mx_scale_stride_n != 1,
750+
mx_tensor, mx_scale_stride_e, mx_scale_stride_k, mx_scale_stride_n, mx_tma_transpose,
790751
bias, bias_stride,
791-
x.shape[1],
792752
x.shape[1] if routing_data.expt_hist is None else None,
793753
N, K,
794754
betas, gammas,

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _matmul_ogs(
3939
WScale,
4040
MxScale, stride_mx_e, stride_mx_k, stride_mx_n, MX_TRANSPOSE: tl.constexpr,
4141
B, stride_b_e, # Bias
42-
NRows, M, N, K, # shapes
42+
M, N, K, # shapes
4343
# expt data
4444
Betas, Gammas,
4545
GatherIndx,

0 commit comments

Comments
 (0)