Skip to content

Commit 397448a

Browse files
handle bound condition.
1 parent ae17efd commit 397448a

File tree

2 files changed

+91
-6
lines changed

2 files changed

+91
-6
lines changed

mlir/lib/Transforms/HoistPureOps.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
#include "mlir/IR/Operation.h"
1717
#include "mlir/Pass/Pass.h"
1818
#include "mlir/Transforms/Passes.h"
19+
#include "llvm/Support/DebugLog.h"
1920

2021
namespace mlir {
2122
#define GEN_PASS_DEF_HOISTPUREOPS
2223
#include "mlir/Transforms/Passes.h.inc"
2324
} // namespace mlir
2425

26+
#define DEBUG_TYPE "hoist-pure-ops"
27+
2528
using namespace mlir;
2629

2730
namespace {
@@ -30,13 +33,22 @@ namespace {
3033
static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
3134
Block *aB = a.getParentBlock();
3235
Block *bB = b.getParentBlock();
33-
if (isa_and_present<BlockArgument>(a) && isa_and_present<BlockArgument>(b)) {
34-
return dominanceInfo.dominates(aB, bB) ? b : a;
35-
} else if (isa_and_present<BlockArgument>(a) ||
36-
isa_and_present<BlockArgument>(b)) {
37-
if (aB == bB)
38-
return b;
36+
if (isa<BlockArgument>(a) && isa<BlockArgument>(b)) {
3937
return dominanceInfo.dominates(aB, bB) ? b : a;
38+
} else if (isa<BlockArgument>(a) || isa<BlockArgument>(b)) {
39+
if (aB != bB)
40+
return dominanceInfo.dominates(aB, bB) ? b : a;
41+
if (auto aArg = dyn_cast<BlockArgument>(a)) {
42+
Operation *aFrontOp = &aArg.getOwner()->front();
43+
if (aFrontOp == b.getDefiningOp())
44+
return b;
45+
return dominanceInfo.dominates(aFrontOp, b.getDefiningOp()) ? b : a;
46+
}
47+
auto bArg = cast<BlockArgument>(b);
48+
Operation *bFrontOp = &bArg.getOwner()->front();
49+
if (bFrontOp == a.getDefiningOp())
50+
return a;
51+
return dominanceInfo.dominates(a.getDefiningOp(), bFrontOp) ? b : a;
4052
} else {
4153
Operation *aDefineOp = a.getDefiningOp();
4254
Operation *bDefineOp = b.getDefiningOp();
@@ -64,10 +76,16 @@ static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
6476
return;
6577

6678
if (Operation *defineOp = pos.getDefiningOp()) {
79+
LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
80+
<< " after " << OpWithFlags(op, OpPrintingFlags().skipRegions());
6781
rewriter.moveOpAfter(op, defineOp);
6882
return;
6983
}
7084
auto argument = cast<BlockArgument>(pos);
85+
LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
86+
<< " before "
87+
<< OpWithFlags(&argument.getOwner()->front(),
88+
OpPrintingFlags().skipRegions());
7189
rewriter.moveOpBefore(op, &argument.getOwner()->front());
7290
}
7391

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: mlir-opt %s -hoist-pure-ops -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: func @hoist_cast_pos
4+
// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>,
5+
// CHECK-SAME: %[[ARG1:.*]]: i1
6+
func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
7+
// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
8+
// CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]]
9+
// CHECK-NEXT: cf.cond_br %[[ARG1]]
10+
cf.cond_br %arg1, ^bb1, ^bb2
11+
^bb1:
12+
%cast = memref.cast %arg : memref<10xf32> to memref<?xf32>
13+
// CHECK: return %[[CAST_1]]
14+
return %cast : memref<?xf32>
15+
^bb2:
16+
%cast1 = memref.cast %arg : memref<10xf32> to memref<?xf32>
17+
// CHECK: return %[[CAST_0]]
18+
return %cast1 : memref<?xf32>
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: func.func @hoist_cast_pos_alloc
24+
// CHECK-SAME: %[[ARG0:.*]]: i1
25+
func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref<?xf32>) {
26+
// CHECK: %[[ALLOC_0:.*]] = memref.alloc()
27+
// CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]]
28+
// CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]]
29+
// CHECK-NEXT: cf.cond_br %[[ARG0]]
30+
%alloc = memref.alloc() : memref<10xf32>
31+
cf.cond_br %arg, ^bb1, ^bb2
32+
^bb1:
33+
%cast = memref.cast %alloc : memref<10xf32> to memref<?xf32>
34+
// CHECK: return %[[CAST_1]]
35+
return %cast : memref<?xf32>
36+
^bb2:
37+
%cast1 = memref.cast %alloc : memref<10xf32> to memref<?xf32>
38+
// CHECK: return %[[CAST_0]]
39+
return %cast1 : memref<?xf32>
40+
}
41+
42+
// -----
43+
44+
// CHECK-LABEL: func @mult_scf_sum(
45+
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
46+
func.func @mult_scf_sum(%arg0: index, %arg1: index, %arg2: index) -> index {
47+
%c0 = arith.constant 0 : index
48+
%res0 = scf.for %iv0 = %arg0 to %arg1 step %arg2 iter_args(%sum0 = %c0) -> index {
49+
%res1 = scf.for %iv1 = %arg0 to %arg1 step %arg2 iter_args(%sum1 = %sum0) -> index {
50+
%res2 = scf.for %iv2 = %arg0 to %arg1 step %arg2 iter_args(%sum2 = %sum1) -> index {
51+
%add0 = arith.addi %iv0, %iv1 : index
52+
%add1 = arith.addi %add0, %iv2 : index
53+
%add2 = arith.addi %add1, %sum2 : index
54+
scf.yield %add1 : index
55+
}
56+
scf.yield %res2 : index
57+
}
58+
scf.yield %res1 : index
59+
}
60+
// CHECK: %[[FOR_0:.*]] = scf.for %[[IV_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
61+
// CHECK-NEXT: %[[FOR_1:.*]] = scf.for %[[IV_1:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
62+
// CHECK-NEXT: %[[ADDI_0:.*]] = arith.addi %[[IV_0]], %[[IV_1]] : index
63+
// CHECK-NEXT: %[[FOR_2:.*]] = scf.for %[[IV_3:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] iter_args(%[[ITER:.*]] = %{{.*}})
64+
// CHECK-NEXT: %[[ADDI_1:.*]] = arith.addi %[[ADDI_0]], %[[IV_3]] : index
65+
// CHECK-NEXT: %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[ITER]] : index
66+
return %res0 : index
67+
}

0 commit comments

Comments
 (0)