Skip to content

Commit ea4c56b

Browse files
committed
Add lit tests for fast math flags
1 parent bfca113 commit ea4c56b

File tree

2 files changed

+87
-3
lines changed

2 files changed

+87
-3
lines changed

mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
4444
Location loc = op.getLoc();
4545
Value x = op.getLhs();
4646
arith::FastMathFlags fmf = op.getFastmathAttr().getValue();
47+
arith::FastMathFlags intermediateFmf = arith::bitEnumClear(
48+
fmf, arith::FastMathFlags::reassoc | arith::FastMathFlags::contract |
49+
arith::FastMathFlags::arcp);
4750

4851
FloatAttr scalarExponent;
4952
DenseFPElementsAttr vectorExponent;
@@ -85,7 +88,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
8588

8689
// Replace `pow(x, 3.0)` with `x * x * x`.
8790
if (isExponentValue(3.0)) {
88-
Value square = arith::MulFOp::create(rewriter, loc, x, x, fmf);
91+
Value square = arith::MulFOp::create(rewriter, loc, x, x, intermediateFmf);
8992
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, square, fmf);
9093
return success();
9194
}
@@ -113,8 +116,9 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
113116

114117
// Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
115118
if (isExponentValue(0.75)) {
116-
Value powHalf = math::SqrtOp::create(rewriter, loc, x, fmf);
117-
Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf, fmf);
119+
Value powHalf = math::SqrtOp::create(rewriter, loc, x, intermediateFmf);
120+
Value powQuarter =
121+
math::SqrtOp::create(rewriter, loc, powHalf, intermediateFmf);
118122
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, powHalf, powQuarter, fmf);
119123
return success();
120124
}

mlir/test/Dialect/Math/algebraic-simplification.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,18 @@ func.func @pow_square(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
2222
return %0, %1 : f32, vector<4xf32>
2323
}
2424

25+
// CHECK-LABEL: @pow_square_fast
26+
func.func @pow_square_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
27+
// CHECK: %[[SCALAR:.*]] = arith.mulf %arg0, %arg0 fastmath<fast>
28+
// CHECK: %[[VECTOR:.*]] = arith.mulf %arg1, %arg1 fastmath<fast>
29+
// CHECK: return %[[SCALAR]], %[[VECTOR]]
30+
%c = arith.constant 2.0 : f32
31+
%v = arith.constant dense <2.0> : vector<4xf32>
32+
%0 = math.powf %arg0, %c fastmath<fast> : f32
33+
%1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
34+
return %0, %1 : f32, vector<4xf32>
35+
}
36+
2537
// CHECK-LABEL: @pow_cube
2638
func.func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
2739
// CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0
@@ -36,6 +48,20 @@ func.func @pow_cube(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
3648
return %0, %1 : f32, vector<4xf32>
3749
}
3850

51+
// CHECK-LABEL: @pow_cube_fast
52+
func.func @pow_cube_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
53+
// CHECK: %[[TMP_S:.*]] = arith.mulf %arg0, %arg0 fastmath<nnan,ninf,nsz,afn>
54+
// CHECK: %[[SCALAR:.*]] = arith.mulf %arg0, %[[TMP_S]] fastmath<fast>
55+
// CHECK: %[[TMP_V:.*]] = arith.mulf %arg1, %arg1 fastmath<nnan,ninf,nsz,afn>
56+
// CHECK: %[[VECTOR:.*]] = arith.mulf %arg1, %[[TMP_V]] fastmath<fast>
57+
// CHECK: return %[[SCALAR]], %[[VECTOR]]
58+
%c = arith.constant 3.0 : f32
59+
%v = arith.constant dense <3.0> : vector<4xf32>
60+
%0 = math.powf %arg0, %c fastmath<fast> : f32
61+
%1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
62+
return %0, %1 : f32, vector<4xf32>
63+
}
64+
3965
// CHECK-LABEL: @pow_recip
4066
func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
4167
// CHECK-DAG: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32
@@ -50,6 +76,20 @@ func.func @pow_recip(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
5076
return %0, %1 : f32, vector<4xf32>
5177
}
5278

79+
// CHECK-LABEL: @pow_recip_fast
80+
func.func @pow_recip_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
81+
// CHECK-DAG: %[[CST_S:.*]] = arith.constant 1.0{{.*}} : f32
82+
// CHECK-DAG: %[[CST_V:.*]] = arith.constant dense<1.0{{.*}}> : vector<4xf32>
83+
// CHECK: %[[SCALAR:.*]] = arith.divf %[[CST_S]], %arg0 fastmath<fast>
84+
// CHECK: %[[VECTOR:.*]] = arith.divf %[[CST_V]], %arg1 fastmath<fast>
85+
// CHECK: return %[[SCALAR]], %[[VECTOR]]
86+
%c = arith.constant -1.0 : f32
87+
%v = arith.constant dense <-1.0> : vector<4xf32>
88+
%0 = math.powf %arg0, %c fastmath<fast> : f32
89+
%1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
90+
return %0, %1 : f32, vector<4xf32>
91+
}
92+
5393
// CHECK-LABEL: @pow_sqrt
5494
func.func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
5595
// CHECK: %[[SCALAR:.*]] = math.sqrt %arg0
@@ -62,6 +102,18 @@ func.func @pow_sqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
62102
return %0, %1 : f32, vector<4xf32>
63103
}
64104

105+
// CHECK-LABEL: @pow_sqrt_fast
106+
func.func @pow_sqrt_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
107+
// CHECK: %[[SCALAR:.*]] = math.sqrt %arg0 fastmath<fast>
108+
// CHECK: %[[VECTOR:.*]] = math.sqrt %arg1 fastmath<fast>
109+
// CHECK: return %[[SCALAR]], %[[VECTOR]]
110+
%c = arith.constant 0.5 : f32
111+
%v = arith.constant dense <0.5> : vector<4xf32>
112+
%0 = math.powf %arg0, %c fastmath<fast> : f32
113+
%1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
114+
return %0, %1 : f32, vector<4xf32>
115+
}
116+
65117
// CHECK-LABEL: @pow_rsqrt
66118
func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
67119
// CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0
@@ -74,6 +126,18 @@ func.func @pow_rsqrt(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>)
74126
return %0, %1 : f32, vector<4xf32>
75127
}
76128

129+
// CHECK-LABEL: @pow_rsqrt_fast
130+
func.func @pow_rsqrt_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
131+
// CHECK: %[[SCALAR:.*]] = math.rsqrt %arg0 fastmath<fast>
132+
// CHECK: %[[VECTOR:.*]] = math.rsqrt %arg1 fastmath<fast>
133+
// CHECK: return %[[SCALAR]], %[[VECTOR]]
134+
%c = arith.constant -0.5 : f32
135+
%v = arith.constant dense <-0.5> : vector<4xf32>
136+
%0 = math.powf %arg0, %c fastmath<fast> : f32
137+
%1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
138+
return %0, %1 : f32, vector<4xf32>
139+
}
140+
77141
// CHECK-LABEL: @pow_0_75
78142
func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
79143
// CHECK: %[[SQRT1S:.*]] = math.sqrt %arg0
@@ -90,6 +154,22 @@ func.func @pow_0_75(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
90154
return %0, %1 : f32, vector<4xf32>
91155
}
92156

157+
// CHECK-LABEL: @pow_0_75_fast
158+
func.func @pow_0_75_fast(%arg0: f32, %arg1 : vector<4xf32>) -> (f32, vector<4xf32>) {
159+
// CHECK: %[[SQRT1S:.*]] = math.sqrt %arg0 fastmath<nnan,ninf,nsz,afn>
160+
// CHECK: %[[SQRT2S:.*]] = math.sqrt %[[SQRT1S]] fastmath<nnan,ninf,nsz,afn>
161+
// CHECK: %[[SCALAR:.*]] = arith.mulf %[[SQRT1S]], %[[SQRT2S]] fastmath<fast>
162+
// CHECK: %[[SQRT1V:.*]] = math.sqrt %arg1 fastmath<nnan,ninf,nsz,afn>
163+
// CHECK: %[[SQRT2V:.*]] = math.sqrt %[[SQRT1V]] fastmath<nnan,ninf,nsz,afn>
164+
// CHECK: %[[VECTOR:.*]] = arith.mulf %[[SQRT1V]], %[[SQRT2V]] fastmath<fast>
165+
// CHECK: return %[[SCALAR]], %[[VECTOR]]
166+
%c = arith.constant 0.75 : f32
167+
%v = arith.constant dense <0.75> : vector<4xf32>
168+
%0 = math.powf %arg0, %c fastmath<fast> : f32
169+
%1 = math.powf %arg1, %v fastmath<fast> : vector<4xf32>
170+
return %0, %1 : f32, vector<4xf32>
171+
}
172+
93173
// CHECK-LABEL: @ipowi_zero_exp(
94174
// CHECK-SAME: %[[ARG0:.+]]: i32
95175
// CHECK-SAME: %[[ARG1:.+]]: vector<4xi32>

0 commit comments

Comments
 (0)