Skip to content

Commit 7088c64

Browse files
authored
[MXFP] Fix packing for mxfp4 type (#5197)
When packing we should have element 0 in the lower bits, until this PR it was in higher bits.
1 parent 1214ac7 commit 7088c64

File tree

4 files changed

+29
-29
lines changed

4 files changed

+29
-29
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -752,26 +752,26 @@ SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
752752
ArrayRef<Value> values) {
753753
SmallVector<Value> results;
754754
for (auto v : values) {
755-
auto em0 = and_(v, i8_val(0x70));
756-
auto em1 = and_(v, i8_val(0x7));
757-
Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)),
758-
shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8)));
759-
Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)),
755+
auto em0 = and_(v, i8_val(0x7));
756+
auto em1 = and_(v, i8_val(0x70));
757+
Value v0 = or_(shl(zext(i16_ty, em0), i16_val(6)),
760758
shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12)));
759+
Value v1 = or_(shl(zext(i16_ty, em1), i16_val(2)),
760+
shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8)));
761761

762762
// Three cases:
763763
// 1) x is normal and non-zero: Correct bias
764-
v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)),
764+
v0 = select(icmp_ne(and_(em0, i8_val(0x6)), i8_val(0)),
765765
add(v0, i16_val((127 - 1) << 7)), v0);
766-
v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)),
766+
v1 = select(icmp_ne(and_(em1, i8_val(0x60)), i8_val(0)),
767767
add(v1, i16_val((127 - 1) << 7)), v1);
768768

769769
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
770770
// bf16
771-
v0 = bitcast(select(icmp_eq(em0, i8_val(0x10)),
771+
v0 = bitcast(select(icmp_eq(em0, i8_val(0x1)),
772772
or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0),
773773
bf16_ty);
774-
v1 = bitcast(select(icmp_eq(em1, i8_val(0x1)),
774+
v1 = bitcast(select(icmp_eq(em1, i8_val(0x10)),
775775
or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1),
776776
bf16_ty);
777777
// 3) x is zero, nothing to do

python/test/unit/language/test_core.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3469,17 +3469,17 @@ def mxfp_to_bf16_kernel(
34693469
x_bf16 = x_f8.to(tl.bfloat16)
34703470
else:
34713471
# e2m1
3472-
em0 = x & 0x70
3473-
em1 = x & 0x7
3474-
x0 = (em0.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << 8)
3475-
x1 = (em1.to(tl.uint16) << (2 + 4)) | ((x & 0x8).to(tl.uint16) << (8 + 4))
3472+
em0 = x & 0x7
3473+
em1 = x & 0x70
3474+
x0 = (em0.to(tl.uint16) << 2 + 4) | ((x & 0x8).to(tl.uint16) << 8 + 4)
3475+
x1 = (em1.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << (8))
34763476
# Three cases:
34773477
# 1) x is normal and non-zero: Correct bias
3478-
x0 = tl.where((em0 & 0x60) != 0, x0 + ((127 - 1) << 7), x0)
3479-
x1 = tl.where((em1 & 0x6) != 0, x1 + ((127 - 1) << 7), x1)
3478+
x0 = tl.where((em0 & 0x6) != 0, x0 + ((127 - 1) << 7), x0)
3479+
x1 = tl.where((em1 & 0x60) != 0, x1 + ((127 - 1) << 7), x1)
34803480
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16
3481-
x0 = tl.where(em0 == 0x10, 16128 | (x0 & 0x8000), x0)
3482-
x1 = tl.where(em1 == 0x1, 16128 | (x1 & 0x8000), x1)
3481+
x0 = tl.where(em0 == 0x1, 16128 | (x0 & 0x8000), x0)
3482+
x1 = tl.where(em1 == 0x10, 16128 | (x1 & 0x8000), x1)
34833483
# 3) x is zero, do nothing
34843484
x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True)
34853485
# Multiplication preserves infs and NaNs in x_bf16

python/test/unit/language/test_pipeliner.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,17 +160,17 @@ def mxfp_to_bf16_kernel(
160160
x_bf16 = x_f8.to(tl.bfloat16)
161161
else:
162162
# e2m1
163-
em0 = x & 0x70
164-
em1 = x & 0x7
165-
x0 = (em0.to(tl.uint16) << 2) | ((x & 0x80).to(tl.uint16) << 8)
166-
x1 = (em1.to(tl.uint16) << (2 + 4)) | ((x & 0x8).to(tl.uint16) << (8 + 4))
163+
em0 = x & 0x7
164+
em1 = x & 0x70
165+
x0 = (em0.to(tl.uint16) << 2 + 4) | ((x & 0x8).to(tl.uint16) << 8 + 4)
166+
x1 = (em1.to(tl.uint16) << (2)) | ((x & 0x80).to(tl.uint16) << (8))
167167
# Three cases:
168168
# 1) x is normal and non-zero: Correct bias
169-
x0 = tl.where((em0 & 0x60) != 0, x0 + ((127 - 1) << 7), x0)
170-
x1 = tl.where((em1 & 0x6) != 0, x1 + ((127 - 1) << 7), x1)
169+
x0 = tl.where((em0 & 0x6) != 0, x0 + ((127 - 1) << 7), x0)
170+
x1 = tl.where((em1 & 0x60) != 0, x1 + ((127 - 1) << 7), x1)
171171
# 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in bf16
172-
x0 = tl.where(em0 == 0x10, 16128 | (x0 & 0x8000), x0)
173-
x1 = tl.where(em1 == 0x1, 16128 | (x1 & 0x8000), x1)
172+
x0 = tl.where(em0 == 0x1, 16128 | (x0 & 0x8000), x0)
173+
x1 = tl.where(em1 == 0x10, 16128 | (x1 & 0x8000), x1)
174174
# 3) x is zero, do nothing
175175
x_bf16 = tl.interleave(x0, x1).to(tl.bfloat16, bitcast=True)
176176
# Multiplication preserves infs and NaNs in x_bf16

python/triton/language/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,16 +1647,16 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None,
16471647
lhs and rhs use microscaling formats described here:
16481648
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
16491649
:param lhs: The first tensor to be multiplied.
1650-
:type lhs: 2D tensor representing fp4 or fp8 elements packed into uint8 for fp4 inputs, or in uint8 or the corresponding fp8 type for fp8 inputs.
1650+
:type lhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
16511651
:param lhs_scale: Scale factor for lhs tensor.
16521652
:type lhs_scale: e8m0 type represented as an uint8 tensor.
1653-
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`}.
1653+
:param lhs_format: format of the lhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`}.
16541654
:type lhs_format: str
16551655
:param rhs: The second tensor to be multiplied.
1656-
:type rhs: 2D tensor representing fp8 or bf16 elements in uint8 or the corresponding fp8 type for fp8 inputs or bf16 for bf16 inputs.
1656+
:type rhs: 2D tensor representing fp4, fp8 or bf16 elements. Fp4 elements are packed into uint8 inputs with the first element in lower bits. Fp8 are stored as uint8 or the corresponding fp8 type.
16571657
:param rhs_scale: Scale factor for rhs tensor.
16581658
:type rhs_scale: e8m0 type represented as an uint8 tensor.
1659-
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e4m3`, :code:`e5m2`, :code:`bf16`}.
1659+
:param rhs_format: format of the rhs tensor. Available formats: {:code:`e2m1`, :code:`e4m3`, :code:`e5m2`, :code:`bf16`}.
16601660
:type rhs_format: str
16611661
:param acc: The accumulator tensor. If not None, the result is added to this tensor.
16621662
"""

0 commit comments

Comments
 (0)