Skip to content

[SCFToAffine] Add a pass to raise scf to affine ops. #152925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToAffine/SCFToAffine.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,18 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> {
}];
}

//===----------------------------------------------------------------------===//
// SCFToAffine
//===----------------------------------------------------------------------===//

def RaiseSCFToAffinePass : Pass<"raise-scf-to-affine"> {
let summary = "Raise SCF to affine ops";
let dependentDialects = [
"affine::AffineDialect",
"scf::SCFDialect",
];
}

//===----------------------------------------------------------------------===//
// SCFToControlFlow
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- SCFToAffine.h - SCF to Affine Pass entrypoint ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_
#define MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_

#include <memory>

namespace mlir {
class Pass;
class RewritePatternSet;

#define GEN_PASS_DECL_RAISESCFTOAFFINEPASS
#include "mlir/Conversion/Passes.h.inc"

/// Collect a set of patterns to convert SCF operations to Affine operations.
void populateSCFToAffineConversionPatterns(RewritePatternSet &patterns);

} // namespace mlir

#endif // MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToAffine)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToEmitC)
add_subdirectory(SCFToGPU)
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/Conversion/SCFToAffine/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_conversion_library(MLIRSCFToAffine
SCFToAffine.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToAffine

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRAffineDialect
MLIRLLVMDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRTransforms
)
100 changes: 100 additions & 0 deletions mlir/lib/Conversion/SCFToAffine/SCFToAffine.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//===- SCFToAffine.cpp - SCF to Affine conversion -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to raise scf.for, scf.if and loop.terminator
// ops into affine ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/SCFToAffine/SCFToAffine.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"

namespace mlir {
#define GEN_PASS_DEF_RAISESCFTOAFFINEPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

namespace {

struct SCFToAffinePass
: public impl::RaiseSCFToAffinePassBase<SCFToAffinePass> {
void runOnOperation() override;
};

struct ForOpRewrite : public OpRewritePattern<scf::ForOp> {
using OpRewritePattern<scf::ForOp>::OpRewritePattern;

LogicalResult matchAndRewrite(scf::ForOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: please only use auto when it improves readability. E.g., the type is long (iterators) or impossible (lambdas) to spell.

auto lower = op.getLowerBound();
auto upper = op.getUpperBound();
auto step = op.getStep();

if (!affine::isValidDim(lower) || !affine::isValidDim(upper) ||
!affine::isValidSymbol(step))
return llvm::failure();

auto lowerDim = rewriter.getAffineDimExpr(0);
auto upperDim = rewriter.getAffineDimExpr(1);
auto stepSym = rewriter.getAffineSymbolExpr(0);
auto affineFor = affine::AffineForOp::create(
rewriter, loc, ValueRange(), rewriter.getConstantAffineMap(0),
ValueRange({lower, upper, step}),
AffineMap::get(2, 1,
(upperDim - lowerDim + stepSym - 1).floorDiv(stepSym)),
1, op.getInits());
Comment on lines +55 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why all this complexity? Affine fors support having values as lower and upper bounds (by using a 1D identity map), and a non-unit step. We could then "normalize" loops in a separate pass when desired, and we may already have such a pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The step of affine.for must be a constant integer. To ensure the generality of the conversion, I have normalized the loop. It may be necessary to consider performing an additional check on the step in scf to determine if it is a constant integer and handle it accordingly.

auto affineBody = affineFor.getBody();

if (affineBody->mightHaveTerminator())
rewriter.eraseOp(affineBody->getTerminator());

rewriter.setInsertionPointToStart(affineBody);
auto actualIndexMap =
AffineMap::get(2, 1, lowerDim + rewriter.getAffineDimExpr(1) * stepSym);
Value newIndVar =
affine::AffineApplyOp::create(
rewriter, op.getLoc(), actualIndexMap,
ValueRange({lower, affineFor.getInductionVar(), step}))
.getResult();

SmallVector<Value> argValues;
argValues.push_back(newIndVar);
llvm::append_range(argValues, affineFor.getRegionIterArgs());
rewriter.inlineBlockBefore(op.getBody(), affineBody, affineBody->end(),
argValues);

auto scfYieldOp = cast<scf::YieldOp>(affineBody->getTerminator());
rewriter.setInsertionPointToEnd(affineBody);
rewriter.replaceOpWithNewOp<affine::AffineYieldOp>(
scfYieldOp, scfYieldOp->getOperands());

rewriter.replaceOp(op, affineFor);
return success();
}
};

} // namespace

void mlir::populateSCFToAffineConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ForOpRewrite>(patterns.getContext());
}

void SCFToAffinePass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateSCFToAffineConversionPatterns(patterns);

if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
Comment on lines +96 to +99
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need the greedy rewriter with its overhead here or would a simple walk calling a function suffice?

}
34 changes: 34 additions & 0 deletions mlir/test/Conversion/SCFToAffine/scf-to-affine.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: mlir-opt -raise-scf-to-affine -split-input-file %s | FileCheck %s

// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 - s1 + s2 - 1) floordiv s0)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)>
// CHECK-LABEL: func.func @simple_loop(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi32>,
// CHECK-SAME: %[[ARG1:.*]]: memref<3xindex>) {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_4:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_1]]] : memref<3xindex>
// CHECK: %[[VAL_5:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_2]]] : memref<3xindex>
// CHECK: %[[VAL_6:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_3]]] : memref<3xindex>
// CHECK: affine.for %[[VAL_7:.*]] = 0 to #[[$ATTR_0]](){{\[}}%[[VAL_6]], %[[VAL_4]], %[[VAL_5]]] {
// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_4]], %[[VAL_7]]){{\[}}%[[VAL_6]]]
// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
// CHECK: }
// CHECK: return
// CHECK: }

func.func @simple_loop(%arg0: memref<?xi32>, %arg1: memref<3xindex>) {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%0 = memref.load %arg1[%c0] : memref<3xindex>
%1 = memref.load %arg1[%c1] : memref<3xindex>
%2 = memref.load %arg1[%c2] : memref<3xindex>
scf.for %arg2 = %0 to %1 step %2 {
memref.store %c0_i32, %arg0[%arg2] : memref<?xi32>
}
return
}