Skip to content

Commit 984b694

Browse files
[mxfp/easy] add MXFP_BLOCK_SIZE constant (#7567)
Replace hard-coded 32 for better readability <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because covered by existing test ``python/triton_kernels/tests/test_mxfp.py``. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 67f647a commit 984b694

File tree

6 files changed

+39
-23
lines changed

6 files changed

+39
-23
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
55
from triton_kernels.tensor_details.layout_details.hopper_value import unswizzle_mxfp4_value_hopper
66
from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale
7+
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
78
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
89

910
# fmt: off
@@ -75,7 +76,7 @@ def _matmul_ogs(
7576

7677
Y = Out # Y is passed for the purposes of annotation; replace it with Out
7778
is_microscaled_format: tl.constexpr = MxScale is not None
78-
MX_PACK_DIVISOR: tl.constexpr = 32
79+
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
7980
if is_microscaled_format:
8081
w_type: tl.constexpr = W.dtype.element_ty
8182
is_mxfp4: tl.constexpr = w_type == tl.uint8

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
import triton.language as tl
44
from triton_kernels import target_info
55
from triton_kernels.tensor_details.layout_details.blackwell_scale import unswizzle_mx_scale_bw
6-
from triton_kernels.numerics_details.flexpoint import float_to_flex, load_scale, nan_propagating_absmax_reduce, compute_scale
6+
from triton_kernels.numerics_details.flexpoint import (
7+
float_to_flex,
8+
load_scale,
9+
nan_propagating_absmax_reduce,
10+
compute_scale,
11+
)
12+
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
713
from ._common import make_matmul_repr, matmul_launch_metadata, swizzle2d, xcd_swizzle, get_scaled_dot_format_string
814

915
# fmt: off
@@ -147,7 +153,7 @@ def _p_matmul_ogs(
147153
Y = Out # Y is passed for the purposes of annotation; replace it with Out
148154

149155
is_microscaled_format: tl.constexpr = MxScale is not None
150-
MX_PACK_DIVISOR: tl.constexpr = 32
156+
MX_PACK_DIVISOR: tl.constexpr = MXFP_BLOCK_SIZE
151157
if is_microscaled_format:
152158
w_type: tl.constexpr = get_dtype(W)
153159
tl.static_assert(w_type == tl.uint8 or (w_type == tl.float8e4nv or w_type == tl.float8e5),

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import triton
33
from triton_kernels import target_info
44
from triton_kernels.tensor import bitwidth, FP4
5+
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
56

67

78
def compute_grid_size(routing_data, m, n, block_m, block_n):
@@ -97,9 +98,9 @@ def compute_num_stages(
9798
smem_capacity -= int((block_m + 4) * acc_block_n * acc_size)
9899
if precision_config.weight_scale is not None:
99100
# mx scales
100-
stage_size += block_n * (block_k // 32)
101+
stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
101102
elif has_native_mxfp:
102103
# mx scales
103-
stage_size += block_n * (block_k // 32)
104+
stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
104105
num_stages = min(4, smem_capacity // int(stage_size))
105106
return num_stages

python/triton_kernels/triton_kernels/numerics_details/mxfp.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
# isort: off
2+
# fmt: off
13
from enum import Enum
24
import triton
35
import torch
46
import torch.nn.functional as F
57
from .mxfp_details._upcast_from_mxfp import _upcast_from_mxfp
6-
from .mxfp_details._downcast_to_mxfp import _downcast_to_mxfp
8+
from .mxfp_details._downcast_to_mxfp import _downcast_to_mxfp, MXFP_BLOCK_SIZE
79

810
# -----------------------------------------------------------------------------
911
# Dequantization / Quantization Utilities
@@ -39,7 +41,7 @@ def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis
3941
if is_fp4:
4042
assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
4143
out_shape = src_tensor.shape[:-1] + (L // divisor, )
42-
out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, 32), )
44+
out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
4345

4446
out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
4547
out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
@@ -49,7 +51,7 @@ def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis
4951
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
5052

5153
BLOCK_OUT_DIM = 128
52-
BLOCK_QUANT_DIM = 32
54+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE
5355
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
5456
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
5557

@@ -90,7 +92,7 @@ def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dty
9092
reshaped_tensor = tensor.view(-1, tensor.shape[-1])
9193
reshaped_scale = scale.view(-1, scale.shape[-1])
9294
BLOCK_OUT_DIM = 128
93-
BLOCK_QUANT_DIM = 32
95+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE
9496
blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
9597
blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
9698
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
@@ -153,7 +155,7 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype
153155
axis_shape = src.shape[-1]
154156

155157
# Pad the axis to be divisible by 32, in case it is not.
156-
next_multiple = (axis_shape + 31) // 32 * 32
158+
next_multiple = triton.cdiv(axis_shape, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
157159
pad_amount = next_multiple - axis_shape
158160
padded_src = F.pad(src, (0, pad_amount))
159161
valid_mask = F.pad(torch.ones_like(src, dtype=torch.bool), (0, pad_amount))
@@ -164,7 +166,7 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype
164166
abs_f = torch.abs(padded_src)
165167
abs_f = torch.where(valid_mask, abs_f, torch.tensor(-1.0, device=device, dtype=padded_src.dtype))
166168
# Reshape the last dimension into groups of 32.
167-
new_shape = padded_src.shape[:-1] + (padded_axis_shape // 32, 32)
169+
new_shape = padded_src.shape[:-1] + (padded_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
168170
abs_groups = abs_f.view(*new_shape)
169171
# Compute maximum along the group dimension (of size 32).
170172
max_val, _ = abs_groups.max(dim=-1, keepdim=True)
@@ -277,12 +279,12 @@ def upcast_from_mxfp_torch(tensor: torch.Tensor, scale: torch.Tensor, target_dty
277279

278280
logical_quant_dim = tensor.shape[-1] * (2 if tensor.dtype == torch.uint8 else 1)
279281
axis_shape = fp32_tensor.size(-1)
280-
padded_axis_shape = triton.cdiv(logical_quant_dim, 32) * 32
282+
padded_axis_shape = triton.cdiv(logical_quant_dim, MXFP_BLOCK_SIZE) * MXFP_BLOCK_SIZE
281283
pad_size = padded_axis_shape - axis_shape
282284
padded_tensor = F.pad(fp32_tensor, (0, pad_size))
283285

284286
new_axis_shape = padded_tensor.shape[-1]
285-
new_shape = padded_tensor.shape[:-1] + (new_axis_shape // 32, 32)
287+
new_shape = padded_tensor.shape[:-1] + (new_axis_shape // MXFP_BLOCK_SIZE, MXFP_BLOCK_SIZE)
286288
padded_tensor = padded_tensor.view(*new_shape)
287289
dq_scale_padded = dq_scale.unsqueeze(-1) # shape: [..., ceil(axis_shape/32), 1]
288290
out_padded = padded_tensor * dq_scale_padded

python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33

44
# fmt: off
55

6+
7+
MXFP_BLOCK_SIZE = tl.constexpr(32)
8+
9+
610
@triton.jit
711
def _get_max_quant_val(dtype: tl.constexpr):
812
if dtype == tl.uint8:
@@ -20,13 +24,13 @@ def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.con
2024
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
2125
BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0]
2226
BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1]
23-
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // 32
27+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // MXFP_BLOCK_SIZE
2428

2529
# Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16
2630
f32_tensor = src_tensor.to(tl.float32)
2731
abs_tensor = tl.abs(f32_tensor)
2832
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
29-
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32])
33+
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
3034
max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
3135
dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
3236
if DEQUANT_SCALE_ROUNDING_MODE == 0:
@@ -44,7 +48,7 @@ def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.con
4448
dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
4549
quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)
4650

47-
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32])
51+
f32_tensor = tl.reshape(f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
4852
quant_tensor = f32_tensor * quant_scale
4953

5054
# Reshape the tensors after scaling
@@ -94,7 +98,7 @@ def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.cons
9498
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
9599

96100
tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
97-
tl.static_assert(BLOCK_SIZE_QUANT_DIM % 32 == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
101+
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
98102

99103
# uint8 signifies two fp4 e2m1 values packed into a single byte
100104
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
@@ -110,7 +114,7 @@ def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.cons
110114
quant_block = tl.program_id(1).to(tl.int64)
111115

112116
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
113-
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // 32
117+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
114118
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
115119

116120
start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM
@@ -134,7 +138,7 @@ def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.cons
134138
mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
135139
full_mask_mxt = mask_mxt_quant & mask_n
136140

137-
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, 32)
141+
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
138142
full_scale_mask = scale_mask_k & mask_n
139143

140144
src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer

python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_upcast_from_mxfp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import triton
22
import triton.language as tl
3+
from ._downcast_to_mxfp import MXFP_BLOCK_SIZE
34

45

6+
# fmt: off
57
@triton.jit
68
def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_scale_ptr, stride_scale_outer,
79
stride_scale_quant, mx_tensor_ptr, stride_tensor_outer, stride_tensor_quant: tl.constexpr,
810
outer_dim, quant_dim, BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr):
911

1012
tl.static_assert(stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx")
11-
tl.static_assert(BLOCK_SIZE_QUANT_DIM % 32 == 0, "BLOCK_SIZE_K must be a multiple of 32")
13+
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, "BLOCK_SIZE_K must be a multiple of 32")
1214
# uint8 signifies two fp4 e2m1 values packed into a single byte
1315
mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty
1416
dst_dtype: tl.constexpr = out_ptr.dtype.element_ty
@@ -23,7 +25,7 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_
2325
is_fp4: tl.constexpr = mx_tensor_dtype == tl.uint8
2426
is_fp8: tl.constexpr = mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5
2527
K_DIVISOR: tl.constexpr = 2 if is_fp4 else 1
26-
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // 32
28+
BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // MXFP_BLOCK_SIZE
2729
BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM // K_DIVISOR
2830

2931
# Compute starting indices for the quantized (packed) dimension and the outer dimension.
@@ -52,7 +54,7 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_
5254
mask_src_quant = start_mxt_quant + offs_src_quant < tl.cdiv(quant_dim, K_DIVISOR)
5355
full_mask_src = mask_src_quant & mask_outer
5456

55-
mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, 32)
57+
mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
5658
full_scale_mask = mask_scale & mask_outer
5759

5860
tensor_offsets = offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer
@@ -109,7 +111,7 @@ def _upcast_from_mxfp(out_ptr, stride_o_outer, stride_o_quant: tl.constexpr, mx_
109111
dst_tensor = tl.interleave(x0, x1).to(dst_dtype, bitcast=True)
110112

111113
# Reshape for proper broadcasting: the scale was stored with a 32‐sized “inner” grouping.
112-
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32])
114+
dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
113115
dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1])
114116
scale = scale.reshape(dst_scale.shape)
115117

0 commit comments

Comments
 (0)