Skip to content

Commit 3e072be

Browse files
Insertion point should be the dominant op
1 parent c86c241 commit 3e072be

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include "mlir/Dialect/Math/IR/Math.h"
1010
#include "mlir/Dialect/Math/Transforms/Passes.h"
1111
#include "mlir/IR/PatternMatch.h"
12-
#include "mlir/Pass/Pass.h"
12+
// #include "mlir/Pass/Pass.h"
1313
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1414

1515
using namespace mlir;
@@ -39,10 +39,14 @@ struct SincosFusionPattern : OpRewritePattern<math::SinOp> {
3939
if (!cosOp)
4040
return failure();
4141

42+
Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation()
43+
: cosOp.getOperation();
44+
rewriter.setInsertionPoint(firstOp);
45+
4246
Type elemType = sinOp.getType();
43-
auto sincos = rewriter.create<math::SincosOp>(
44-
sinOp.getLoc(), TypeRange{elemType, elemType}, operand,
45-
sinOp.getFastmathAttr());
47+
auto sincos = math::SincosOp::create(rewriter, firstOp->getLoc(),
48+
TypeRange{elemType, elemType}, operand,
49+
sinOp.getFastmathAttr());
4650

4751
rewriter.replaceOp(sinOp, sincos.getSin());
4852
rewriter.replaceOp(cosOp, sincos.getCos());

mlir/test/Dialect/Math/sincos-fusion.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,28 @@ func.func @sincos_fusion(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
1717
func.return %0, %1, %2, %3 : f32, f32, f32, f32
1818
}
1919

20+
func.func private @sink(%arg0 : f32)
21+
22+
// CHECK: func.func private @sink(f32)
23+
// CHECK-LABEL: func.func @sincos_ensure_ssa_dominance(
24+
// CHECK-SAME: %[[ARG0:.*]]: f32,
25+
// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) {
26+
// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32
27+
// CHECK: call @sink(%[[VAL_0]]) : (f32) -> ()
28+
// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32
29+
// CHECK: call @sink(%[[VAL_3]]) : (f32) -> ()
30+
// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32
31+
// CHECK: }
32+
func.func @sincos_ensure_ssa_dominance(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) {
33+
%0 = math.sin %arg0 : f32
34+
func.call @sink(%0) : (f32) -> ()
35+
%1 = math.cos %arg0 : f32
36+
%2 = math.cos %arg1 : f32
37+
func.call @sink(%2) : (f32) -> ()
38+
%3 = math.sin %arg1 : f32
39+
func.return %0, %1, %2, %3 : f32, f32, f32, f32
40+
}
41+
2042
// CHECK-LABEL: func.func @sincos_fusion_no_match_fmf(
2143
// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) {
2244
// CHECK: %[[VAL_0:.*]] = math.sin %[[ARG0]] fastmath<contract> : f32

0 commit comments

Comments
 (0)