Skip to content

Commit 48dad4a

Browse files
committed
[SCFToAffine] Add a pass to raise scf to affine ops.
This patch supports the conversion from `scf.for` to `affine.for`.
1 parent d803a93 commit 48dad4a

File tree

7 files changed

+250
-0
lines changed

7 files changed

+250
-0
lines changed

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
5959
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
6060
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
61+
#include "mlir/Conversion/SCFToAffine/SCFToAffine.h"
6162
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
6263
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
6364
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,18 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> {
10271027
}];
10281028
}
10291029

1030+
//===----------------------------------------------------------------------===//
1031+
// SCFToAffine
1032+
//===----------------------------------------------------------------------===//
1033+
1034+
def RaiseSCFToAffinePass : Pass<"raise-scf-to-affine"> {
1035+
let summary = "Raise SCF to affine ops";
1036+
let dependentDialects = [
1037+
"affine::AffineDialect",
1038+
"scf::SCFDialect",
1039+
];
1040+
}
1041+
10301042
//===----------------------------------------------------------------------===//
10311043
// SCFToControlFlow
10321044
//===----------------------------------------------------------------------===//
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- SCFToAffine.h - SCF to Affine Pass entrypoint ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_
10+
#define MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
class RewritePatternSet;
17+
18+
#define GEN_PASS_DECL_RAISESCFTOAFFINEPASS
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
/// Collect a set of patterns to convert SCF operations to Affine operations.
22+
void populateSCFToAffineConversionPatterns(RewritePatternSet &patterns);
23+
24+
} // namespace mlir
25+
26+
#endif // MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ add_subdirectory(OpenACCToSCF)
5151
add_subdirectory(OpenMPToLLVM)
5252
add_subdirectory(PDLToPDLInterp)
5353
add_subdirectory(ReconcileUnrealizedCasts)
54+
add_subdirectory(SCFToAffine)
5455
add_subdirectory(SCFToControlFlow)
5556
add_subdirectory(SCFToEmitC)
5657
add_subdirectory(SCFToGPU)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_conversion_library(MLIRSCFToAffine
2+
SCFToAffine.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToAffine
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRArithDialect
12+
MLIRAffineDialect
13+
MLIRLLVMDialect
14+
MLIRSCFDialect
15+
MLIRSCFTransforms
16+
MLIRTransforms
17+
)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
//===- SCFToAffine.cpp - SCF to Affine conversion -------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to raise scf.for, scf.if and loop.terminator
10+
// ops into affine ops.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Conversion/SCFToAffine/SCFToAffine.h"
15+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
16+
#include "mlir/Dialect/SCF/IR/SCF.h"
17+
#include "mlir/IR/Verifier.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
#include "mlir/Transforms/Passes.h"
20+
21+
namespace mlir {
22+
#define GEN_PASS_DEF_RAISESCFTOAFFINEPASS
23+
#include "mlir/Conversion/Passes.h.inc"
24+
} // namespace mlir
25+
26+
using namespace mlir;
27+
28+
namespace {
29+
30+
struct SCFToAffinePass
31+
: public impl::RaiseSCFToAffinePassBase<SCFToAffinePass> {
32+
void runOnOperation() override;
33+
};
34+
35+
bool canRaiseToAffine(scf::ForOp op) {
36+
return affine::isValidDim(op.getLowerBound()) &&
37+
affine::isValidDim(op.getUpperBound()) &&
38+
affine::isValidSymbol(op.getStep());
39+
}
40+
41+
struct ForOpRewrite : public OpRewritePattern<scf::ForOp> {
42+
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
43+
44+
std::pair<affine::AffineForOp, Value>
45+
createAffineFor(scf::ForOp op, PatternRewriter &rewriter) const {
46+
if (auto constantStep = op.getStep().getDefiningOp<arith::ConstantOp>()) {
47+
int64_t step = cast<IntegerAttr>(constantStep.getValue()).getInt();
48+
if (step > 0)
49+
return positiveConstantStep(op, step, rewriter);
50+
}
51+
return genericBounds(op, rewriter);
52+
}
53+
54+
std::pair<affine::AffineForOp, Value>
55+
positiveConstantStep(scf::ForOp op, int64_t step,
56+
PatternRewriter &rewriter) const {
57+
auto affineFor = affine::AffineForOp::create(
58+
rewriter, op.getLoc(), ValueRange(op.getLowerBound()),
59+
AffineMap::get(1, 0, rewriter.getAffineDimExpr(0)),
60+
ValueRange(op.getUpperBound()),
61+
AffineMap::get(1, 0, rewriter.getAffineDimExpr(0)), step,
62+
op.getInits());
63+
return std::make_pair(affineFor, affineFor.getInductionVar());
64+
}
65+
66+
std::pair<affine::AffineForOp, Value>
67+
genericBounds(scf::ForOp op, PatternRewriter &rewriter) const {
68+
Value lower = op.getLowerBound();
69+
Value upper = op.getUpperBound();
70+
Value step = op.getStep();
71+
AffineExpr lowerExpr = rewriter.getAffineDimExpr(0);
72+
AffineExpr upperExpr = rewriter.getAffineDimExpr(1);
73+
AffineExpr stepExpr = rewriter.getAffineSymbolExpr(0);
74+
auto affineFor = affine::AffineForOp::create(
75+
rewriter, op.getLoc(), ValueRange(), rewriter.getConstantAffineMap(0),
76+
ValueRange({lower, upper, step}),
77+
AffineMap::get(
78+
2, 1, (upperExpr - lowerExpr + stepExpr - 1).floorDiv(stepExpr)),
79+
1, op.getInits());
80+
81+
rewriter.setInsertionPointToStart(affineFor.getBody());
82+
auto actualIndexMap = AffineMap::get(
83+
2, 1, lowerExpr + rewriter.getAffineDimExpr(1) * stepExpr);
84+
auto actualIndex = affine::AffineApplyOp::create(
85+
rewriter, op.getLoc(), actualIndexMap,
86+
ValueRange({lower, affineFor.getInductionVar(), step}));
87+
return std::make_pair(affineFor, actualIndex.getResult());
88+
}
89+
90+
LogicalResult matchAndRewrite(scf::ForOp op,
91+
PatternRewriter &rewriter) const override {
92+
if (!canRaiseToAffine(op))
93+
return failure();
94+
95+
auto [affineFor, actualIndex] = createAffineFor(op, rewriter);
96+
Block *affineBody = affineFor.getBody();
97+
98+
if (affineBody->mightHaveTerminator())
99+
rewriter.eraseOp(affineBody->getTerminator());
100+
101+
SmallVector<Value> argValues;
102+
argValues.push_back(actualIndex);
103+
llvm::append_range(argValues, affineFor.getRegionIterArgs());
104+
rewriter.inlineBlockBefore(op.getBody(), affineBody, affineBody->end(),
105+
argValues);
106+
107+
auto scfYieldOp = cast<scf::YieldOp>(affineBody->getTerminator());
108+
rewriter.setInsertionPointToEnd(affineBody);
109+
rewriter.replaceOpWithNewOp<affine::AffineYieldOp>(
110+
scfYieldOp, scfYieldOp->getOperands());
111+
112+
rewriter.replaceOp(op, affineFor);
113+
return success();
114+
}
115+
};
116+
117+
} // namespace
118+
119+
void mlir::populateSCFToAffineConversionPatterns(RewritePatternSet &patterns) {
120+
patterns.add<ForOpRewrite>(patterns.getContext());
121+
}
122+
123+
void SCFToAffinePass::runOnOperation() {
124+
MLIRContext &ctx = getContext();
125+
RewritePatternSet patterns(&ctx);
126+
populateSCFToAffineConversionPatterns(patterns);
127+
128+
// Configure conversion to raise SCF operations.
129+
ConversionTarget target(ctx);
130+
target.addDynamicallyLegalOp<scf::ForOp>(
131+
[](scf::ForOp op) { return !canRaiseToAffine(op); });
132+
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
133+
if (failed(
134+
applyPartialConversion(getOperation(), target, std::move(patterns))))
135+
signalPassFailure();
136+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: mlir-opt -raise-scf-to-affine -split-input-file %s | FileCheck %s
2+
3+
// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1)[s0] -> ((d1 - d0 + s0 - 1) floordiv s0)>
4+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)>
5+
// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> (d0)>
6+
// CHECK-LABEL: func.func @simple_loop(
7+
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi32>,
8+
// CHECK-SAME: %[[ARG1:.*]]: memref<3xindex>) {
9+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
10+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
11+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
12+
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
13+
// CHECK: %[[VAL_4:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_1]]] : memref<3xindex>
14+
// CHECK: %[[VAL_5:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_2]]] : memref<3xindex>
15+
// CHECK: %[[VAL_6:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_3]]] : memref<3xindex>
16+
// CHECK: affine.for %[[VAL_7:.*]] = 0 to #[[$ATTR_0]](%[[VAL_4]], %[[VAL_5]]){{\[}}%[[VAL_6]]] {
17+
// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_4]], %[[VAL_7]]){{\[}}%[[VAL_6]]]
18+
// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
19+
// CHECK: }
20+
// CHECK: return
21+
// CHECK: }
22+
23+
func.func @simple_loop(%arg0: memref<?xi32>, %arg1: memref<3xindex>) {
24+
%c0_i32 = arith.constant 0 : i32
25+
%c0 = arith.constant 0 : index
26+
%c1 = arith.constant 1 : index
27+
%c2 = arith.constant 2 : index
28+
%0 = memref.load %arg1[%c0] : memref<3xindex>
29+
%1 = memref.load %arg1[%c1] : memref<3xindex>
30+
%2 = memref.load %arg1[%c2] : memref<3xindex>
31+
scf.for %arg2 = %0 to %1 step %2 {
32+
memref.store %c0_i32, %arg0[%arg2] : memref<?xi32>
33+
}
34+
return
35+
}
36+
37+
// CHECK-LABEL: func.func @loop_with_constant_step(
38+
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi32>,
39+
// CHECK-SAME: %[[ARG1:.*]]: index,
40+
// CHECK-SAME: %[[ARG2:.*]]: index) {
41+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
42+
// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
43+
// CHECK: affine.for %[[VAL_2:.*]] = #[[$ATTR_2]](%[[ARG1]]) to #[[$ATTR_2]](%[[ARG2]]) step 3 {
44+
// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[VAL_2]]] : memref<?xi32>
45+
// CHECK: }
46+
// CHECK: return
47+
// CHECK: }
48+
49+
func.func @loop_with_constant_step(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
50+
%c0_i32 = arith.constant 0 : i32
51+
%c3 = arith.constant 3 : index
52+
scf.for %arg3 = %arg1 to %arg2 step %c3 {
53+
memref.store %c0_i32, %arg0[%arg3] : memref<?xi32>
54+
}
55+
return
56+
}
57+

0 commit comments

Comments
 (0)