-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[SCFToAffine] Add a pass to raise scf to affine ops. #152925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
//===- SCFToAffine.h - SCF to Affine Pass entrypoint ------------*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_ | ||
#define MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_ | ||
|
||
#include <memory> | ||
|
||
namespace mlir { | ||
class Pass; | ||
class RewritePatternSet; | ||
|
||
#define GEN_PASS_DECL_RAISESCFTOAFFINEPASS | ||
#include "mlir/Conversion/Passes.h.inc" | ||
|
||
/// Collect a set of patterns to convert SCF operations to Affine operations. | ||
void populateSCFToAffineConversionPatterns(RewritePatternSet &patterns); | ||
|
||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
add_mlir_conversion_library(MLIRSCFToAffine | ||
SCFToAffine.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToAffine | ||
|
||
DEPENDS | ||
MLIRConversionPassIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRArithDialect | ||
MLIRAffineDialect | ||
MLIRLLVMDialect | ||
MLIRSCFDialect | ||
MLIRSCFTransforms | ||
MLIRTransforms | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
//===- SCFToAffine.cpp - SCF to Affine conversion -------------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file implements a pass to raise scf.for, scf.if and loop.terminator | ||
// ops into affine ops. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/SCFToAffine/SCFToAffine.h" | ||
#include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
#include "mlir/Dialect/SCF/IR/SCF.h" | ||
#include "mlir/IR/Verifier.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "mlir/Transforms/Passes.h" | ||
|
||
namespace mlir { | ||
#define GEN_PASS_DEF_RAISESCFTOAFFINEPASS | ||
#include "mlir/Conversion/Passes.h.inc" | ||
} // namespace mlir | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
struct SCFToAffinePass | ||
: public impl::RaiseSCFToAffinePassBase<SCFToAffinePass> { | ||
void runOnOperation() override; | ||
}; | ||
|
||
struct ForOpRewrite : public OpRewritePattern<scf::ForOp> { | ||
using OpRewritePattern<scf::ForOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(scf::ForOp op, | ||
PatternRewriter &rewriter) const override { | ||
auto loc = op.getLoc(); | ||
auto lower = op.getLowerBound(); | ||
auto upper = op.getUpperBound(); | ||
auto step = op.getStep(); | ||
|
||
if (!affine::isValidDim(lower) || !affine::isValidDim(upper) || | ||
!affine::isValidSymbol(step)) | ||
return llvm::failure(); | ||
|
||
auto lowerDim = rewriter.getAffineDimExpr(0); | ||
auto upperDim = rewriter.getAffineDimExpr(1); | ||
auto stepSym = rewriter.getAffineSymbolExpr(0); | ||
auto affineFor = affine::AffineForOp::create( | ||
rewriter, loc, ValueRange(), rewriter.getConstantAffineMap(0), | ||
ValueRange({lower, upper, step}), | ||
AffineMap::get(2, 1, | ||
(upperDim - lowerDim + stepSym - 1).floorDiv(stepSym)), | ||
1, op.getInits()); | ||
Comment on lines
+55
to
+57
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why all this complexity? Affine There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
auto affineBody = affineFor.getBody(); | ||
|
||
if (affineBody->mightHaveTerminator()) | ||
rewriter.eraseOp(affineBody->getTerminator()); | ||
|
||
rewriter.setInsertionPointToStart(affineBody); | ||
auto actualIndexMap = | ||
AffineMap::get(2, 1, lowerDim + rewriter.getAffineDimExpr(1) * stepSym); | ||
Value newIndVar = | ||
affine::AffineApplyOp::create( | ||
rewriter, op.getLoc(), actualIndexMap, | ||
ValueRange({lower, affineFor.getInductionVar(), step})) | ||
.getResult(); | ||
|
||
SmallVector<Value> argValues; | ||
argValues.push_back(newIndVar); | ||
llvm::append_range(argValues, affineFor.getRegionIterArgs()); | ||
rewriter.inlineBlockBefore(op.getBody(), affineBody, affineBody->end(), | ||
argValues); | ||
|
||
auto scfYieldOp = cast<scf::YieldOp>(affineBody->getTerminator()); | ||
rewriter.setInsertionPointToEnd(affineBody); | ||
rewriter.replaceOpWithNewOp<affine::AffineYieldOp>( | ||
scfYieldOp, scfYieldOp->getOperands()); | ||
|
||
rewriter.replaceOp(op, affineFor); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::populateSCFToAffineConversionPatterns(RewritePatternSet &patterns) { | ||
patterns.add<ForOpRewrite>(patterns.getContext()); | ||
} | ||
|
||
void SCFToAffinePass::runOnOperation() { | ||
RewritePatternSet patterns(&getContext()); | ||
populateSCFToAffineConversionPatterns(patterns); | ||
|
||
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) | ||
signalPassFailure(); | ||
Comment on lines
+96
to
+99
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we actually need the greedy rewriter with its overhead here or would a simple walk calling a function suffice? |
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
// RUN: mlir-opt -raise-scf-to-affine -split-input-file %s | FileCheck %s | ||
|
||
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 - s1 + s2 - 1) floordiv s0)> | ||
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)> | ||
// CHECK-LABEL: func.func @simple_loop( | ||
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi32>, | ||
// CHECK-SAME: %[[ARG1:.*]]: memref<3xindex>) { | ||
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32 | ||
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index | ||
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index | ||
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index | ||
// CHECK: %[[VAL_4:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_1]]] : memref<3xindex> | ||
// CHECK: %[[VAL_5:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_2]]] : memref<3xindex> | ||
// CHECK: %[[VAL_6:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_3]]] : memref<3xindex> | ||
// CHECK: affine.for %[[VAL_7:.*]] = 0 to #[[$ATTR_0]](){{\[}}%[[VAL_6]], %[[VAL_4]], %[[VAL_5]]] { | ||
// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_4]], %[[VAL_7]]){{\[}}%[[VAL_6]]] | ||
// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[VAL_8]]] : memref<?xi32> | ||
// CHECK: } | ||
// CHECK: return | ||
// CHECK: } | ||
|
||
func.func @simple_loop(%arg0: memref<?xi32>, %arg1: memref<3xindex>) { | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c2 = arith.constant 2 : index | ||
%0 = memref.load %arg1[%c0] : memref<3xindex> | ||
%1 = memref.load %arg1[%c1] : memref<3xindex> | ||
%2 = memref.load %arg1[%c2] : memref<3xindex> | ||
scf.for %arg2 = %0 to %1 step %2 { | ||
memref.store %c0_i32, %arg0[%arg2] : memref<?xi32> | ||
} | ||
return | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: please only use
auto
when it improves readability. E.g., the type is long (iterators) or impossible (lambdas) to spell.