Skip to content

Commit a21bbbc

Browse files
[MXFP] mxfp conversions speedup (#8610)
This PR improves the throughput of mxfp8 upcast and downcast operations. I included a commit from @jongsoo-openai (original PR [here](triton-lang/triton#8179)) and added improvements below on top of it. The PR is functionally a no-op, which is verified by the tests in ``python/triton_kernels/tests/test_mxfp.py``. Upcast improvements: - Added native packed e2m1 conversion to fp16 (for Blackwell+). - Added tensor descriptors to utilize TMA for reading the input mxfp value tensor and writing the output. - Note that this addition required adding padding for the innermost dimension for IO tensors not adhering to tensor descriptor specification requirements, and unpadding the output afterwards. - Tuned tile dimensions and num_warps. Downcast improvements: - Enabled vectorized store of mxfp4 value tensors (h/t to @ThomasRaoux), as opposed to byte-level stores. - Tuned the tile dimensions as well as num_warps. - Unfortunately, as opposed to upcast, tensor descriptors were unable to give a consistent performance improvement. I left performance tuning as a TODO for a subsequent PR. ### Performance comparison (BW, in GBps) Done via ``python/triton_kernels/tests/test_mxfp.py``. **Before -- GB200** ``` MXFP8 (e4m3fn): M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16 ---- ---- ------------------- ------------------- ------------------ --------------------- -------------------- 1024 8192 torch.float8_e4m3fn 1985.94 2053.35 2154.61 2347.56 4096 8192 torch.float8_e4m3fn 3479.79 3518.71 3243.02 3753.85 MXFP4 (e2m1): M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16 ---- ---- ------------- ------------------- ------------------ --------------------- -------------------- 1024 8192 torch.uint8 808.089 815.124 647.589 713.9 4096 8192 torch.uint8 1045.23 1041.91 811.089 888.624 ``` **After -- GB200** ``` MXFP8 (e4m3fn): M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16 ---- ---- ------------------- ------------------- ------------------ --------------------- -------------------- 1024 8192 torch.float8_e4m3fn 2259.86 2404.99 2119.76 2361.66 4096 8192 torch.float8_e4m3fn 4106.69 4268.29 4038.16 4059 MXFP4 (e2m1): M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16 ---- ---- ------------- ------------------- ------------------ --------------------- -------------------- 1024 8192 torch.uint8 1334.75 1332.03 1424.7 1397.36 4096 8192 torch.uint8 2027.41 2028.98 2097.15 2275.56 ``` **Before -- H100** ``` MXFP8 (e4m3fn): M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16 ---- ---- ------------------- ------------------- ------------------ --------------------- -------------------- 1024 8192 torch.float8_e4m3fn 1250.29 1244.35 1595.2 1588.75 4096 8192 torch.float8_e4m3fn 1805.81 1799.62 2080.51 2118.34 MXFP4 (e2m1): M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16 ---- ---- ------------- ------------------- ------------------ --------------------- -------------------- 1024 8192 torch.uint8 418.493 416.102 572.367 627.739 4096 8192 torch.uint8 489.531 490.08 687.861 758.08 ``` **After -- H100** ``` MXFP8 (e4m3fn): M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16 ---- ---- ------------------- ------------------- ------------------ --------------------- -------------------- 1024 8192 torch.float8_e4m3fn 1604.96 1624.86 1732.23 1751.52 4096 8192 torch.float8_e4m3fn 2347.56 2337.09 2386.74 2292.8 MXFP4 (e2m1): M N quant_dtype quant_bw_bfloat16 quant_bw_float16 dequant_bw_bfloat16 dequant_bw_float16 ---- ---- ------------- ------------------- ------------------ --------------------- -------------------- 1024 8192 torch.uint8 731.429 745.575 892.861 917.871 4096 8192 torch.uint8 882.343 894.995 1102.37 1165.08 ``` Co-authored-by: jongsoo-openai <[email protected]>
1 parent c33b2d9 commit a21bbbc

File tree

4 files changed

+218
-80
lines changed

4 files changed

+218
-80
lines changed

python/triton_kernels/tests/test_mxfp.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
from functools import partial
23

34
import pytest
@@ -23,19 +24,25 @@ def dtype_str_to_torch(dtype_str: str) -> torch.dtype:
2324
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])
2425
def test_mxfp4_rounding_cases(dst_dtype, device):
2526
dst_dtype = dtype_str_to_torch(dst_dtype)
26-
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3, 1.25, -1.25]).to(device).bfloat16().view(1, -1, 1)
27+
two_point_five_plus_ulp = {
28+
torch.bfloat16: 0.251953125,
29+
torch.float16: 0.250244140625,
30+
torch.float32: 0.2500000298023223877,
31+
}[dst_dtype]
32+
# Construct an example where scale is 1 (when max value is 6.0, the maximum value of e2m1)
33+
x = torch.tensor([6, 0, 0.24, 0.25, 0.75, 0.99, 1.2, 1.3, -1.25, two_point_five_plus_ulp], dtype=dst_dtype,
34+
device=device).view(1, -1, 1)
2735
quant, scale = downcast_to_mxfp(x, torch.uint8, axis=1)
2836
dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1)
2937
# Tie-breaking cases (RTNE):
3038
# - 0.25 is exactly halfway between 0.0 and 0.5. RTNE selects the even quantized value 0.0
3139
# (binary LSB of target is 0). Rounding away from zero would pick 0.5; towards zero also picks 0.0.
3240
# - 0.75 is halfway between 0.5 and 1.0. RTNE selects the even value 1.0 (LSB 0). Away-from-zero would pick 1.0;
3341
# towards-zero would pick 0.5.
34-
# - 1.25 is halfway between 1.0 and 1.5. RTNE selects the even value 1.0. Away-from-zero would pick 1.5;
35-
# towards-zero would pick 1.0.
3642
# - -1.25 is halfway between -1.0 and -1.5. RTNE selects -1.0 (even). Away-from-zero would pick -1.5;
3743
# towards-zero would pick -1.0.
38-
assert dequant.flatten().tolist() == [6, 0, 0, 0.0, 1.0, 1.0, 1.0, 1.5, 1.0, -1.0], f"{dequant=}"
44+
# - two_point_five_plus_ulp is slightly bigger than 0.25, so it rounds to 0.5.
45+
assert dequant.flatten().tolist() == [6, 0, 0, 0.0, 1.0, 1.0, 1.0, 1.5, -1.0, 0.5], f"{dequant=}"
3946

4047
quant_torch, scale_torch = downcast_to_mxfp_torch(x, torch.uint8, axis=1)
4148
assert_equal(quant_torch, quant)
@@ -153,6 +160,7 @@ def test_mxfp_casting(
153160
):
154161
if "float8" in quant_dtype and (is_cuda() and torch.cuda.get_device_capability()[0] < 9):
155162
pytest.skip("Float8 not tested on A100")
163+
torch.manual_seed(0)
156164
quant_torch_type = dtype_str_to_torch(quant_dtype)
157165
dequant_torch_type = dtype_str_to_torch(dequant_dtype)
158166
# Generate random input tensor that is contiguous once axis is the last dimension
@@ -220,15 +228,32 @@ def _benchmark_mxfp_dequantization(shape, src_quant_dtype: torch.dtype, target_d
220228
]
221229

222230
table = []
223-
for shape, dtype in tests:
224-
mxfp8_q_bw = _benchmark_mxfp_quantization(shape, dtype, torch.float8_e4m3fn)
225-
mxfp8_dq_bw = _benchmark_mxfp_dequantization(shape, torch.float8_e4m3fn, dtype)
226-
mxfp4_q_bw = _benchmark_mxfp_quantization(shape, dtype, torch.uint8)
227-
mxfp4_dq_bw = _benchmark_mxfp_dequantization(shape, torch.uint8, dtype)
228-
table.append(shape + (dtype, mxfp8_q_bw, mxfp8_dq_bw, mxfp4_q_bw, mxfp4_dq_bw))
231+
shapes = [(1024, 8192), (4096, 8192)]
232+
source_dtypes = [torch.bfloat16, torch.float16]
233+
for shape, quant_dtype in itertools.product(shapes, [torch.float8_e4m3fn, torch.uint8]):
234+
results = [*shape, quant_dtype]
235+
for src_dtype in source_dtypes:
236+
results.append(_benchmark_mxfp_quantization(shape, src_dtype, quant_dtype))
237+
for src_dtype in source_dtypes:
238+
results.append(_benchmark_mxfp_dequantization(shape, quant_dtype, src_dtype))
239+
table.append(results)
229240

230241
from tabulate import tabulate
231-
print(
232-
tabulate(
233-
table,
234-
headers=["M", "N", "dtype", "mxfp8_quant_bw", "mxfp8_dequant_bw", "mxfp4_quant_bw", "mxfp4_dequant_bw"]))
242+
243+
headers = [
244+
"M",
245+
"N",
246+
"quant_dtype",
247+
"quant_bw_bfloat16",
248+
"quant_bw_float16",
249+
"dequant_bw_bfloat16",
250+
"dequant_bw_float16",
251+
]
252+
mxfp8_rows = [row for row in table if row[2] == torch.float8_e4m3fn]
253+
mxfp4_rows = [row for row in table if row[2] == torch.uint8]
254+
255+
print("MXFP8 (e4m3fn):")
256+
print(tabulate(mxfp8_rows, headers=headers))
257+
print()
258+
print("MXFP4 (e2m1):")
259+
print(tabulate(mxfp4_rows, headers=headers))

python/triton_kernels/triton_kernels/numerics_details/mxfp.py

Lines changed: 79 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn.functional as F
88
from .mxfp_details._upcast_from_mxfp import _upcast_from_mxfp
99
from .mxfp_details._downcast_to_mxfp import _downcast_to_mxfp, MXFP_BLOCK_SIZE, _quantize_mxfp8_fn
10+
from triton.tools.tensor_descriptor import TensorDescriptor
1011

1112
# -----------------------------------------------------------------------------
1213
# Dequantization / Quantization Utilities
@@ -20,7 +21,6 @@ class DequantScaleRoundingMode(Enum):
2021
# chance of clipping the max value.
2122
ROUND_DOWN = 1
2223

23-
2424
def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis: int,
2525
DEQUANT_SCALE_ROUNDING_MODE: DequantScaleRoundingMode = DequantScaleRoundingMode.ROUND_UP):
2626
"""
@@ -44,26 +44,40 @@ def downcast_to_mxfp(src_tensor: torch.Tensor, out_quant_type: torch.dtype, axis
4444
L = src_tensor.shape[-1]
4545
if is_fp4:
4646
assert L % 2 == 0, f"axis dim must be divisible by 2 for e2m1. Got {L}"
47-
out_shape = src_tensor.shape[:-1] + (L // divisor, )
47+
# Ensure last dimension is a multiple of MXFP_BLOCK_SIZE. This is expected by the kernel.
48+
padded_L = triton.cdiv(L, MXFP_BLOCK_SIZE.value) * MXFP_BLOCK_SIZE.value
49+
needs_padding = padded_L != L
50+
out_shape_padded = src_tensor.shape[:-1] + (padded_L // divisor, )
4851
out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, MXFP_BLOCK_SIZE), )
4952

50-
out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type)
53+
out_quant_tensor = src_tensor.new_empty(out_shape_padded, dtype=out_quant_type)
5154
out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8)
5255

5356
if src_tensor.numel() > 0:
54-
kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1])
57+
src_tensor_padded = F.pad(src_tensor, (0, padded_L - L)) if needs_padding else src_tensor
58+
kernel_src_tensor = src_tensor_padded.reshape(-1, src_tensor_padded.shape[-1])
5559
kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1])
5660
kernel_scale = out_scale.view(-1, out_scale.shape[-1])
5761

58-
BLOCK_OUT_DIM = 128
59-
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
60-
grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
61-
grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
62-
63-
_downcast_to_mxfp[(grid_out, grid_quant)](kernel_quant_tensor, *kernel_quant_tensor.stride(), kernel_scale,
64-
*kernel_scale.stride(), kernel_src_tensor, *kernel_src_tensor.stride(),
65-
*kernel_src_tensor.shape, BLOCK_OUT_DIM, BLOCK_QUANT_DIM,
66-
DEQUANT_SCALE_ROUNDING_MODE.value, num_warps=8)
62+
# performance hyper-parameters
63+
BLOCK_OUT_DIM = 32
64+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value * 4
65+
NUM_WARPS = 4 if src_tensor.dtype == torch.float32 else 8
66+
67+
blocks_out_dim = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM)
68+
blocks_quant_dim = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM)
69+
_downcast_to_mxfp[(blocks_out_dim, blocks_quant_dim)](
70+
kernel_quant_tensor, *kernel_quant_tensor.stride(),
71+
kernel_scale, *kernel_scale.stride(),
72+
kernel_src_tensor, *kernel_src_tensor.stride(), *kernel_src_tensor.shape,
73+
BLOCK_OUT_DIM,
74+
BLOCK_QUANT_DIM,
75+
DEQUANT_SCALE_ROUNDING_MODE.value,
76+
num_warps=NUM_WARPS,
77+
)
78+
79+
if needs_padding:
80+
out_quant_tensor = out_quant_tensor[..., : (L // divisor)]
6781

6882
out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1)
6983
out_scale = out_scale.transpose(axis, src_tensor.ndim - 1)
@@ -89,23 +103,56 @@ def upcast_from_mxfp(tensor: torch.Tensor, scale: torch.Tensor, target_dtype: to
89103
assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}"
90104
assert target_dtype in (torch.float16, torch.bfloat16, torch.float32), f"Invalid output dtype {target_dtype=}"
91105
# upcast
92-
logical_quant_dim = tensor.shape[axis] * (2 if tensor.dtype == torch.uint8 else 1)
106+
pack_multiple = 2 if tensor.dtype == torch.uint8 else 1
107+
logical_quant_dim = tensor.shape[axis] * pack_multiple
93108
tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous()
94109
scale = scale.transpose(axis, scale.ndim - 1).contiguous()
95-
out = torch.empty((*tensor.shape[:-1], logical_quant_dim), dtype=target_dtype, device=tensor.device)
110+
original_out_shape = tensor.shape[:-1] + (logical_quant_dim, )
96111

97112
if tensor.numel() > 0:
98-
reshaped_out = out.view(-1, out.shape[-1])
99113
reshaped_tensor = tensor.view(-1, tensor.shape[-1])
100114
reshaped_scale = scale.view(-1, scale.shape[-1])
101-
BLOCK_OUT_DIM = 128
102-
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value
115+
116+
# Pad the tensor and output if needed for tensor descriptor spec requirements.
117+
TENSOR_DESC_PAD_REQ = 16
118+
needs_padding = reshaped_tensor.shape[-1] % TENSOR_DESC_PAD_REQ != 0
119+
if needs_padding:
120+
tensor_pad_amount = TENSOR_DESC_PAD_REQ - (reshaped_tensor.shape[-1] % TENSOR_DESC_PAD_REQ)
121+
reshaped_tensor = F.pad(reshaped_tensor, (0, tensor_pad_amount), "constant", 0)
122+
pad_elems_count = tensor_pad_amount * pack_multiple
123+
out_shape = original_out_shape[:-1] + (original_out_shape[-1] + pad_elems_count, )
124+
else:
125+
out_shape = original_out_shape
126+
out = torch.empty(out_shape, dtype=target_dtype, device=tensor.device)
127+
reshaped_out = out.view(-1, out.shape[-1])
128+
129+
is_fp4 = reshaped_tensor.dtype == torch.uint8
130+
131+
# performance hyper-parameters
132+
BLOCK_OUT_DIM = 64
133+
BLOCK_QUANT_DIM = MXFP_BLOCK_SIZE.value * 4
134+
NUM_WARPS = 4
135+
103136
blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM)
104137
blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM)
105-
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](reshaped_out, *reshaped_out.stride(), reshaped_scale,
106-
*reshaped_scale.stride(), reshaped_tensor,
107-
*reshaped_tensor.stride(), *reshaped_out.shape, BLOCK_OUT_DIM,
108-
BLOCK_QUANT_DIM, num_warps=8)
138+
k_divisor = 2 if is_fp4 else 1
139+
block_size_quant_mx_tensor = BLOCK_QUANT_DIM // k_divisor
140+
out_desc = TensorDescriptor.from_tensor(reshaped_out, [BLOCK_OUT_DIM, BLOCK_QUANT_DIM])
141+
tensor_desc = TensorDescriptor.from_tensor(reshaped_tensor, [BLOCK_OUT_DIM, block_size_quant_mx_tensor])
142+
_upcast_from_mxfp[(blocks_out_dim, blocks_quant_dim)](
143+
out_desc,
144+
tensor_desc,
145+
reshaped_scale,
146+
*reshaped_scale.stride(),
147+
*reshaped_out.shape,
148+
BLOCK_OUT_DIM,
149+
BLOCK_QUANT_DIM,
150+
num_warps=NUM_WARPS,
151+
)
152+
if needs_padding:
153+
out = out[..., :original_out_shape[-1]]
154+
else:
155+
out = torch.empty(original_out_shape, dtype=target_dtype, device=tensor.device)
109156
out = out.transpose(axis, scale.ndim - 1).contiguous()
110157
return out
111158

@@ -218,19 +265,25 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype
218265
# Extract sign, exponent, and mantissa.
219266
signs = q_int & 0x80000000
220267
exponents = right_shift_unsigned(q_int, 23) & 0xFF
221-
mantissas = q_int & 0x7FFFFF
268+
mantissas_orig = q_int & 0x7FFFFF
222269

223270
E8_BIAS = 127
224271
E2_BIAS = 1
225272
# Adjust mantissas for subnormals.
226-
mantissas = torch.where(exponents < E8_BIAS, (0x400000 | right_shift_unsigned(mantissas, 1)) >>
227-
(E8_BIAS - exponents - 1), mantissas)
273+
is_subnormal = exponents < E8_BIAS
274+
shift = E8_BIAS - exponents - 1
275+
mantissas_pre = (0x400000 | right_shift_unsigned(mantissas_orig, 1))
276+
bit0_dropped = (mantissas_orig & 0x1) != 0
277+
mask = (1 << shift.clamp(max=31)) - 1
278+
dropped_post = (mantissas_pre & mask) != 0
279+
sticky = is_subnormal & (bit0_dropped | dropped_post)
280+
mantissas = torch.where(is_subnormal, mantissas_pre >> shift, mantissas_orig)
228281
exponents = torch.maximum(exponents, torch.tensor(E8_BIAS - E2_BIAS, device=device)) - (E8_BIAS - E2_BIAS)
229282
# Round to nearest, ties to even (RTNE)
230283
m2bits = right_shift_unsigned(mantissas, 21) & 0x3
231284
lsb_keep = right_shift_unsigned(m2bits, 1) & 0x1
232285
guard = m2bits & 0x1
233-
sticky = (mantissas & ((1 << 21) - 1)) != 0
286+
sticky |= (mantissas & ((1 << 21) - 1)) != 0
234287
round_inc = guard & (sticky.to(torch.int32) | lsb_keep)
235288
e2m1_tmp = right_shift_unsigned(((exponents << 2) | m2bits) + round_inc, 1)
236289
e2m1_tmp = torch.minimum(e2m1_tmp, torch.tensor(0x7, device=device))

python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import triton
22
import triton.language as tl
3+
from triton_kernels.target_info import cuda_capability_geq
34

45
# fmt: off
56

@@ -72,18 +73,42 @@ def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.con
7273
# Now we must convert the tensors to the mx format.
7374
if is_fp8:
7475
out_tensor = quant_tensor.to(mx_tensor_dtype)
76+
elif cuda_capability_geq(10, 0):
77+
# Convert scaled values to two f32 lanes and use PTX cvt to e2m1x2 with two f32 operands.
78+
pairs = tl.reshape(quant_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM // 2, 2])
79+
lo_f, hi_f = tl.split(pairs)
80+
lo_f32 = lo_f.to(tl.float32)
81+
hi_f32 = hi_f.to(tl.float32)
82+
83+
# Inline PTX: cvt.rn.satfinite.e2m1x2.f32 takes two f32 sources and produces one .b8 packed e2m1x2.
84+
out_tensor = tl.inline_asm_elementwise(
85+
"""
86+
{
87+
.reg .b8 r;
88+
cvt.rn.satfinite.e2m1x2.f32 r, $1, $2;
89+
mov.b32 $0, {r, r, r, r};
90+
}
91+
""",
92+
constraints="=r,f,f",
93+
args=[hi_f32, lo_f32],
94+
dtype=tl.uint8,
95+
is_pure=True,
96+
pack=1,
97+
)
7598
else:
7699
quant_tensor = quant_tensor.to(tl.uint32, bitcast=True)
77100
signs = quant_tensor & 0x80000000
78101
exponents = (quant_tensor >> 23) & 0xFF
79-
mantissas = (quant_tensor & 0x7FFFFF)
102+
mantissas_orig = (quant_tensor & 0x7FFFFF)
80103

81104
# For RTNE: 0.25 < x < 0.75 maps to 0.5 (denormal); exactly 0.25 maps to 0.0
82105
E8_BIAS = 127
83106
E2_BIAS = 1
84107
# Move implicit bit 1 at the beginning to mantissa for denormals
108+
is_subnormal = exponents < E8_BIAS
85109
adjusted_exponents = tl.core.sub(E8_BIAS, exponents + 1, sanitize_overflow=False)
86-
mantissas = tl.where(exponents < E8_BIAS, (0x400000 | (mantissas >> 1)) >> adjusted_exponents, mantissas)
110+
mantissas_pre = (0x400000 | (mantissas_orig >> 1))
111+
mantissas = tl.where(is_subnormal, mantissas_pre >> adjusted_exponents, mantissas_orig)
87112

88113
# For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0.
89114
exponents = tl.maximum(exponents, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
@@ -93,7 +118,15 @@ def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.con
93118
m2bits = mantissas >> 21
94119
lsb_keep = (m2bits >> 1) & 0x1
95120
guard = m2bits & 0x1
96-
sticky = ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
121+
IS_SRC_FP32: tl.constexpr = src_tensor.dtype == tl.float32
122+
if IS_SRC_FP32:
123+
bit0_dropped = (mantissas_orig & 0x1) != 0
124+
mask = (1 << tl.minimum(adjusted_exponents, 31)) - 1
125+
dropped_post = (mantissas_pre & mask) != 0
126+
sticky = is_subnormal & (bit0_dropped | dropped_post)
127+
sticky |= ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
128+
else:
129+
sticky = ((mantissas & 0x1FFFFF) != 0).to(tl.uint32)
97130
round_inc = guard & (sticky | lsb_keep)
98131
e2m1_tmp = tl.minimum((((exponents << 2) | m2bits) + round_inc) >> 1, 0x7)
99132
e2m1_value = ((signs >> 28) | e2m1_tmp).to(tl.uint8)
@@ -105,12 +138,14 @@ def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.con
105138
return out_tensor, dequant_scale_exponent
106139

107140
@triton.jit
108-
def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
109-
mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
110-
src_ptr, stride_src_outer, stride_src_quant,
111-
outer_dim, quant_dim,
112-
BLOCK_SIZE_OUT_DIM: tl.constexpr, BLOCK_SIZE_QUANT_DIM: tl.constexpr,
113-
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr):
141+
def _downcast_to_mxfp(
142+
mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.constexpr,
143+
mx_scale_ptr, stride_mx_scale_outer, stride_mx_scale_quant,
144+
src_ptr, stride_src_outer, stride_src_quant, outer_dim, quant_dim,
145+
BLOCK_SIZE_OUT_DIM:tl.constexpr,
146+
BLOCK_SIZE_QUANT_DIM: tl.constexpr,
147+
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr,
148+
):
114149

115150
tl.static_assert(stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1.")
116151
tl.static_assert(BLOCK_SIZE_QUANT_DIM % MXFP_BLOCK_SIZE == 0, f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32")
@@ -150,10 +185,10 @@ def _downcast_to_mxfp(mx_tensor_ptr, stride_mxt_outer, stride_mxt_quant: tl.cons
150185
mask_n = start_out + offs_outer < outer_dim
151186
full_mask_src = mask_src_quant & mask_n
152187

153-
mask_mxt_quant = start_mx_quant + offs_mxt_quant < tl.cdiv(quant_dim, K_DIVISOR)
188+
mask_mxt_quant = start_mx_quant + offs_mxt_quant < quant_dim // K_DIVISOR # requires quant_dim % K_DIVISOR == 0
154189
full_mask_mxt = mask_mxt_quant & mask_n
155190

156-
scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, MXFP_BLOCK_SIZE)
191+
scale_mask_k = start_mx_scale_quant + offs_scale_quant < quant_dim // MXFP_BLOCK_SIZE # requires quant_dim % MXFP_BLOCK_SIZE == 0
157192
full_scale_mask = scale_mask_k & mask_n
158193

159194
src_tensor_offsets = offs_src_quant * stride_src_quant + offs_outer * stride_src_outer

0 commit comments

Comments
 (0)