@@ -234,12 +234,17 @@ def _unpack_fp4_to_bf16_triton(x):
234
234
r"""
235
235
{
236
236
.reg .b32 b, c, d<7>, scale;
237
+ .reg .b32 bias;
238
+ mov.b32 bias, 0x7e807e80; // 2 ** 126 == 2 ** (bias_bf16 - bias_fp2)
237
239
// We add the missing bias to the scale directly
238
240
and.b32 $0, $4, 0b10000001110000001000000111000000;
241
+ mul.bf16x2 $0, $0, bias;
239
242
shl.b32 b, $4, 3;
240
243
and.b32 $1, b, 0b10000001110000001000000111000000;
244
+ mul.bf16x2 $1, $1, bias;
241
245
shl.b32 c, $4, 6;
242
246
and.b32 $2, c, 0b10000001110000001000000111000000;
247
+ mul.bf16x2 $2, $2, bias;
243
248
// Unpack last two elements
244
249
shl.b32 d0, $4, 1;
245
250
and.b32 d1, d0, 0b10000000000000001000000000000000;
@@ -249,6 +254,7 @@ def _unpack_fp4_to_bf16_triton(x):
249
254
shr.b32 d5, $4, 7;
250
255
and.b32 d6, d5, 0b00000000010000000000000001000000;
251
256
or.b32 $3, d4, d6;
257
+ mul.bf16x2 $3, $3, bias;
252
258
}
253
259
""" ,
254
260
constraints = "=r,=r,=r,=r,r" ,
@@ -289,15 +295,12 @@ def mxfp4_to_bf16_triton(x, scale, mx_axis: tl.constexpr):
289
295
# upcast scale to bfloat16
290
296
# Add bias missing from the bf16 upcasting sequence
291
297
# triton / LLVM generates terrible code for this sequence
292
- # scale += 126
293
- #scale = scale.to(tl.uint16)
294
- #scale = scale << 7
295
- #scale = scale.to(tl.bfloat16, bitcast=True)
298
+ # scale = scale.to(tl.uint16)
299
+ # scale = scale << 7
300
+ # scale = scale.to(tl.bfloat16, bitcast=True)
296
301
scale = tl .inline_asm_elementwise (
297
302
r"""
298
303
{
299
- // Assumes no overflow
300
- add.u32 $2, $2, 0x7E7E7E7E;
301
304
prmt.b32 $0, $2, 0, 0x5140;
302
305
shl.b32 $0, $0, 7;
303
306
prmt.b32 $1, $2, 0, 0x7362;
0 commit comments