Skip to content

Commit 2bc0672

Browse files
authored
[TRITON_KERNELS] fix swizzling numerics (#7632)
1 parent 3df0da5 commit 2bc0672

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

python/triton_kernels/tests/test_tensor_details/test_layout_hopper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
66
from triton_kernels.tensor_details.layout_details.hopper_value import mxfp4_to_bf16_triton
77
from triton_kernels.tensor_details.layout_details.hopper_scale import unswizzle_mxfp4_scale_hopper
8+
from triton_kernels.target_info import cuda_capability_geq
89
import triton.language as tl
910
import triton
1011
import torch
@@ -71,6 +72,7 @@ def _upcast_mxfp4_to_bf16(Y, X, XScale, x_stride_m, x_stride_n, x_scale_stride_m
7172

7273

7374
@pytest.mark.skipif(not is_cuda(), reason="Only supported on cuda")
75+
@pytest.mark.skipif(not cuda_capability_geq(9), reason="Only supported for capability >= 9")
7476
def test_upcast_mxfp4_to_bf16():
7577
mx_axis = 0
7678
num_warps = 4

python/triton_kernels/triton_kernels/tensor_details/layout_details/hopper_value.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,17 @@ def _unpack_fp4_to_bf16_triton(x):
234234
r"""
235235
{
236236
.reg .b32 b, c, d<7>, scale;
237+
.reg .b32 bias;
238+
mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
237239
// We add the missing bias to the scale directly
238240
and.b32 $0, $4, 0b10000001110000001000000111000000;
241+
mul.bf16x2 $0, $0, bias;
239242
shl.b32 b, $4, 3;
240243
and.b32 $1, b, 0b10000001110000001000000111000000;
244+
mul.bf16x2 $1, $1, bias;
241245
shl.b32 c, $4, 6;
242246
and.b32 $2, c, 0b10000001110000001000000111000000;
247+
mul.bf16x2 $2, $2, bias;
243248
// Unpack last two elements
244249
shl.b32 d0, $4, 1;
245250
and.b32 d1, d0, 0b10000000000000001000000000000000;
@@ -249,6 +254,7 @@ def _unpack_fp4_to_bf16_triton(x):
249254
shr.b32 d5, $4, 7;
250255
and.b32 d6, d5, 0b00000000010000000000000001000000;
251256
or.b32 $3, d4, d6;
257+
mul.bf16x2 $3, $3, bias;
252258
}
253259
""",
254260
constraints="=r,=r,=r,=r,r",
@@ -289,15 +295,12 @@ def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
289295
# upcast scale to bfloat16
290296
# Add bias missing from the bf16 upcasting sequence
291297
# triton / LLVM generates terrible code for this sequence
292-
# scale += 126
293-
#scale = scale.to(tl.uint16)
294-
#scale = scale << 7
295-
#scale = scale.to(tl.bfloat16, bitcast=True)
298+
# scale = scale.to(tl.uint16)
299+
# scale = scale << 7
300+
# scale = scale.to(tl.bfloat16, bitcast=True)
296301
scale = tl.inline_asm_elementwise(
297302
r"""
298303
{
299-
// Assumes no overflow
300-
add.u32 $2, $2, 0x7E7E7E7E;
301304
prmt.b32 $0, $2, 0, 0x5140;
302305
shl.b32 $0, $0, 7;
303306
prmt.b32 $1, $2, 0, 0x7362;

0 commit comments

Comments
 (0)