Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class GreedyRewriteConfig;
#define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS
#define GEN_PASS_DECL_CONTROLFLOWSINK
#define GEN_PASS_DECL_GENERATERUNTIMEVERIFICATION
#define GEN_PASS_DECL_HOISTPUREOPS
#define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION
#define GEN_PASS_DECL_INLINER
#define GEN_PASS_DECL_MEM2REG
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -636,4 +636,8 @@ def BubbleDownMemorySpaceCasts :
}];
}

def HoistPureOps :
Pass<"hoist-pure-ops"> {
}

#endif // MLIR_TRANSFORMS_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_mlir_library(MLIRTransforms
SymbolPrivatize.cpp
TopologicalSort.cpp
ViewOpGraph.cpp
HoistPureOps.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
Expand Down
136 changes: 136 additions & 0 deletions mlir/lib/Transforms/HoistPureOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
//===- HoistPureOps.cpp - Hoist Pure ops ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the function of hoist the pure op based on SSA
// dominance.
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/Dialect.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/DebugLog.h"

namespace mlir {
#define GEN_PASS_DEF_HOISTPUREOPS
#include "mlir/Transforms/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE "hoist-pure-ops"

using namespace mlir;

namespace {

/// Return the dominated Value.
static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
Block *aB = a.getParentBlock();
Block *bB = b.getParentBlock();
if (isa<BlockArgument>(a) && isa<BlockArgument>(b)) {
return dominanceInfo.dominates(aB, bB) ? b : a;
} else if (isa<BlockArgument>(a) || isa<BlockArgument>(b)) {
if (aB != bB)
return dominanceInfo.dominates(aB, bB) ? b : a;
if (auto aArg = dyn_cast<BlockArgument>(a)) {
Operation *aFrontOp = &aArg.getOwner()->front();
if (aFrontOp == b.getDefiningOp())
return b;
return dominanceInfo.dominates(aFrontOp, b.getDefiningOp()) ? b : a;
}
auto bArg = cast<BlockArgument>(b);
Operation *bFrontOp = &bArg.getOwner()->front();
if (bFrontOp == a.getDefiningOp())
return a;
return dominanceInfo.dominates(a.getDefiningOp(), bFrontOp) ? b : a;
} else {
Operation *aDefineOp = a.getDefiningOp();
Operation *bDefineOp = b.getDefiningOp();
return dominanceInfo.dominates(aDefineOp, bDefineOp) ? b : a;
}
}

static bool isOpContainBlock(Operation *op, Block *block) {
Operation *parentOp = block->getParentOp();
while (parentOp && parentOp != op) {
parentOp = parentOp->getParentOp();
}
return parentOp == op ? true : false;
}

/// Find the hoisting position for the pure op.
static Value getDestPos(Operation *op) {
DominanceInfo dominanceInfo(op);
SmallVector<Value> operands(op->getOperands());
if (op->getNumRegions()) {
op->walk([&](Operation *operation) {
for (auto operand : operation->getOperands()) {
Operation *defineOp = operand.getDefiningOp();
if (!defineOp) {
BlockArgument argument = cast<BlockArgument>(operand);
if (!isOpContainBlock(op, argument.getOwner()))
operands.push_back(operand);
continue;
}
if (!isOpContainBlock(op, defineOp->getBlock())) {
operands.push_back(operand);
}
}
});
}
if (operands.empty())
return {};
Value ret = operands[0];
for (int i = 1, e = operands.size(); i < e; ++i) {
ret = getDomaincedValue(dominanceInfo, ret, operands[i]);
}
return ret;
}

/// Hoist single pure op.
static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
LDBG() << "hoistPureOp: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
Value pos = getDestPos(op);
if (!pos)
return;

if (Operation *defineOp = pos.getDefiningOp()) {
if (op == defineOp)
return;

LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
<< " after "
<< OpWithFlags(defineOp, OpPrintingFlags().skipRegions());
rewriter.moveOpAfter(op, defineOp);
return;
}
auto argument = cast<BlockArgument>(pos);
LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
<< " before "
<< OpWithFlags(&argument.getOwner()->front(),
OpPrintingFlags().skipRegions());
rewriter.moveOpBefore(op, &argument.getOwner()->front());
}

struct HoistPureOps : public impl::HoistPureOpsBase<HoistPureOps> {
void runOnOperation() override;
};
} // namespace

void HoistPureOps::runOnOperation() {
Operation *module = getOperation();
IRRewriter rewriter(module->getContext());
module->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op->hasTrait<mlir::OpTrait::IsTerminator>())
return;
if (isPure(op)) {
hoistPureOp(rewriter, op);
}
});
}
67 changes: 67 additions & 0 deletions mlir/test/Transforms/hoist-pure-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: mlir-opt %s -hoist-pure-ops -split-input-file -verify-diagnostics | FileCheck %s

// CHECK-LABEL: func @hoist_cast_pos
// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>,
// CHECK-SAME: %[[ARG1:.*]]: i1
func.func @hoist_cast_pos(%arg: memref<10xf32>, %arg1: i1) -> (memref<?xf32>) {
// CHECK: %[[CAST_0:.*]] = memref.cast %[[ARG0]]
// CHECK: %[[CAST_1:.*]] = memref.cast %[[ARG0]]
// CHECK-NEXT: cf.cond_br %[[ARG1]]
cf.cond_br %arg1, ^bb1, ^bb2
^bb1:
%cast = memref.cast %arg : memref<10xf32> to memref<?xf32>
// CHECK: return %[[CAST_1]]
return %cast : memref<?xf32>
^bb2:
%cast1 = memref.cast %arg : memref<10xf32> to memref<?xf32>
// CHECK: return %[[CAST_0]]
return %cast1 : memref<?xf32>
}

// -----

// CHECK-LABEL: func.func @hoist_cast_pos_alloc
// CHECK-SAME: %[[ARG0:.*]]: i1
func.func @hoist_cast_pos_alloc(%arg: i1) -> (memref<?xf32>) {
// CHECK: %[[ALLOC_0:.*]] = memref.alloc()
// CHECK: %[[CAST_0:.*]] = memref.cast %[[ALLOC_0]]
// CHECK: %[[CAST_1:.*]] = memref.cast %[[ALLOC_0]]
// CHECK-NEXT: cf.cond_br %[[ARG0]]
%alloc = memref.alloc() : memref<10xf32>
cf.cond_br %arg, ^bb1, ^bb2
^bb1:
%cast = memref.cast %alloc : memref<10xf32> to memref<?xf32>
// CHECK: return %[[CAST_1]]
return %cast : memref<?xf32>
^bb2:
%cast1 = memref.cast %alloc : memref<10xf32> to memref<?xf32>
// CHECK: return %[[CAST_0]]
return %cast1 : memref<?xf32>
}

// -----

// CHECK-LABEL: func @mult_scf_sum(
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
func.func @mult_scf_sum(%arg0: index, %arg1: index, %arg2: index) -> index {
%c0 = arith.constant 0 : index
%res0 = scf.for %iv0 = %arg0 to %arg1 step %arg2 iter_args(%sum0 = %c0) -> index {
%res1 = scf.for %iv1 = %arg0 to %arg1 step %arg2 iter_args(%sum1 = %sum0) -> index {
%res2 = scf.for %iv2 = %arg0 to %arg1 step %arg2 iter_args(%sum2 = %sum1) -> index {
%add0 = arith.addi %iv0, %iv1 : index
%add1 = arith.addi %add0, %iv2 : index
%add2 = arith.addi %add1, %sum2 : index
scf.yield %add1 : index
}
scf.yield %res2 : index
}
scf.yield %res1 : index
}
// CHECK: %[[FOR_0:.*]] = scf.for %[[IV_0:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
// CHECK-NEXT: %[[FOR_1:.*]] = scf.for %[[IV_1:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]]
// CHECK-NEXT: %[[ADDI_0:.*]] = arith.addi %[[IV_0]], %[[IV_1]] : index
// CHECK-NEXT: %[[FOR_2:.*]] = scf.for %[[IV_3:.*]] = %[[ARG0]] to %[[ARG1]] step %[[ARG2]] iter_args(%[[ITER:.*]] = %{{.*}})
// CHECK-NEXT: %[[ADDI_1:.*]] = arith.addi %[[ADDI_0]], %[[IV_3]] : index
// CHECK-NEXT: %[[ADDI_2:.*]] = arith.addi %[[ADDI_1]], %[[ITER]] : index
return %res0 : index
}