Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
#include "mlir/Conversion/SCFToAffine/SCFToAffine.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h"
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,18 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> {
}];
}

//===----------------------------------------------------------------------===//
// SCFToAffine
//===----------------------------------------------------------------------===//

def RaiseSCFToAffinePass : Pass<"raise-scf-to-affine"> {
let summary = "Raise SCF to affine ops";
let dependentDialects = [
"affine::AffineDialect",
"scf::SCFDialect",
];
}

//===----------------------------------------------------------------------===//
// SCFToControlFlow
//===----------------------------------------------------------------------===//
Expand Down
26 changes: 26 additions & 0 deletions mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
//===- SCFToAffine.h - SCF to Affine Pass entrypoint ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_
#define MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_

#include <memory>

namespace mlir {
class Pass;
class RewritePatternSet;

#define GEN_PASS_DECL_RAISESCFTOAFFINEPASS
#include "mlir/Conversion/Passes.h.inc"

/// Collect a set of patterns to convert SCF operations to Affine operations.
void populateSCFToAffineConversionPatterns(RewritePatternSet &patterns);

} // namespace mlir

#endif // MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ add_subdirectory(OpenACCToSCF)
add_subdirectory(OpenMPToLLVM)
add_subdirectory(PDLToPDLInterp)
add_subdirectory(ReconcileUnrealizedCasts)
add_subdirectory(SCFToAffine)
add_subdirectory(SCFToControlFlow)
add_subdirectory(SCFToEmitC)
add_subdirectory(SCFToGPU)
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/Conversion/SCFToAffine/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_conversion_library(MLIRSCFToAffine
SCFToAffine.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToAffine

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRAffineDialect
MLIRLLVMDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRTransforms
)
133 changes: 133 additions & 0 deletions mlir/lib/Conversion/SCFToAffine/SCFToAffine.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
//===- SCFToAffine.cpp - SCF to Affine conversion -------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass to raise scf.for, scf.if and loop.terminator
// ops into affine ops.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/SCFToAffine/SCFToAffine.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

namespace mlir {
#define GEN_PASS_DEF_RAISESCFTOAFFINEPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

#define DEBUG_TYPE "raise-scf-to-affine"

using namespace mlir;

namespace {

struct SCFToAffinePass
: public impl::RaiseSCFToAffinePassBase<SCFToAffinePass> {
void runOnOperation() override;
};

bool canRaiseToAffine(scf::ForOp op) {
return affine::isValidDim(op.getLowerBound()) &&
affine::isValidDim(op.getUpperBound()) &&
affine::isValidSymbol(op.getStep());
}

struct ForOpRewrite : public OpRewritePattern<scf::ForOp> {
using OpRewritePattern<scf::ForOp>::OpRewritePattern;

std::pair<affine::AffineForOp, Value>
createAffineFor(scf::ForOp op, PatternRewriter &rewriter) const {
if (auto constantStep = op.getStep().getDefiningOp<arith::ConstantOp>()) {
int64_t step = cast<IntegerAttr>(constantStep.getValue()).getInt();
if (step > 0)
return positiveConstantStep(op, step, rewriter);
}
return genericBounds(op, rewriter);
}

std::pair<affine::AffineForOp, Value>
positiveConstantStep(scf::ForOp op, int64_t step,
PatternRewriter &rewriter) const {
auto affineFor = affine::AffineForOp::create(
rewriter, op.getLoc(), ValueRange(op.getLowerBound()),
AffineMap::get(1, 0, rewriter.getAffineDimExpr(0)),
ValueRange(op.getUpperBound()),
AffineMap::get(1, 0, rewriter.getAffineDimExpr(0)), step,
op.getInits());
return std::make_pair(affineFor, affineFor.getInductionVar());
}

std::pair<affine::AffineForOp, Value>
genericBounds(scf::ForOp op, PatternRewriter &rewriter) const {
Value lower = op.getLowerBound();
Value upper = op.getUpperBound();
Value step = op.getStep();
AffineExpr lowerExpr = rewriter.getAffineDimExpr(0);
AffineExpr upperExpr = rewriter.getAffineDimExpr(1);
AffineExpr stepExpr = rewriter.getAffineSymbolExpr(0);
auto affineFor = affine::AffineForOp::create(
rewriter, op.getLoc(), ValueRange(), rewriter.getConstantAffineMap(0),
ValueRange({lower, upper, step}),
AffineMap::get(
2, 1, (upperExpr - lowerExpr + stepExpr - 1).floorDiv(stepExpr)),
1, op.getInits());

rewriter.setInsertionPointToStart(affineFor.getBody());
auto actualIndexMap = AffineMap::get(
2, 1, lowerExpr + rewriter.getAffineDimExpr(1) * stepExpr);
auto actualIndex = affine::AffineApplyOp::create(
rewriter, op.getLoc(), actualIndexMap,
ValueRange({lower, affineFor.getInductionVar(), step}));
return std::make_pair(affineFor, actualIndex.getResult());
}

LogicalResult matchAndRewrite(scf::ForOp op,
PatternRewriter &rewriter) const override {
if (!canRaiseToAffine(op)) {
LLVM_DEBUG(llvm::dbgs()
<< "[affine] Cannot raise scf op: " << op << "\n");
return failure();
}

auto [affineFor, actualIndex] = createAffineFor(op, rewriter);
Block *affineBody = affineFor.getBody();

if (affineBody->mightHaveTerminator())
rewriter.eraseOp(affineBody->getTerminator());

SmallVector<Value> argValues;
argValues.push_back(actualIndex);
llvm::append_range(argValues, affineFor.getRegionIterArgs());
rewriter.inlineBlockBefore(op.getBody(), affineBody, affineBody->end(),
argValues);

auto scfYieldOp = cast<scf::YieldOp>(affineBody->getTerminator());
rewriter.setInsertionPointToEnd(affineBody);
rewriter.replaceOpWithNewOp<affine::AffineYieldOp>(
scfYieldOp, scfYieldOp->getOperands());

rewriter.replaceOp(op, affineFor);
return success();
}
};

} // namespace

void mlir::populateSCFToAffineConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ForOpRewrite>(patterns.getContext());
}

void SCFToAffinePass::runOnOperation() {
MLIRContext &ctx = getContext();
RewritePatternSet patterns(&ctx);
populateSCFToAffineConversionPatterns(patterns);
walkAndApplyPatterns(getOperation(), std::move(patterns));
}
88 changes: 88 additions & 0 deletions mlir/test/Conversion/SCFToAffine/scf-to-affine.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// RUN: mlir-opt -raise-scf-to-affine -split-input-file %s | FileCheck %s

// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1)[s0] -> ((d1 - d0 + s0 - 1) floordiv s0)>
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)>
// CHECK-LABEL: func.func @simple_loop(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi32>,
// CHECK-SAME: %[[ARG1:.*]]: memref<3xindex>) {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
// CHECK: %[[VAL_4:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_1]]] : memref<3xindex>
// CHECK: %[[VAL_5:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_2]]] : memref<3xindex>
// CHECK: %[[VAL_6:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_3]]] : memref<3xindex>
// CHECK: affine.for %[[VAL_7:.*]] = 0 to #[[$ATTR_0]](%[[VAL_4]], %[[VAL_5]]){{\[}}%[[VAL_6]]] {
// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_4]], %[[VAL_7]]){{\[}}%[[VAL_6]]]
// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[VAL_8]]] : memref<?xi32>
// CHECK: }
// CHECK: return
// CHECK: }

func.func @simple_loop(%arg0: memref<?xi32>, %arg1: memref<3xindex>) {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%0 = memref.load %arg1[%c0] : memref<3xindex>
%1 = memref.load %arg1[%c1] : memref<3xindex>
%2 = memref.load %arg1[%c2] : memref<3xindex>
scf.for %arg2 = %0 to %1 step %2 {
memref.store %c0_i32, %arg0[%arg2] : memref<?xi32>
}
return
}

// -----

// CHECK: #[[$ATTR_2:.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @loop_with_constant_step(
// CHECK-SAME: %[[ARG0:.*]]: memref<?xi32>,
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: index) {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
// CHECK: affine.for %[[VAL_2:.*]] = #[[$ATTR_2]](%[[ARG1]]) to #[[$ATTR_2]](%[[ARG2]]) step 3 {
// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[VAL_2]]] : memref<?xi32>
// CHECK: }
// CHECK: return
// CHECK: }

func.func @loop_with_constant_step(%arg0: memref<?xi32>, %arg1: index, %arg2: index) {
%c0_i32 = arith.constant 0 : i32
%c3 = arith.constant 3 : index
scf.for %arg3 = %arg1 to %arg2 step %c3 {
memref.store %c0_i32, %arg0[%arg3] : memref<?xi32>
}
return
}

// -----

// CHECK: #[[$ATTR_3:.+]] = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func.func @nested_loop(
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi32>,
// CHECK-SAME: %[[ARG1:.*]]: index,
// CHECK-SAME: %[[ARG2:.*]]: index) {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK: affine.for %[[VAL_3:.*]] = #[[$ATTR_3]](%[[VAL_1]]) to #[[$ATTR_3]](%[[ARG1]]) {
// CHECK: affine.for %[[VAL_4:.*]] = #[[$ATTR_3]](%[[VAL_1]]) to #[[$ATTR_3]](%[[ARG2]]) {
// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[VAL_3]], %[[VAL_4]]] : memref<?x?xi32>
// CHECK: }
// CHECK: }
// CHECK: return
// CHECK: }

func.func @nested_loop(%arg0: memref<?x?xi32>, %arg1: index, %arg2: index) {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
scf.for %arg3 = %c0 to %arg1 step %c1 {
scf.for %arg4 = %c0 to %arg2 step %c1 {
memref.store %c0_i32, %arg0[%arg3, %arg4] : memref<?x?xi32>
}
}
return
}