Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def SCF_Dialect : Dialect {
and then lowered to some final target like LLVM or SPIR-V.
}];

let dependentDialects = ["arith::ArithDialect"];
let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
}

// Base class for SCF dialect ops.
Expand Down Expand Up @@ -138,7 +138,9 @@ def ForOp : SCF_Op<"for",
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
"getLoopInductionVars", "getLoopLowerBounds", "getLoopSteps",
"getLoopUpperBounds", "getYieldedValuesMutable",
"moveOutOfLoopWithGuard",
"promoteIfSingleIteration", "replaceWithAdditionalYields",
"wrapInTripCountCheck", "unwrapTripCountCheck",
"yieldTiledValuesAndReplace"]>,
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
ConditionallySpeculatable,
Expand Down Expand Up @@ -302,7 +304,7 @@ def ForallOp : SCF_Op<"forall", [
AttrSizedOperandSegments,
AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
["getInitsMutable", "getRegionIterArgs", "getLoopInductionVars",
"getLoopLowerBounds", "getLoopUpperBounds", "getLoopSteps",
"promoteIfSingleIteration", "yieldTiledValuesAndReplace"]>,
RecursiveMemoryEffects,
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Interfaces/LoopLikeInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
/*methodBody=*/"",
/*defaultImplementation=*/"op->moveBefore($_op);"
>,
InterfaceMethod<[{
Moves the given loop-invariant operation out of the loop with a
trip-count guard.
}],
/*retTy=*/"void",
/*methodName=*/"moveOutOfLoopWithGuard",
/*args=*/(ins "::mlir::Operation *":$op),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return;
}]
>,
InterfaceMethod<[{
Promotes the loop body to its containing block if the loop is known to
have a single iteration. Returns "success" if the promotion was
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Interfaces/SideEffectInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,10 @@ bool wouldOpBeTriviallyDead(Operation *op);
/// conditions are satisfied.
bool isMemoryEffectFree(Operation *op);

/// Returns true if the given operation is free of memory effects or has only
/// read effect.
bool isMemoryEffectFreeOrOnlyRead(Operation *op);

/// Returns the side effects of an operation. If the operation has
/// RecursiveMemoryEffects, include all side effects of child operations.
///
Expand Down
16 changes: 10 additions & 6 deletions mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,19 @@ class Value;
/// }
/// }
/// ```
///
/// Users must supply three callbacks.
/// Users must supply four callbacks.
///
/// - `isDefinedOutsideRegion` returns true if the given value is invariant with
/// respect to the given region. A common implementation might be:
/// `value.getParentRegion()->isProperAncestor(region)`.
/// - `shouldMoveOutOfRegion` returns true if the provided operation can be
/// moved of the given region, e.g. if it is side-effect free.
/// - `moveOutOfRegion` moves the operation out of the given region. A common
/// implementation might be: `op->moveBefore(region->getParentOp())`.
/// moved of the given region, e.g. if it is side-effect free or has only read
/// side effects.
/// - `moveOutOfRegionWithoutGuard` moves the operation out of the given region
/// without a guard. A common implementation might be:
/// `op->moveBefore(region->getParentOp())`.
/// - `moveOutOfRegionWithGuard` moves the operation out of the given region
/// with a guard.
///
/// An operation is moved if all of its operands satisfy
/// `isDefinedOutsideRegion` and it satisfies `shouldMoveOutOfRegion`.
Expand All @@ -66,7 +69,8 @@ size_t moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion);
function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
function_ref<void(Operation *)> moveOutOfRegionWithGuard);

/// Move side-effect free loop invariant code out of a loop-like op using
/// methods provided by the interface.
Expand Down
40 changes: 37 additions & 3 deletions mlir/lib/Dialect/SCF/IR/SCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
Expand Down Expand Up @@ -395,6 +396,40 @@ std::optional<SmallVector<OpFoldResult>> ForOp::getLoopUpperBounds() {

std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }

/// Moves the op out of the loop with a guard that checks if the loop has at
/// least one iteration.
void ForOp::moveOutOfLoopWithGuard(Operation *op) {
IRRewriter rewriter(this->getContext());
OpBuilder::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPoint(this->getOperation());
Location loc = this->getLoc();
arith::CmpIOp cmpIOp = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, this->getLowerBound(),
this->getUpperBound());
// Create the trip-count check.
scf::YieldOp thenYield;
scf::IfOp ifOp = rewriter.create<scf::IfOp>(
loc, cmpIOp,
[&](OpBuilder &builder, Location loc) {
thenYield = builder.create<scf::YieldOp>(loc, op->getResults());
},
[&](OpBuilder &builder, Location loc) {
SmallVector<Value> poisonResults;
poisonResults.reserve(op->getResults().size());
for (Type type : op->getResults().getTypes()) {
ub::PoisonOp poisonOp =
rewriter.create<ub::PoisonOp>(loc, type, nullptr);
poisonResults.push_back(poisonOp);
}
builder.create<scf::YieldOp>(loc, poisonResults);
});
for (auto [opResult, ifOpResult] :
llvm::zip(op->getResults(), ifOp->getResults()))
rewriter.replaceAllUsesExcept(opResult, ifOpResult, thenYield);
// Move the op into the then block.
rewriter.moveOpBefore(op, thenYield);
}

/// Promotes the loop body of a forOp to its containing block if the forOp
/// it can be determined that the loop has a single iteration.
LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
Expand Down Expand Up @@ -3397,9 +3432,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {

if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
<< "expected as many input types as operands "
<< "(expected " << operands.size() << " got "
<< functionType.getNumInputs() << ")";
<< "expected as many input types as operands " << "(expected "
<< operands.size() << " got " << functionType.getNumInputs() << ")";
}

// Resolve input operands.
Expand Down
10 changes: 10 additions & 0 deletions mlir/lib/Interfaces/SideEffectInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,16 @@ mlir::getEffectsRecursively(Operation *rootOp) {
return effects;
}

bool mlir::isMemoryEffectFreeOrOnlyRead(Operation *op) {
std::optional<SmallVector<MemoryEffects::EffectInstance>> effects =
getEffectsRecursively(op);
if (!effects)
return false;
return std::all_of(effects->begin(), effects->end(), [](auto &effect) {
return isa<MemoryEffects::Read>(effect.getEffect());
});
}

bool mlir::isSpeculatable(Operation *op) {
auto conditionallySpeculatable = dyn_cast<ConditionallySpeculatable>(op);
if (!conditionallySpeculatable)
Expand Down
34 changes: 27 additions & 7 deletions mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,18 @@ size_t mlir::moveLoopInvariantCode(
ArrayRef<Region *> regions,
function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
function_ref<void(Operation *, Region *)> moveOutOfRegion) {
function_ref<void(Operation *)> moveOutOfRegionWithoutGuard,
function_ref<void(Operation *)> moveOutOfRegionWithGuard) {
size_t numMoved = 0;

for (Region *region : regions) {
LLVM_DEBUG(llvm::dbgs() << "Original loop:\n"
<< *region->getParentOp() << "\n");

bool anyOpHoistedWithGuard = false;
bool loopSideEffectFreeOrHasOnlyReadSideEffect =
isMemoryEffectFreeOrOnlyRead(region->getParentOp());

std::queue<Operation *> worklist;
// Add top-level operations in the loop body to the worklist.
for (Operation &op : region->getOps())
Expand All @@ -84,12 +89,26 @@ size_t mlir::moveLoopInvariantCode(
continue;

LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n");

if (!shouldMoveOutOfRegion(op, region) ||
!canBeHoisted(op, definedOutside))
continue;
// Can only hoist pure ops (side-effect free) when there is an op with
// write and/or unknown side effects in the loop.
if (!loopSideEffectFreeOrHasOnlyReadSideEffect && !isMemoryEffectFree(op))
continue;

LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n");
moveOutOfRegion(op, region);
bool moveWithoutGuard = !anyOpHoistedWithGuard && isMemoryEffectFree(op);
if (moveWithoutGuard) {
LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op
<< " without guard\n");
moveOutOfRegionWithoutGuard(op);
} else {
LLVM_DEBUG(llvm::dbgs()
<< "Moving loop-invariant op: " << *op << " with guard\n");
moveOutOfRegionWithGuard(op);
anyOpHoistedWithGuard = true;
}
++numMoved;

// Since the op has been moved, we need to check its users within the
Expand All @@ -106,13 +125,14 @@ size_t mlir::moveLoopInvariantCode(
size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
return moveLoopInvariantCode(
loopLike.getLoopRegions(),
[&](Value value, Region *) {
return loopLike.isDefinedOutsideOfLoop(value);
[&](Value value, Region *region) {
return !region->isAncestor(value.getParentRegion());
},
[&](Operation *op, Region *) {
return isMemoryEffectFree(op) && isSpeculatable(op);
return isSpeculatable(op) && isMemoryEffectFreeOrOnlyRead(op);
},
[&](Operation *op, Region *) { loopLike.moveOutOfLoop(op); });
[&](Operation *op) { loopLike.moveOutOfLoop(op); },
[&](Operation *op) { loopLike.moveOutOfLoopWithGuard(op); });
}

namespace {
Expand Down
144 changes: 144 additions & 0 deletions mlir/test/Transforms/loop-invariant-code-motion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,150 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
return
}

// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
func.func @test_speculatable_op_with_read_side_effect_success(%lb: index, %ub: index, %step: index) -> i32 {
// CHECK: test.always_speculatable_op
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
// CHECK-NEXT: scf.if %[[CMP]]
// CHECK-NEXT: test.speculatable_op_with_memread
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
// CHECK-NOT: test.always_speculatable_op
// CHECK-NOT: test.speculatable_op_with_memread
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
%always_speculate = "test.always_speculatable_op"() : () -> i32
%only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%i_cast = arith.index_cast %i: index to i32
%add = arith.addi %acc, %i_cast : i32
%sum = arith.addi %add, %only_read : i32
scf.yield %sum : i32
}
return %sum_result : i32
}

// CHECK-LABEL: test_speculatable_op_with_read_side_effect_multiple_result_success
func.func @test_speculatable_op_with_read_side_effect_multiple_result_success(%lb: index, %ub: index, %step: index) -> i32 {
// CHECK: test.always_speculatable_op
// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
// CHECK-NEXT: scf.if %[[CMP]]
// CHECK-NEXT: test.speculatable_op_with_memread
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK-NEXT: ub.poison : f32
// CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
// CHECK-NOT: test.always_speculatable_op
// CHECK-NOT: test.speculatable_op_with_memread
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
%always_speculate = "test.always_speculatable_op"() : () -> i32
%only_read:2 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> (i32, f32)
%i_cast = arith.index_cast %i: index to i32
%add = arith.addi %acc, %i_cast : i32
%sum = arith.addi %add, %only_read#0 : i32
scf.yield %sum : i32
}
return %sum_result : i32
}

// CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
func.func @test_speculatable_op_with_read_side_effect_success_with_dependents(%lb: index, %ub: index, %step: index) -> i32 {
// CHECK: %[[ALWAYS:.*]] = "test.always_speculatable_op"
// CHECK-NEXT: %[[CMP0:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
// CHECK-NEXT: %[[IF0:.*]] = scf.if %[[CMP0]]
// CHECK-NEXT: test.speculatable_op_with_memread
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK: %[[CMP1:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
// CHECK-NEXT: %[[IF1:.*]] = scf.if %[[CMP1]]
// CHECK-NEXT: arith.addi %[[ALWAYS]], %[[IF0]]
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK: %[[CMP2:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
// CHECK-NEXT: %[[IF2:.*]] = scf.if %[[CMP2]]
// CHECK-NEXT: test.speculatable_op_with_memread
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK: %[[CMP3:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
// CHECK-NEXT: %{{.*}} = scf.if %[[CMP3]]
// CHECK-NEXT: arith.addi %[[IF1]], %[[IF2]]
// CHECK: else
// CHECK-NEXT: ub.poison : i32
// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]]
// CHECK-NOT: test.always_speculatable_op
// CHECK-NOT: test.speculatable_op_with_memread
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
%always_speculate = "test.always_speculatable_op"() : () -> i32
%only_read_0 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%add_0 = arith.addi %always_speculate, %only_read_0 : i32
%only_read_1 = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%add_1 = arith.addi %add_0, %only_read_1 : i32
%i_cast = arith.index_cast %i: index to i32
%sum = arith.addi %add_1, %i_cast : i32
scf.yield %sum : i32
}
return %sum_result : i32
}

// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write(%lb: index, %ub: index, %step: index) -> i32 {
// CHECK: test.always_speculatable_op
// CHECK-NEXT: scf.for
// CHECK-NOT: test.always_speculatable_op
// CHECK: test.speculatable_op_with_memread
// CHECK: test.speculatable_op_with_memwrite
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
%always_speculate = "test.always_speculatable_op"() : () -> i32
%only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%i_cast = arith.index_cast %i: index to i32
%add = arith.addi %acc, %i_cast : i32
%sum = arith.addi %add, %only_read : i32
%write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
scf.yield %sum : i32
}
return %sum_result : i32
}

// CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write(%lb: index, %ub: index, %step: index) -> i32 {
// CHECK: test.always_speculatable_op
// CHECK-NEXT: scf.for
// CHECK-NOT: test.always_speculatable_op
// CHECK: test.speculatable_op_with_memread
// CHECK: scf.for
// CHECK: scf.if
// CHECK: test.speculatable_op_with_memwrite
%cst_0 = arith.constant 0 : i32
%cst_42 = arith.constant dense<42> : tensor<64xi32>
%ind_42 = arith.constant 42 : index
%sum_result = scf.for %i = %lb to %ub step %step iter_args(%acc = %cst_0) -> i32 {
%always_speculate = "test.always_speculatable_op"() : () -> i32
%only_read = "test.speculatable_op_with_memread"(%cst_42, %ind_42) : (tensor<64xi32>, index) -> i32
%i_cast = arith.index_cast %i: index to i32
%add = arith.addi %acc, %i_cast : i32
%sum = arith.addi %add, %only_read : i32
scf.for %j = %lb to %ub step %step {
%eq42 = arith.cmpi eq, %j, %ind_42 : index
scf.if %eq42 {
%always_write = "test.speculatable_op_with_memwrite"(%cst_42) : (tensor<64xi32>) -> i32
}
}
scf.yield %sum : i32
}
return %sum_result : i32
}

// -----

func.func @speculate_tensor_dim_unknown_rank_unknown_dim(
Expand Down
Loading
Loading