Skip to content

Commit b73c37d

Browse files
jtuylsgithub-actions[bot]
authored andcommitted
Automerge: [mlir][arith][transforms] Fix f4E2M1FN to f32 cast (#160121)
The signed i4 bitcast was used when setting the exponent and mantissa and instead the sign should be omitted in the comparisons. Without this, for example the following incorrect conversion from `-0.5` f4 to `-3.0` f32 will happen: | Binary | F4E2M1 | f32[23:32] | f32 | 1001 | -0.5 | ~~1 1000 000 01~~ | ~~-3.0~~ **Walkthrough:** Bits 23 and 24 are set based on: ``` Value isHalf = arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1); ``` Because `1001 (i4) != 1`, bit 23 and 24 are set to the leading two bits of `1001 << 2`, which is `01`. The correct bits are `00`. Bits 25 through 31 are set based on the i4 value being greater or equal to 4: ``` Value useLargerExp = arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4); ``` As `1001` is a negative i4 value, this is false and those bits are incorrectly set to `1000 000` instead of `0111 111`.
2 parents 7ac61db + faf5f28 commit b73c37d

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,12 +387,15 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
387387
Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter);
388388
Value c0x2 = createConst(loc, i4Ty, 0x2, rewriter);
389389
Value c0x4 = createConst(loc, i4Ty, 0x4, rewriter);
390+
Value c0x7 = createConst(loc, i4Ty, 0x7, rewriter);
391+
392+
Value i4BitsNoSign = arith::AndIOp::create(b, i4Bits, c0x7);
390393

391394
// Set last Exponent bit and Mantissa.
392395
Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter);
393-
Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2);
396+
Value bits1To24 = arith::ShLIOp::create(b, i4BitsNoSign, c0x2);
394397
Value isHalf =
395-
arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1);
398+
arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x1);
396399
bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24);
397400
bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24);
398401
bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014);
@@ -402,11 +405,11 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
402405
Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter);
403406
Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter);
404407
Value useLargerExp =
405-
arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4);
408+
arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4BitsNoSign, c0x4);
406409
Value bits25To31 =
407410
arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits);
408411
Value zeroExp =
409-
arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0);
412+
arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4BitsNoSign, c0x0);
410413
bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31);
411414

412415
// Set sign.

mlir/test/Integration/Dialect/Arith/CPU/test-arith-expand-truncf-extf.mlir

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,18 @@ func.func @entry() {
2828
%zero = arith.constant 0.0 : f32
2929
%half = arith.constant 0.5 : f32
3030
%one = arith.constant 1.0 : f32
31+
%oneAndAHalf = arith.constant 1.5 : f32
32+
%two = arith.constant 2.0 : f32
33+
%three = arith.constant 3.0 : f32
34+
%four = arith.constant 4.0 : f32
3135
%max = arith.constant 6.0 : f32
36+
%minZero = arith.constant -0.0 : f32
37+
%minHalf = arith.constant -0.5 : f32
38+
%minOne = arith.constant -1.0 : f32
39+
%minOneAndAHalf = arith.constant -1.5 : f32
40+
%minTwo = arith.constant -2.0 : f32
41+
%minThree = arith.constant -3.0 : f32
42+
%minFour = arith.constant -4.0 : f32
3243
%min = arith.constant -6.0 : f32
3344
%lowerThanMin = arith.constant -1000000.0 : f32
3445
%higherThanMax = arith.constant 1000000.0 : f32
@@ -41,8 +52,28 @@ func.func @entry() {
4152
func.call @check_truncf(%half) : (f32) -> ()
4253
// CHECK: 2
4354
func.call @check_truncf(%one) : (f32) -> ()
55+
// CHECK: 3
56+
func.call @check_truncf(%oneAndAHalf) : (f32) -> ()
57+
// CHECK: 4
58+
func.call @check_truncf(%two) : (f32) -> ()
59+
// CHECK: 5
60+
func.call @check_truncf(%three) : (f32) -> ()
61+
// CHECK: 6
62+
func.call @check_truncf(%four) : (f32) -> ()
4463
// CHECK: 7
4564
func.call @check_truncf(%max) : (f32) -> ()
65+
// CHECK: 9
66+
func.call @check_truncf(%minHalf) : (f32) -> ()
67+
// CHECK: 10
68+
func.call @check_truncf(%minOne) : (f32) -> ()
69+
// CHECK: 11
70+
func.call @check_truncf(%minOneAndAHalf) : (f32) -> ()
71+
// CHECK: 12
72+
func.call @check_truncf(%minTwo) : (f32) -> ()
73+
// CHECK: 13
74+
func.call @check_truncf(%minThree) : (f32) -> ()
75+
// CHECK: 14
76+
func.call @check_truncf(%minFour) : (f32) -> ()
4677
// CHECK: 15
4778
func.call @check_truncf(%min) : (f32) -> ()
4879
// CHECK: 7
@@ -60,9 +91,45 @@ func.func @entry() {
6091
// CHECK: 0.5
6192
%halfF4 = arith.truncf %half : f32 to f4E2M1FN
6293
func.call @check_extf(%halfF4) : (f4E2M1FN) -> ()
94+
// CHECK: 1
95+
%oneF4 = arith.truncf %one : f32 to f4E2M1FN
96+
func.call @check_extf(%oneF4) : (f4E2M1FN) -> ()
97+
// CHECK: 1.5
98+
%oneAndAHalfF4 = arith.truncf %oneAndAHalf : f32 to f4E2M1FN
99+
func.call @check_extf(%oneAndAHalfF4) : (f4E2M1FN) -> ()
100+
// CHECK: 2
101+
%twoF4 = arith.truncf %two : f32 to f4E2M1FN
102+
func.call @check_extf(%twoF4) : (f4E2M1FN) -> ()
103+
// CHECK: 3
104+
%threeF4 = arith.truncf %three : f32 to f4E2M1FN
105+
func.call @check_extf(%threeF4) : (f4E2M1FN) -> ()
106+
// CHECK: 4
107+
%fourF4 = arith.truncf %four : f32 to f4E2M1FN
108+
func.call @check_extf(%fourF4) : (f4E2M1FN) -> ()
63109
// CHECK: 6
64110
%higherThanMaxF4 = arith.truncf %higherThanMax : f32 to f4E2M1FN
65111
func.call @check_extf(%higherThanMaxF4) : (f4E2M1FN) -> ()
112+
// CHECK: -0
113+
%minZeroF4 = arith.truncf %minZero : f32 to f4E2M1FN
114+
func.call @check_extf(%minZeroF4) : (f4E2M1FN) -> ()
115+
// CHECK: -0.5
116+
%minHalfF4 = arith.truncf %minHalf : f32 to f4E2M1FN
117+
func.call @check_extf(%minHalfF4) : (f4E2M1FN) -> ()
118+
// CHECK: -1
119+
%minOneF4 = arith.truncf %minOne : f32 to f4E2M1FN
120+
func.call @check_extf(%minOneF4) : (f4E2M1FN) -> ()
121+
// CHECK: -1.5
122+
%minOneAndAHalfF4 = arith.truncf %minOneAndAHalf : f32 to f4E2M1FN
123+
func.call @check_extf(%minOneAndAHalfF4) : (f4E2M1FN) -> ()
124+
// CHECK: -2
125+
%minTwoF4 = arith.truncf %minTwo : f32 to f4E2M1FN
126+
func.call @check_extf(%minTwoF4) : (f4E2M1FN) -> ()
127+
// CHECK: -3
128+
%minThreeF4 = arith.truncf %minThree : f32 to f4E2M1FN
129+
func.call @check_extf(%minThreeF4) : (f4E2M1FN) -> ()
130+
// CHECK: -4
131+
%minFourF4 = arith.truncf %minFour : f32 to f4E2M1FN
132+
func.call @check_extf(%minFourF4) : (f4E2M1FN) -> ()
66133
// CHECK: -6
67134
%lowerThanMinF4 = arith.truncf %lowerThanMin : f32 to f4E2M1FN
68135
func.call @check_extf(%lowerThanMinF4) : (f4E2M1FN) -> ()

0 commit comments

Comments
 (0)