Skip to content

Commit 6faad70

Browse files
committed
Add fastmath attribute.
Update op description. Update tests.
1 parent 52182f1 commit 6faad70

File tree

8 files changed

+63
-47
lines changed

8 files changed

+63
-47
lines changed

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,9 +1332,11 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
13321332
auto complexTy = mlir::cast<mlir::ComplexType>(mathLibFuncType.getInput(0));
13331333
mlir::Value exp = args[1];
13341334
mlir::Value result;
1335-
if (mlir::isa<mlir::IntegerType>(exp.getType()) ||
1336-
mlir::isa<mlir::IndexType>(exp.getType())) {
1337-
result = builder.create<mlir::complex::PowiOp>(loc, args[0], exp);
1335+
auto fmfAttr = mlir::arith::FastMathFlagsAttr::get(
1336+
builder.getContext(), builder.getFastMathFlags());
1337+
if (mlir::isa<mlir::IntegerType>(exp.getType())) {
1338+
result = builder.create<mlir::complex::PowiOp>(
1339+
loc, mathLibFuncType.getResult(0), args[0], args[1], fmfAttr);
13381340
} else {
13391341
if (!mlir::isa<mlir::ComplexType>(exp.getType())) {
13401342
auto realTy = complexTy.getElementType();
@@ -1343,7 +1345,7 @@ mlir::Value genComplexPow(fir::FirOpBuilder &builder, mlir::Location loc,
13431345
exp = builder.create<mlir::complex::CreateOp>(loc, complexTy, realExp,
13441346
zero);
13451347
}
1346-
result = builder.create<mlir::complex::PowOp>(loc, args[0], exp);
1348+
result = builder.create<mlir::complex::PowOp>(loc, args[0], exp, fmfAttr);
13471349
}
13481350
result = builder.createConvert(loc, mathLibFuncType.getResult(0), result);
13491351
return result;

flang/test/Lower/HLFIR/binary-ops.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ subroutine complex_to_int_power(x, y, z)
193193
! CHECK: %[[VAL_5:.*]]:2 = hlfir.declare %{{.*}}z"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
194194
! CHECK: %[[VAL_6:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<complex<f32>>
195195
! CHECK: %[[VAL_7:.*]] = fir.load %[[VAL_5]]#0 : !fir.ref<i32>
196-
! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] : complex<f32>, i32
196+
! CHECK: %[[VAL_8:.*]] = complex.powi %[[VAL_6]], %[[VAL_7]] fastmath<contract> : complex<f32>, i32
197197

198198
subroutine extremum(c, n, l)
199199
integer(8), intent(in) :: l

flang/test/Lower/Intrinsics/pow_complex16i.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
55

66
! PRECISE: fir.call @_FortranAcqpowi({{.*}}){{.*}}: (complex<f128>, i32) -> complex<f128>
7-
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
7+
! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
88
complex(16) :: a
99
integer(4) :: b
1010
b = a ** b

flang/test/Lower/Intrinsics/pow_complex16k.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
! RUN: %flang_fc1 -emit-fir %s -o - | FileCheck %s
55

66
! PRECISE: fir.call @_FortranAcqpowk({{.*}}){{.*}}: (complex<f128>, i64) -> complex<f128>
7-
! CHECK: complex.powi %{{.*}}, %{{.*}} : complex<f128>
7+
! CHECK: complex.powi %{{.*}}, %{{.*}} fastmath<contract> : complex<f128>
88
complex(16) :: a
99
integer(8) :: b
1010
b = a ** b

flang/test/Transforms/convert-complex-pow.fir

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,51 +2,38 @@
22

33
module {
44
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
5-
%c0 = arith.constant 0.0 : f32
6-
%0 = fir.convert %arg1 : (i32) -> f32
7-
%1 = complex.create %0, %c0 : complex<f32>
8-
%2 = complex.pow %arg0, %1 : complex<f32>
9-
return %2 : complex<f32>
5+
%0 = complex.powi %arg0, %arg1 : complex<f32>, i32
6+
return %0 : complex<f32>
7+
}
8+
9+
func.func @pow_c4_i4_fast(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
10+
%0 = complex.powi %arg0, %arg1 fastmath<fast> : complex<f32>, i32
11+
return %0 : complex<f32>
1012
}
1113

1214
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
13-
%c0 = arith.constant 0.0 : f32
14-
%0 = fir.convert %arg1 : (i64) -> f32
15-
%1 = complex.create %0, %c0 : complex<f32>
16-
%2 = complex.pow %arg0, %1 : complex<f32>
17-
return %2 : complex<f32>
15+
%0 = complex.powi %arg0, %arg1 : complex<f32>, i64
16+
return %0 : complex<f32>
1817
}
1918

2019
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
21-
%c0 = arith.constant 0.0 : f64
22-
%0 = fir.convert %arg1 : (i32) -> f64
23-
%1 = complex.create %0, %c0 : complex<f64>
24-
%2 = complex.pow %arg0, %1 : complex<f64>
25-
return %2 : complex<f64>
20+
%0 = complex.powi %arg0, %arg1 : complex<f64>, i32
21+
return %0 : complex<f64>
2622
}
2723

2824
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
29-
%c0 = arith.constant 0.0 : f64
30-
%0 = fir.convert %arg1 : (i64) -> f64
31-
%1 = complex.create %0, %c0 : complex<f64>
32-
%2 = complex.pow %arg0, %1 : complex<f64>
33-
return %2 : complex<f64>
25+
%0 = complex.powi %arg0, %arg1 : complex<f64>, i64
26+
return %0 : complex<f64>
3427
}
3528

3629
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
37-
%c0 = arith.constant 0.0 : f128
38-
%0 = fir.convert %arg1 : (i32) -> f128
39-
%1 = complex.create %0, %c0 : complex<f128>
40-
%2 = complex.pow %arg0, %1 : complex<f128>
41-
return %2 : complex<f128>
30+
%0 = complex.powi %arg0, %arg1 : complex<f128>, i32
31+
return %0 : complex<f128>
4232
}
4333

4434
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
45-
%c0 = arith.constant 0.0 : f128
46-
%0 = fir.convert %arg1 : (i64) -> f128
47-
%1 = complex.create %0, %c0 : complex<f128>
48-
%2 = complex.pow %arg0, %1 : complex<f128>
49-
return %2 : complex<f128>
35+
%0 = complex.powi %arg0, %arg1 : complex<f128>, i64
36+
return %0 : complex<f128>
5037
}
5138

5239
func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
@@ -74,26 +61,37 @@ module {
7461
// CHECK-LABEL: func.func @pow_c4_i4(
7562
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
7663
// CHECK-NOT: complex.pow
64+
// CHECK-NOT: complex.powi
65+
66+
// CHECK-LABEL: func.func @pow_c4_i4_fast(
67+
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath<fast> : (complex<f32>, i32) -> complex<f32>
68+
// CHECK-NOT: complex.pow
69+
// CHECK-NOT: complex.powi
7770

7871
// CHECK-LABEL: func.func @pow_c4_i8(
7972
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
8073
// CHECK-NOT: complex.pow
74+
// CHECK-NOT: complex.powi
8175

8276
// CHECK-LABEL: func.func @pow_c8_i4(
8377
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
8478
// CHECK-NOT: complex.pow
79+
// CHECK-NOT: complex.powi
8580

8681
// CHECK-LABEL: func.func @pow_c8_i8(
8782
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
8883
// CHECK-NOT: complex.pow
84+
// CHECK-NOT: complex.powi
8985

9086
// CHECK-LABEL: func.func @pow_c16_i4(
9187
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
9288
// CHECK-NOT: complex.pow
89+
// CHECK-NOT: complex.powi
9390

9491
// CHECK-LABEL: func.func @pow_c16_i8(
9592
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
9693
// CHECK-NOT: complex.pow
94+
// CHECK-NOT: complex.powi
9795

9896
// CHECK-LABEL: func.func @pow_c4_fast(
9997
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
@@ -108,4 +106,4 @@ module {
108106
// CHECK-LABEL: func.func @pow_c16_complex(
109107
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
110108
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
111-
// CHECK-NOT: complex.pow
109+
// CHECK-NOT: complex.pow

mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -449,10 +449,13 @@ def PowOp : ComplexArithmeticOp<"pow"> {
449449

450450
def PowiOp : Complex_Op<"powi",
451451
[Pure, Elementwise, SameOperandsAndResultShape,
452-
AllTypesMatch<["lhs", "result"]>]> {
453-
let summary = "complex number raised to integer power";
452+
AllTypesMatch<["lhs", "result"]>,
453+
DeclareOpInterfaceMethods<ArithFastMathInterface>]> {
454+
let summary = "complex number raised to signed integer power";
454455
let description = [{
455-
The `powi` operation takes a complex number and an integer exponent.
456+
The `powi` operation takes a `base` operand of complex type and a `power`
457+
operand of signed integer type and returns one result of the same type
458+
as `base`. The result is `base` raised to the power of `power`.
456459

457460
Example:
458461

@@ -462,11 +465,12 @@ def PowiOp : Complex_Op<"powi",
462465
}];
463466

464467
let arguments = (ins Complex<AnyFloat>:$lhs,
465-
AnySignlessInteger:$rhs);
468+
AnySignlessInteger:$rhs,
469+
OptionalAttr<Arith_FastMathAttr>:$fastmath);
466470
let results = (outs Complex<AnyFloat>:$result);
467471

468472
let assemblyFormat =
469-
"$lhs `,` $rhs attr-dict `:` type($result) `,` type($rhs)";
473+
"$lhs `,` $rhs (`fastmath` `` $fastmath^)? attr-dict `:` type($result) `,` type($rhs)";
470474
}
471475

472476
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ struct PowiOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowiOp> {
100100
loc, op.getLhs().getType(), exponentReal, zeroImag);
101101

102102
rewriter.replaceOpWithNewOp<complex::PowOp>(op, op.getType(), op.getLhs(),
103-
exponent);
103+
exponent, op.getFastmathAttr());
104104
return success();
105105
}
106106
};

mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,25 @@ PowIStrengthReduction<PowIOpTy, DivOpTy, MulOpTy>::matchAndRewrite(
217217
// `[fi]powi(x, negative_exponent)`
218218
// with:
219219
// (1 / x) * (1 / x) * (1 / x) * ...
220+
auto buildMul = [&](Value lhs, Value rhs) {
221+
if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
222+
return rewriter.create<MulOpTy>(loc, op.getType(), lhs, rhs,
223+
op.getFastmathAttr());
224+
else
225+
return MulOpTy::create(rewriter, loc, lhs, rhs);
226+
};
220227
for (unsigned i = 1; i < exponentValue; ++i)
221-
result = MulOpTy::create(rewriter, loc, result, base);
228+
result = buildMul(result, base);
222229

223230
// Inverse the base for negative exponent, i.e. for
224231
// `[fi]powi(x, negative_exponent)` set `x` to `1 / x`.
225-
if (exponentIsNegative)
226-
result = DivOpTy::create(rewriter, loc, bcast(one), result);
232+
if (exponentIsNegative) {
233+
if constexpr (std::is_same_v<PowIOpTy, complex::PowiOp>)
234+
result = rewriter.create<DivOpTy>(loc, op.getType(), bcast(one), result,
235+
op.getFastmathAttr());
236+
else
237+
result = DivOpTy::create(rewriter, loc, bcast(one), result);
238+
}
227239

228240
rewriter.replaceOp(op, result);
229241
return success();

0 commit comments

Comments
 (0)