Skip to content

Commit f4e4188

Browse files
fix: Triton MXFP4 mantissa rounding (ROCm#975)
* fix: MXFP4 mantissa rounding * fix: mantissa rounding in test_quant_mxfp4 * refactor dynamic_mxfp4_quant * chore: format * fix: mxfp4 quantization tests * chore: format * fix: mxfp4 quantization test with correct bitwidth and sign * chore: restore DEBUG_MODE * chore: align test_quant_mxfp4 with triton kernel --------- Co-authored-by: lucas-santos-amd <Lucas.Santos@amd.com>
1 parent 8e4d703 commit f4e4188

File tree

2 files changed

+107
-36
lines changed

2 files changed

+107
-36
lines changed

aiter/ops/triton/_triton_kernels/quant.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,16 @@ def _mxfp4_quant_op(
9494
x: [BLOCK_SIZE_M, BLOCK_SIZE_N], fp32
9595
9696
"""
97+
EXP_BIAS_FP32: tl.constexpr = 127
98+
EXP_BIAS_FP4: tl.constexpr = 1
99+
EBITS_F32: tl.constexpr = 8
100+
EBITS_FP4: tl.constexpr = 2
101+
MBITS_F32: tl.constexpr = 23
102+
MBITS_FP4: tl.constexpr = 1
103+
104+
max_normal: tl.constexpr = 6
105+
min_normal: tl.constexpr = 1
106+
97107
NUM_QUANT_BLOCKS: tl.constexpr = BLOCK_SIZE_N // MXFP4_QUANT_BLOCK_SIZE
98108
x = x.reshape(BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE)
99109
# Calculate scale
@@ -125,26 +135,49 @@ def _mxfp4_quant_op(
125135
# S111 -> +/- 6.0
126136
qx = qx.to(tl.uint32, bitcast=True)
127137

128-
# Extract sign, exponents and mantissa fields from FP32
138+
# Extract sign
129139
s = qx & 0x80000000
130-
e = (qx >> 23) & 0xFF
131-
m = qx & 0x7FFFFF
132-
E8_BIAS: tl.constexpr = 127
133-
E2_BIAS: tl.constexpr = 1
140+
# Set everything to positive, will add sign back at the end
141+
qx = qx ^ s
142+
143+
qx_fp32 = qx.to(tl.float32, bitcast=True)
144+
saturate_mask = qx_fp32 >= max_normal
145+
denormal_mask = (not saturate_mask) & (qx_fp32 < min_normal)
146+
normal_mask = not (saturate_mask | denormal_mask)
134147

135148
# Denormal numbers
136-
# If exponent is less than 127, then it's a denormal number
137-
# See above, for denormal number mantissa is always 1 and we set bit 1 of mantissa
138-
adjusted_exponents = tl.core.sub(E8_BIAS, e + 1, sanitize_overflow=False)
139-
m = tl.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m)
140-
# For normal numbers, bias is changed from 127 to 1, and for subnormals, we keep exponent as 0.
141-
# Note: E8_BIAS - E2_BIAS = 126, so for normals we subtract that.
142-
e = tl.maximum(e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
143-
144-
# Combine sign, exponent, and mantissa, while saturating
145-
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
146-
e2m1_tmp = tl.minimum((((e << 2) | (m >> 21)) + 1) >> 1, 0x7)
147-
e2m1_value = ((s >> 28) | e2m1_tmp).to(tl.uint8)
149+
denorm_exp: tl.constexpr = (
150+
(EXP_BIAS_FP32 - EXP_BIAS_FP4) + (MBITS_F32 - MBITS_FP4) + 1
151+
)
152+
denorm_mask_int: tl.constexpr = denorm_exp << MBITS_F32
153+
denorm_mask_float: tl.constexpr = tl.cast(denorm_mask_int, tl.float32, bitcast=True)
154+
155+
denormal_x = qx_fp32 + denorm_mask_float
156+
denormal_x = denormal_x.to(tl.uint32, bitcast=True)
157+
denormal_x -= denorm_mask_int
158+
denormal_x = denormal_x.to(tl.uint8)
159+
160+
# Normal numbers
161+
normal_x = qx
162+
# resulting mantissa is odd
163+
mant_odd = (normal_x >> (MBITS_F32 - MBITS_FP4)) & 1
164+
# update exponent, rounding bias part 1
165+
val_to_add = ((EXP_BIAS_FP4 - EXP_BIAS_FP32) << MBITS_F32) + (1 << 21) - 1
166+
normal_x += val_to_add
167+
# rounding bias part 2
168+
normal_x += mant_odd
169+
# take the bits!
170+
normal_x = normal_x >> (MBITS_F32 - MBITS_FP4)
171+
normal_x = normal_x.to(tl.uint8)
172+
173+
# Merge results
174+
e2m1_value = tl.full(qx.type.get_block_shapes(), 0x7, dtype=tl.uint8)
175+
e2m1_value = tl.where(normal_mask, normal_x, e2m1_value)
176+
e2m1_value = tl.where(denormal_mask, denormal_x, e2m1_value)
177+
# add sign back
178+
sign_lp = s >> (MBITS_F32 + EBITS_F32 - MBITS_FP4 - EBITS_FP4)
179+
sign_lp = sign_lp.to(tl.uint8)
180+
e2m1_value = e2m1_value | sign_lp
148181
e2m1_value = tl.reshape(
149182
e2m1_value, [BLOCK_SIZE_M, NUM_QUANT_BLOCKS, MXFP4_QUANT_BLOCK_SIZE // 2, 2]
150183
)

op_tests/triton_tests/test_quant_mxfp4.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,16 @@ def torch_dynamic_mxfp4_quant(
2929
"""
3030
# Create padded x. Needed because mxfp4 works with block of 32 elements
3131
MXFP4_QUANT_BLOCK_SIZE = 32
32+
EXP_BIAS_FP32 = 127
33+
EXP_BIAS_FP4 = 1
34+
EBITS_F32 = 8
35+
EBITS_FP4 = 2
36+
MBITS_F32 = 23
37+
MBITS_FP4 = 1
38+
max_normal = 6
39+
min_normal = 1
40+
sign_mask = 1 << (EBITS_FP4 + MBITS_FP4)
41+
3242
x_shape = x.shape
3343
if x.shape[-1] % MXFP4_QUANT_BLOCK_SIZE != 0:
3444
shape = list(x_shape)
@@ -78,29 +88,57 @@ def torch_dynamic_mxfp4_quant(
7888
# Convert quantized fp32 tensor to int32 before converting to mxfp4 format
7989
qx = qx.view(torch.int32)
8090

81-
# Extract sign, exponents and mantissa fields from int32
91+
# Extract sign
8292
s = qx & 0x80000000
83-
e = (qx >> 23) & 0xFF
84-
m = qx & 0x7FFFFF
93+
# Set everything to positive, will add sign back at the end
94+
qx = qx ^ s
8595

86-
E8_BIAS = 127
87-
E2_BIAS = 1
96+
qx_fp32 = qx.view(torch.float32)
97+
saturate_mask = qx_fp32 >= max_normal
98+
denormal_mask = torch.logical_and(
99+
torch.logical_not(saturate_mask), qx_fp32 < min_normal
100+
)
101+
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))
88102

89103
# Denormal numbers
90-
# If exponent is less than 127, then it's a denormal number
91-
# See above, for denormal number mantissa is always 1 and we set bit 1 of mantissa
92-
adjusted_exponents = E8_BIAS - e - 1
93-
m = torch.where(e < E8_BIAS, (0x400000 | (m >> 1)) >> adjusted_exponents, m)
94-
95-
# For normal numbers, bias is changed from 127 to 1, and for subnormals, we keep exponent as 0.
96-
# Note: E8_BIAS - E2_BIAS = 126, so for normals we subtract that.
97-
e = torch.where(e > E8_BIAS - E2_BIAS, e, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS)
98-
99-
# Combine sign, exponent, and mantissa, while saturating
100-
# rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right
101-
combined_val = (((e << 2) | (m >> 21)) + 1) >> 1
102-
e2m1_tmp = torch.where(combined_val < 0x7, combined_val, 0x7)
103-
e2m1_value = (((s >> 28) & 0xF) | e2m1_tmp).to(torch.uint8)
104+
denorm_exp = (EXP_BIAS_FP32 - EXP_BIAS_FP4) + (MBITS_F32 - MBITS_FP4) + 1
105+
denorm_mask_int = denorm_exp << MBITS_F32
106+
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(
107+
torch.float32
108+
)
109+
110+
denormal_x = qx_fp32 + denorm_mask_float
111+
denormal_x = denormal_x.view(torch.int32)
112+
denormal_x -= denorm_mask_int
113+
denormal_x = denormal_x.to(torch.uint8)
114+
115+
# Normal numbers
116+
normal_x = qx
117+
# resulting mantissa is odd
118+
mant_odd = (normal_x >> (MBITS_F32 - MBITS_FP4)) & 1
119+
# update exponent, rounding bias part 1
120+
val_to_add = ((EXP_BIAS_FP4 - EXP_BIAS_FP32) << MBITS_F32) + (1 << 21) - 1
121+
normal_x += val_to_add
122+
# rounding bias part 2
123+
normal_x += mant_odd
124+
# take the bits!
125+
normal_x = normal_x >> (MBITS_F32 - MBITS_FP4)
126+
normal_x = normal_x.to(torch.uint8)
127+
128+
# Merge results
129+
e2m1_value = torch.full_like(qx, 0x7, dtype=torch.uint8)
130+
e2m1_value = torch.where(normal_mask, normal_x, e2m1_value)
131+
e2m1_value = torch.where(denormal_mask, denormal_x, e2m1_value)
132+
133+
# add sign back
134+
sign_lp = s >> (MBITS_F32 + EBITS_F32 - MBITS_FP4 - EBITS_FP4)
135+
sign_lp = sign_lp.to(torch.uint8)
136+
# Right shift of a negative signed integer can fill the least significant
137+
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
138+
# doesn't have an uint32 dtype, we mask out these bits to get just the
139+
# f4 sign bit
140+
sign_lp = sign_lp & sign_mask
141+
e2m1_value = e2m1_value | sign_lp
104142

105143
# Pack 2 4-bit values into 8-bit
106144
x_mxfp4 = e2m1_value[..., ::2] | (e2m1_value[..., 1::2] << 4)

0 commit comments

Comments
 (0)