Skip to content

Commit 230fa93

Browse files
committed
[SCFToAffine] Add a pass to raise scf to affine ops.
This patch supports the conversion from `scf.for` to `affine.for`.
1 parent d803a93 commit 230fa93

File tree

7 files changed

+237
-0
lines changed

7 files changed

+237
-0
lines changed

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
5959
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
6060
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
61+
#include "mlir/Conversion/SCFToAffine/SCFToAffine.h"
6162
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
6263
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
6364
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,6 +1027,18 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> {
10271027
}];
10281028
}
10291029

1030+
//===----------------------------------------------------------------------===//
1031+
// SCFToAffine
1032+
//===----------------------------------------------------------------------===//
1033+
1034+
def RaiseSCFToAffinePass : Pass<"raise-scf-to-affine"> {
1035+
let summary = "Raise SCF to affine ops";
1036+
let dependentDialects = [
1037+
"affine::AffineDialect",
1038+
"scf::SCFDialect",
1039+
];
1040+
}
1041+
10301042
//===----------------------------------------------------------------------===//
10311043
// SCFToControlFlow
10321044
//===----------------------------------------------------------------------===//
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- SCFToAffine.h - SCF to Affine Pass entrypoint ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_
10+
#define MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_
11+
12+
#include <memory>
13+
14+
namespace mlir {
15+
class Pass;
16+
class RewritePatternSet;
17+
18+
#define GEN_PASS_DECL_RAISESCFTOAFFINEPASS
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
/// Collect a set of patterns to convert SCF operations to Affine operations.
22+
void populateSCFToAffineConversionPatterns(RewritePatternSet &patterns);
23+
24+
} // namespace mlir
25+
26+
#endif // MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ add_subdirectory(OpenACCToSCF)
5151
add_subdirectory(OpenMPToLLVM)
5252
add_subdirectory(PDLToPDLInterp)
5353
add_subdirectory(ReconcileUnrealizedCasts)
54+
add_subdirectory(SCFToAffine)
5455
add_subdirectory(SCFToControlFlow)
5556
add_subdirectory(SCFToEmitC)
5657
add_subdirectory(SCFToGPU)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
add_mlir_conversion_library(MLIRSCFToAffine
2+
SCFToAffine.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToAffine
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRArithDialect
12+
MLIRAffineDialect
13+
MLIRLLVMDialect
14+
MLIRSCFDialect
15+
MLIRSCFTransforms
16+
MLIRTransforms
17+
)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
//===- SCFToAffine.cpp - SCF to Affine conversion -------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a pass to raise scf.for, scf.if and loop.terminator
10+
// ops into affine ops.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Conversion/SCFToAffine/SCFToAffine.h"
15+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
16+
#include "mlir/Dialect/SCF/IR/SCF.h"
17+
#include "mlir/IR/Verifier.h"
18+
#include "mlir/Transforms/DialectConversion.h"
19+
#include "mlir/Transforms/Passes.h"
20+
21+
namespace mlir {
22+
#define GEN_PASS_DEF_RAISESCFTOAFFINEPASS
23+
#include "mlir/Conversion/Passes.h.inc"
24+
} // namespace mlir
25+
26+
using namespace mlir;
27+
28+
namespace {
29+
30+
struct SCFToAffinePass
31+
: public impl::RaiseSCFToAffinePassBase<SCFToAffinePass> {
32+
void runOnOperation() override;
33+
};
34+
35+
struct ForOpRewrite : public OpRewritePattern<scf::ForOp> {
36+
using OpRewritePattern<scf::ForOp>::OpRewritePattern;
37+
38+
std::pair<affine::AffineForOp, Value>
39+
createAffineFor(scf::ForOp op, PatternRewriter &rewriter) const {
40+
if (auto constantStep = op.getStep().getDefiningOp<arith::ConstantOp>()) {
41+
int64_t step = cast<IntegerAttr>(constantStep.getValue()).getInt();
42+
if (step > 0)
43+
return positiveConstantStep(op, step, rewriter);
44+
}
45+
return genericBounds(op, rewriter);
46+
}
47+
48+
std::pair<affine::AffineForOp, Value>
49+
positiveConstantStep(scf::ForOp op, int64_t step,
50+
PatternRewriter &rewriter) const {
51+
auto affineFor = affine::AffineForOp::create(
52+
rewriter, op.getLoc(), ValueRange(op.getLowerBound()),
53+
AffineMap::get(1, 0, rewriter.getAffineDimExpr(0)),
54+
ValueRange(op.getUpperBound()),
55+
AffineMap::get(1, 0, rewriter.getAffineDimExpr(0)), step,
56+
op.getInits());
57+
return std::make_pair(affineFor, affineFor.getInductionVar());
58+
}
59+
60+
std::pair<affine::AffineForOp, Value>
61+
genericBounds(scf::ForOp op, PatternRewriter &rewriter) const {
62+
Value lower = op.getLowerBound();
63+
Value upper = op.getUpperBound();
64+
Value step = op.getStep();
65+
AffineExpr lowerExpr = rewriter.getAffineDimExpr(0);
66+
AffineExpr upperExpr = rewriter.getAffineDimExpr(1);
67+
AffineExpr stepExpr = rewriter.getAffineSymbolExpr(0);
68+
auto affineFor = affine::AffineForOp::create(
69+
rewriter, op.getLoc(), ValueRange(), rewriter.getConstantAffineMap(0),
70+
ValueRange({lower, upper, step}),
71+
AffineMap::get(
72+
2, 1, (upperExpr - lowerExpr + stepExpr - 1).floorDiv(stepExpr)),
73+
1, op.getInits());
74+
75+
rewriter.setInsertionPointToStart(affineFor.getBody());
76+
auto actualIndexMap = AffineMap::get(
77+
2, 1, lowerExpr + rewriter.getAffineDimExpr(1) * stepExpr);
78+
auto actualIndex = affine::AffineApplyOp::create(
79+
rewriter, op.getLoc(), actualIndexMap,
80+
ValueRange({lower, affineFor.getInductionVar(), step}));
81+
return std::make_pair(affineFor, actualIndex.getResult());
82+
}
83+
84+
LogicalResult matchAndRewrite(scf::ForOp op,
85+
PatternRewriter &rewriter) const override {
86+
if (!affine::isValidDim(op.getLowerBound()) ||
87+
!affine::isValidDim(op.getUpperBound()) ||
88+
!affine::isValidSymbol(op.getStep()))
89+
return failure();
90+
91+
auto [affineFor, actualIndex] = createAffineFor(op, rewriter);
92+
Block *affineBody = affineFor.getBody();
93+
94+
if (affineBody->mightHaveTerminator())
95+
rewriter.eraseOp(affineBody->getTerminator());
96+
97+
SmallVector<Value> argValues;
98+
argValues.push_back(actualIndex);
99+
llvm::append_range(argValues, affineFor.getRegionIterArgs());
100+
rewriter.inlineBlockBefore(op.getBody(), affineBody, affineBody->end(),
101+
argValues);
102+
103+
auto scfYieldOp = cast<scf::YieldOp>(affineBody->getTerminator());
104+
rewriter.setInsertionPointToEnd(affineBody);
105+
rewriter.replaceOpWithNewOp<affine::AffineYieldOp>(
106+
scfYieldOp, scfYieldOp->getOperands());
107+
108+
rewriter.replaceOp(op, affineFor);
109+
return success();
110+
}
111+
};
112+
113+
} // namespace
114+
115+
void mlir::populateSCFToAffineConversionPatterns(RewritePatternSet &patterns) {
116+
patterns.add<ForOpRewrite>(patterns.getContext());
117+
}
118+
119+
void SCFToAffinePass::runOnOperation() {
120+
RewritePatternSet patterns(&getContext());
121+
populateSCFToAffineConversionPatterns(patterns);
122+
123+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
124+
signalPassFailure();
125+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: mlir-opt -raise-scf-to-affine -split-input-file %s | FileCheck %s
2+
3+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 - s1 + s2 - 1) floordiv s0)>
4+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)>
5+
// CHECK-LABEL: func.func @simple_loop(
6+
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi32>,
7+
// CHECK-SAME: %[[ARG1:.*]]: memref<3xindex>) {
8+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
9+
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
10+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
11+
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
12+
// CHECK: %[[VAL_4:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_1]]] : memref<3xindex>
13+
// CHECK: %[[VAL_5:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_2]]] : memref<3xindex>
14+
// CHECK: %[[VAL_6:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_3]]] : memref<3xindex>
15+
// CHECK: affine.for %[[VAL_7:.*]] = 0 to #[[$ATTR_0]](){{\[}}%[[VAL_6]], %[[VAL_4]], %[[VAL_5]]] {
16+
// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_4]], %[[VAL_7]]){{\[}}%[[VAL_6]]]
17+
// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
18+
// CHECK: }
19+
// CHECK: return
20+
// CHECK: }
21+
22+
func.func @simple_loop(%arg0: memref<?xi32>, %arg1: memref<3xindex>) {
23+
%c0_i32 = arith.constant 0 : i32
24+
%c0 = arith.constant 0 : index
25+
%c1 = arith.constant 1 : index
26+
%c2 = arith.constant 2 : index
27+
%0 = memref.load %arg1[%c0] : memref<3xindex>
28+
%1 = memref.load %arg1[%c1] : memref<3xindex>
29+
%2 = memref.load %arg1[%c2] : memref<3xindex>
30+
scf.for %arg2 = %0 to %1 step %2 {
31+
memref.store %c0_i32, %arg0[%arg2] : memref<?xi32>
32+
}
33+
return
34+
}
35+
36+
// CHECK-LABEL: func.func @loop_with_constant_step(
37+
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi32>,
38+
// CHECK-SAME: %[[ARG1:.*]]: index,
39+
// CHECK-SAME: %[[ARG2:.*]]: index) {
40+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
41+
// CHECK: affine.for %[[VAL_1:.*]] = %[[ARG1]] to %[[ARG2]] step 3 {
42+
// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[VAL_1]]] : memref<?xi32>
43+
// CHECK: }
44+
// CHECK: return
45+
// CHECK: }
46+
47+
func.func @loop_with_constant_step(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
48+
%c0_i32 = arith.constant 0 : i32
49+
%c3 = arith.constant 3 : index
50+
scf.for %arg3 = %arg1 to %arg2 step %c3 {
51+
memref.store %c0_i32, %arg0[%arg3] : memref<?xi32>
52+
}
53+
return
54+
}
55+

0 commit comments

Comments
 (0)