Skip to content

Commit 2632942

Browse files
committed
Propagate fastmath in ComplexToROCDL.
Fix Targetmachine build error.
1 parent 9528edd commit 2632942

File tree

3 files changed

+76
-64
lines changed

3 files changed

+76
-64
lines changed

flang/lib/Frontend/FrontendActions.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,6 @@ void CodeGenAction::generateLLVMIR() {
720720
const CodeGenOptions &opts = invoc.getCodeGenOpts();
721721
const auto &mathOpts = invoc.getLoweringOpts().getMathOptions();
722722
llvm::OptimizationLevel level = mapToLevel(opts);
723-
const llvm::TargetMachine &targetMachine = ci.getTargetMachine();
724723
mlir::DefaultTimingManager &timingMgr = ci.getTimingManager();
725724
mlir::TimingScope &timingScopeRoot = ci.getTimingScopeRoot();
726725

@@ -739,7 +738,8 @@ void CodeGenAction::generateLLVMIR() {
739738
pm.enableVerifier(/*verifyPasses=*/true);
740739

741740
MLIRToLLVMPassPipelineConfig config(level, opts, mathOpts);
742-
config.SkipConvertComplexPow = targetMachine.getTargetTriple().isAMDGCN();
741+
llvm::Triple pipelineTriple(invoc.getTargetOpts().triple);
742+
config.SkipConvertComplexPow = pipelineTriple.isAMDGCN();
743743
fir::registerDefaultInlinerPass(config);
744744

745745
if (auto vsr = getVScaleRange(ci)) {

flang/test/Transforms/convert-complex-pow.fir

Lines changed: 68 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,101 +2,110 @@
22

33
module {
44
func.func @pow_c4_i4(%arg0: complex<f32>, %arg1: i32) -> complex<f32> {
5-
%c0 = arith.constant 0.000000e+00 : f32
6-
%c1 = fir.convert %arg1 : (i32) -> f32
7-
%c2 = complex.create %c1, %c0 : complex<f32>
8-
%0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath<reassoc, contract, nsz>} : complex<f32>
9-
return %0 : complex<f32>
5+
%c0 = arith.constant 0.0 : f32
6+
%0 = fir.convert %arg1 : (i32) -> f32
7+
%1 = complex.create %0, %c0 : complex<f32>
8+
%2 = complex.pow %arg0, %1 : complex<f32>
9+
return %2 : complex<f32>
1010
}
1111

1212
func.func @pow_c4_i8(%arg0: complex<f32>, %arg1: i64) -> complex<f32> {
13-
%c0 = arith.constant 0.000000e+00 : f32
14-
%c1 = fir.convert %arg1 : (i64) -> f32
15-
%c2 = complex.create %c1, %c0 : complex<f32>
16-
%0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath<reassoc, contract, nsz>} : complex<f32>
17-
return %0 : complex<f32>
18-
}
19-
20-
func.func @pow_c4_c4(%arg0: complex<f32>, %arg1: complex<f32>) -> complex<f32> {
21-
%0 = complex.pow %arg0, %arg1 {fastmath = #arith.fastmath<reassoc, contract, nsz>} : complex<f32>
22-
return %0 : complex<f32>
13+
%c0 = arith.constant 0.0 : f32
14+
%0 = fir.convert %arg1 : (i64) -> f32
15+
%1 = complex.create %0, %c0 : complex<f32>
16+
%2 = complex.pow %arg0, %1 : complex<f32>
17+
return %2 : complex<f32>
2318
}
2419

2520
func.func @pow_c8_i4(%arg0: complex<f64>, %arg1: i32) -> complex<f64> {
26-
%c0 = arith.constant 0.000000e+00 : f64
27-
%c1 = fir.convert %arg1 : (i32) -> f64
28-
%c2 = complex.create %c1, %c0 : complex<f64>
29-
%0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath<reassoc, contract, nsz>} : complex<f64>
30-
return %0 : complex<f64>
21+
%c0 = arith.constant 0.0 : f64
22+
%0 = fir.convert %arg1 : (i32) -> f64
23+
%1 = complex.create %0, %c0 : complex<f64>
24+
%2 = complex.pow %arg0, %1 : complex<f64>
25+
return %2 : complex<f64>
3126
}
3227

3328
func.func @pow_c8_i8(%arg0: complex<f64>, %arg1: i64) -> complex<f64> {
34-
%c0 = arith.constant 0.000000e+00 : f64
35-
%c1 = fir.convert %arg1 : (i64) -> f64
36-
%c2 = complex.create %c1, %c0 : complex<f64>
37-
%0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath<reassoc, contract, nsz>} : complex<f64>
38-
return %0 : complex<f64>
39-
}
40-
41-
func.func @pow_c8_c8(%arg0: complex<f64>, %arg1: complex<f64>) -> complex<f64> {
42-
%0 = complex.pow %arg0, %arg1 {fastmath = #arith.fastmath<reassoc, contract, nsz>} : complex<f64>
43-
return %0 : complex<f64>
29+
%c0 = arith.constant 0.0 : f64
30+
%0 = fir.convert %arg1 : (i64) -> f64
31+
%1 = complex.create %0, %c0 : complex<f64>
32+
%2 = complex.pow %arg0, %1 : complex<f64>
33+
return %2 : complex<f64>
4434
}
4535

4636
func.func @pow_c16_i4(%arg0: complex<f128>, %arg1: i32) -> complex<f128> {
47-
%c0 = arith.constant 0.000000e+00 : f128
48-
%c1 = fir.convert %arg1 : (i32) -> f128
49-
%c2 = complex.create %c1, %c0 : complex<f128>
50-
%0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath<reassoc, contract, nsz>} : complex<f128>
51-
return %0 : complex<f128>
37+
%c0 = arith.constant 0.0 : f128
38+
%0 = fir.convert %arg1 : (i32) -> f128
39+
%1 = complex.create %0, %c0 : complex<f128>
40+
%2 = complex.pow %arg0, %1 : complex<f128>
41+
return %2 : complex<f128>
5242
}
5343

5444
func.func @pow_c16_i8(%arg0: complex<f128>, %arg1: i64) -> complex<f128> {
55-
%c0 = arith.constant 0.000000e+00 : f128
56-
%c1 = fir.convert %arg1 : (i64) -> f128
57-
%c2 = complex.create %c1, %c0 : complex<f128>
58-
%0 = complex.pow %arg0, %c2 {fastmath = #arith.fastmath<reassoc, contract, nsz>} : complex<f128>
59-
return %0 : complex<f128>
45+
%c0 = arith.constant 0.0 : f128
46+
%0 = fir.convert %arg1 : (i64) -> f128
47+
%1 = complex.create %0, %c0 : complex<f128>
48+
%2 = complex.pow %arg0, %1 : complex<f128>
49+
return %2 : complex<f128>
50+
}
51+
52+
func.func @pow_c4_fast(%arg0: complex<f32>, %arg1: f32) -> complex<f32> {
53+
%c1 = arith.constant 1.0 : f32
54+
%0 = complex.create %arg1, %c1 : complex<f32>
55+
%1 = complex.pow %arg0, %0 fastmath<fast> : complex<f32>
56+
return %1 : complex<f32>
57+
}
58+
59+
func.func @pow_c8_complex(%arg0: complex<f64>, %arg1: f64) -> complex<f64> {
60+
%c2 = arith.constant 2.0 : f64
61+
%0 = complex.create %arg1, %c2 : complex<f64>
62+
%1 = complex.pow %arg0, %0 : complex<f64>
63+
return %1 : complex<f64>
6064
}
6165

62-
func.func @pow_c16_c16(%arg0: complex<f128>, %arg1: complex<f128>) -> complex<f128> {
63-
%0 = complex.pow %arg0, %arg1 {fastmath = #arith.fastmath<reassoc, contract, nsz>} : complex<f128>
64-
return %0 : complex<f128>
66+
func.func @pow_c16_complex(%arg0: complex<f128>, %arg1: f128) -> complex<f128> {
67+
%c3 = arith.constant 3.0 : f128
68+
%0 = complex.create %arg1, %c3 : complex<f128>
69+
%1 = complex.pow %arg0, %0 : complex<f128>
70+
return %1 : complex<f128>
6571
}
6672
}
6773

6874
// CHECK-LABEL: func.func @pow_c4_i4(
69-
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) fastmath<reassoc,nsz,contract> : (complex<f32>, i32) -> complex<f32>
75+
// CHECK: fir.call @_FortranAcpowi(%{{.*}}, %{{.*}}) : (complex<f32>, i32) -> complex<f32>
7076
// CHECK-NOT: complex.pow
7177

7278
// CHECK-LABEL: func.func @pow_c4_i8(
73-
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) fastmath<reassoc,nsz,contract> : (complex<f32>, i64) -> complex<f32>
74-
// CHECK-NOT: complex.pow
75-
76-
// CHECK-LABEL: func.func @pow_c4_c4(
77-
// CHECK: fir.call @cpowf(%{{.*}}, %{{.*}}) fastmath<reassoc,nsz,contract> : (complex<f32>, complex<f32>) -> complex<f32>
79+
// CHECK: fir.call @_FortranAcpowk(%{{.*}}, %{{.*}}) : (complex<f32>, i64) -> complex<f32>
7880
// CHECK-NOT: complex.pow
7981

8082
// CHECK-LABEL: func.func @pow_c8_i4(
81-
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) fastmath<reassoc,nsz,contract> : (complex<f64>, i32) -> complex<f64>
83+
// CHECK: fir.call @_FortranAzpowi(%{{.*}}, %{{.*}}) : (complex<f64>, i32) -> complex<f64>
8284
// CHECK-NOT: complex.pow
8385

8486
// CHECK-LABEL: func.func @pow_c8_i8(
85-
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) fastmath<reassoc,nsz,contract> : (complex<f64>, i64) -> complex<f64>
86-
// CHECK-NOT: complex.pow
87-
88-
// CHECK-LABEL: func.func @pow_c8_c8(
89-
// CHECK: fir.call @cpow(%{{.*}}, %{{.*}}) fastmath<reassoc,nsz,contract> : (complex<f64>, complex<f64>) -> complex<f64>
87+
// CHECK: fir.call @_FortranAzpowk(%{{.*}}, %{{.*}}) : (complex<f64>, i64) -> complex<f64>
9088
// CHECK-NOT: complex.pow
9189

9290
// CHECK-LABEL: func.func @pow_c16_i4(
93-
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) fastmath<reassoc,nsz,contract> : (complex<f128>, i32) -> complex<f128>
91+
// CHECK: fir.call @_FortranAcqpowi(%{{.*}}, %{{.*}}) : (complex<f128>, i32) -> complex<f128>
9492
// CHECK-NOT: complex.pow
9593

9694
// CHECK-LABEL: func.func @pow_c16_i8(
97-
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) fastmath<reassoc,nsz,contract> : (complex<f128>, i64) -> complex<f128>
95+
// CHECK: fir.call @_FortranAcqpowk(%{{.*}}, %{{.*}}) : (complex<f128>, i64) -> complex<f128>
9896
// CHECK-NOT: complex.pow
9997

100-
// CHECK-LABEL: func.func @pow_c16_c16(
101-
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %{{.*}}) fastmath<reassoc,nsz,contract> : (complex<f128>, complex<f128>) -> complex<f128>
98+
// CHECK-LABEL: func.func @pow_c4_fast(
99+
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f32>
100+
// CHECK: fir.call @cpowf(%{{.*}}, %[[EXP]]) fastmath<fast> : (complex<f32>, complex<f32>) -> complex<f32>
102101
// CHECK-NOT: complex.pow
102+
103+
// CHECK-LABEL: func.func @pow_c8_complex(
104+
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f64>
105+
// CHECK: fir.call @cpow(%{{.*}}, %[[EXP]]) : (complex<f64>, complex<f64>) -> complex<f64>
106+
// CHECK-NOT: complex.pow
107+
108+
// CHECK-LABEL: func.func @pow_c16_complex(
109+
// CHECK: %[[EXP:.*]] = complex.create %{{.*}}, %{{.*}} : complex<f128>
110+
// CHECK: fir.call @_FortranACPowF128(%{{.*}}, %[[EXP]]) : (complex<f128>, complex<f128>) -> complex<f128>
111+
// CHECK-NOT: complex.pow

mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,12 @@ struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> {
6464
LogicalResult matchAndRewrite(complex::PowOp op,
6565
PatternRewriter &rewriter) const final {
6666
Location loc = op.getLoc();
67-
Value logBase = complex::LogOp::create(rewriter, loc, op.getLhs());
68-
Value mul = complex::MulOp::create(rewriter, loc, op.getRhs(), logBase);
69-
Value exp = complex::ExpOp::create(rewriter, loc, mul);
67+
auto fastmath = op.getFastmathAttr();
68+
Value logBase =
69+
complex::LogOp::create(rewriter, loc, op.getLhs(), fastmath);
70+
Value mul =
71+
complex::MulOp::create(rewriter, loc, op.getRhs(), logBase, fastmath);
72+
Value exp = complex::ExpOp::create(rewriter, loc, mul, fastmath);
7073
rewriter.replaceOp(op, exp);
7174
return success();
7275
}

0 commit comments

Comments
 (0)