Skip to content

Commit f12aea1

Browse files
add remove-single-iteration-loop pass.
1 parent 91c0aa5 commit f12aea1

File tree

6 files changed

+199
-0
lines changed

6 files changed

+199
-0
lines changed

mlir/include/mlir/Dialect/Affine/Passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,11 @@ std::unique_ptr<Pass> createAffineExpandIndexOpsPass();
117117
/// operations.
118118
std::unique_ptr<Pass> createAffineExpandIndexOpsAsAffinePass();
119119

120+
/// Creates a pass in order to remove invalid loops or to move the IR out of the
121+
/// loop when the loop is only iterated once.
122+
std::unique_ptr<InterfacePass<FunctionOpInterface>>
123+
createRemoveSingleIterationLoopPass();
124+
120125
//===----------------------------------------------------------------------===//
121126
// Registration
122127
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Affine/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,11 @@ def AffineLoopUnrollAndJam : InterfacePass<"affine-loop-unroll-jam", "FunctionOp
229229
];
230230
}
231231

232+
def RemoveSingleIterationLoop : InterfacePass<"remove-single-iteration-loop", "FunctionOpInterface"> {
233+
let summary = "Remove distributed loop with single iteration.";
234+
let constructor = "mlir::affine::createRemoveSingleIterationLoopPass()";
235+
}
236+
232237
def AffinePipelineDataTransfer
233238
: Pass<"affine-pipeline-data-transfer", "func::FuncOp"> {
234239
let summary = "Pipeline non-blocking data transfers between explicitly "

mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns);
4141
/// `affine.apply` representations.
4242
void populateAffineExpandIndexOpsAsAffinePatterns(RewritePatternSet &patterns);
4343

44+
/// Insert pattern to remove single iteration loop. The pattern will detect
45+
/// single iteration loops based on the range returned ValueBoundsOpInterface.
46+
void populateRemoveSingleIterationLoopPattern(RewritePatternSet &patterns);
47+
4448
/// Helper function to rewrite `op`'s affine map and reorder its operands such
4549
/// that they are in increasing order of hoistability (i.e. the least hoistable)
4650
/// operands come first in the operand list.

mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRAffineTransforms
1616
ReifyValueBounds.cpp
1717
SuperVectorize.cpp
1818
SimplifyAffineStructures.cpp
19+
RemoveSingleIterationLoop.cpp
1920

2021
ADDITIONAL_HEADER_DIRS
2122
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
//===- RemoveSingleIterationLoop.cpp --- remove single iteration loop ---*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Affine/Passes.h"
10+
11+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
12+
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
13+
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
14+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15+
16+
#include "llvm/Support/Debug.h"
17+
18+
namespace mlir {
19+
namespace affine {
20+
#define GEN_PASS_DEF_REMOVESINGLEITERATIONLOOP
21+
#include "mlir/Dialect/Affine/Passes.h.inc"
22+
} // namespace affine
23+
} // namespace mlir
24+
25+
#define DEBUG_TYPE "affine-remove-single-iteration"
26+
27+
using namespace mlir;
28+
using namespace affine;
29+
30+
/// Replaces the given op with the contents of the given single-block region,
31+
/// using the operands of the block terminator to replace operation results.
32+
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
33+
Region &region, ValueRange blockArgs = {}) {
34+
assert(llvm::hasSingleElement(region) && "expected single-region block");
35+
Block *block = &region.front();
36+
Operation *terminator = block->getTerminator();
37+
ValueRange results = terminator->getOperands();
38+
rewriter.inlineBlockBefore(block, op, blockArgs);
39+
rewriter.replaceOp(op, results);
40+
rewriter.eraseOp(terminator);
41+
}
42+
43+
/// Return true if we can prove that the we always run at least the first
44+
/// iteration of the ForOp.
45+
static bool alwaysRunsFirstIteration(AffineForOp op) {
46+
// Can't perform the analysis if the loops's bounds aren't index-typed.
47+
if (!op.getInductionVar().getType().isIndex())
48+
return false;
49+
SmallVector<Value> lowerMapOperands = op.getLowerBoundOperands();
50+
SmallVector<Value> upperMapOperands = op.getUpperBoundOperands();
51+
ValueBoundsConstraintSet::Variable lower(op.getLowerBoundMap(),
52+
lowerMapOperands);
53+
ValueBoundsConstraintSet::Variable upper(op.getUpperBoundMap(),
54+
upperMapOperands);
55+
FailureOr<bool> isLb = ValueBoundsConstraintSet::compare(
56+
lower, ValueBoundsConstraintSet::LT, upper);
57+
return isLb.value_or(false);
58+
}
59+
60+
/// Return true if we can prove that the we never run more than one iteration of
61+
/// the ForOp.
62+
static bool neverRunsSecondIteration(AffineForOp op) {
63+
// Can't perform the analysis if the loops's bounds aren't index-typed.
64+
if (!op.getInductionVar().getType().isIndex())
65+
return false;
66+
67+
// The loop will only loop once if the inducation variable for the next time
68+
// in the loop is greater than or equal to upper.
69+
MLIRContext *context = op.getContext();
70+
SmallVector<Value> lowerMapOperands = op.getLowerBoundOperands();
71+
SmallVector<Value> upperMapOperands = op.getUpperBoundOperands();
72+
SmallVector<AffineExpr> results;
73+
AffineMap lowerMap = op.getLowerBoundMap();
74+
for (AffineExpr expr : lowerMap.getResults()) {
75+
results.push_back(expr + op.getStep().getSExtValue());
76+
}
77+
AffineMap nextItMap = AffineMap::get(
78+
lowerMap.getNumDims(), lowerMap.getNumSymbols(), results, context);
79+
ValueBoundsConstraintSet::Variable nextItVar(nextItMap, lowerMapOperands);
80+
ValueBoundsConstraintSet::Variable upperVar(op.getUpperBoundMap(),
81+
upperMapOperands);
82+
FailureOr<bool> isUpperUnderNextIter = ValueBoundsConstraintSet::compare(
83+
nextItVar, ValueBoundsConstraintSet::LE, upperVar);
84+
return isUpperUnderNextIter.value_or(false);
85+
}
86+
87+
namespace {
88+
89+
/// Rewriting pattern that replaces single-iteration loops with their bodies.
90+
struct SimplifyTrivialLoops : public OpRewritePattern<AffineForOp> {
91+
using OpRewritePattern::OpRewritePattern;
92+
93+
LogicalResult matchAndRewrite(AffineForOp op,
94+
PatternRewriter &rewriter) const override {
95+
if (!(alwaysRunsFirstIteration(op) && neverRunsSecondIteration(op))) {
96+
return failure();
97+
}
98+
99+
// The first iteration is always run and the second iteration is never run
100+
// so the loop always have 1 iteration. Inline its body and remove the loop.
101+
SmallVector<Value> blockArgs;
102+
blockArgs.reserve(op.getInits().size() + 1);
103+
rewriter.setInsertionPointToStart(op.getBody());
104+
Value lower = rewriter.create<AffineApplyOp>(
105+
op.getLoc(), op.getLowerBoundMap(), op.getLowerBoundOperands());
106+
op.getInductionVar().replaceAllUsesWith(lower);
107+
blockArgs.push_back(lower);
108+
llvm::append_range(blockArgs, op.getInits());
109+
replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
110+
return success();
111+
}
112+
};
113+
114+
struct RemoveSingleIterationLoop
115+
: public affine::impl::RemoveSingleIterationLoopBase<
116+
RemoveSingleIterationLoop> {
117+
void runOnOperation() override {
118+
auto funcOp = getOperation();
119+
RewritePatternSet patterns(funcOp.getContext());
120+
populateRemoveSingleIterationLoopPattern(patterns);
121+
if (failed(applyPatternsGreedily(funcOp, std::move(patterns))))
122+
return signalPassFailure();
123+
}
124+
};
125+
126+
} // namespace
127+
128+
void mlir::affine::populateRemoveSingleIterationLoopPattern(
129+
RewritePatternSet &patterns) {
130+
patterns.add<SimplifyTrivialLoops>(patterns.getContext());
131+
}
132+
133+
std::unique_ptr<InterfacePass<FunctionOpInterface>>
134+
mlir::affine::createRemoveSingleIterationLoopPass() {
135+
return std::make_unique<RemoveSingleIterationLoop>();
136+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(remove-single-iteration-loop))' -split-input-file | FileCheck %s
2+
3+
// -----
4+
5+
// CHECK-LABEL: func @loop_once(
6+
func.func @loop_once(%arg : index) -> index{
7+
%0 = affine.for %iv = 2 to 3 step 1 iter_args(%arg1 = %arg) -> index {
8+
%sum = arith.addi %arg1, %iv : index
9+
affine.yield %sum : index
10+
}
11+
return %0 : index
12+
}
13+
// CHECK-SAME: %[[ARG:.*]]: index)
14+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
15+
// CHECK: %[[SUM:.*]] = arith.addi %[[ARG]], %[[C2]] : index
16+
// CHECK: return %[[SUM]] : index
17+
18+
// -----
19+
20+
// CHECK-LABEL: func @invalid_loop(
21+
func.func @invalid_loop(%arg : index) -> index{
22+
%0 = affine.for %iv = 4 to 3 step 1 iter_args(%arg1 = %arg) -> index {
23+
%sum = arith.addi %arg1, %iv : index
24+
affine.yield %sum : index
25+
}
26+
return %0 : index
27+
}
28+
// CHECK-SAME: %[[ARG:.*]]: index)
29+
// CHECK: return %[[ARG]] : index
30+
31+
// -----
32+
33+
// CHECK-LABEL: func @gpu_invalid_loop
34+
func.func @gpu_invalid_loop() {
35+
%0 = arith.constant 0 :index
36+
%1 = arith.constant 2 : index
37+
gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %1, %sz_by = %1, %sz_bz = %1)
38+
threads(%tx, %ty, %tz) in (%sz_tx = %1, %sz_ty = %1, %sz_tz = %1) {
39+
%threadid = gpu.thread_id x
40+
affine.for %iv = %tx to 0 step 2 iter_args(%arg = %0) -> index {
41+
%3 = arith.addi %arg, %0 : index
42+
affine.yield %3 : index
43+
}
44+
gpu.terminator
45+
}
46+
// CHECK-NEXT: return
47+
return
48+
}

0 commit comments

Comments
 (0)