diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h index 3dc48b2201cf2..2507ef2834dc5 100644 --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -58,6 +58,7 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h" #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" +#include "mlir/Conversion/SCFToAffine/SCFToAffine.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 6e1baaf23fcf7..38f35b2dadd94 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1025,6 +1025,18 @@ def ReconcileUnrealizedCastsPass : Pass<"reconcile-unrealized-casts"> { }]; } +//===----------------------------------------------------------------------===// +// SCFToAffine +//===----------------------------------------------------------------------===// + +def RaiseSCFToAffinePass : Pass<"raise-scf-to-affine"> { + let summary = "Raise SCF to affine ops"; + let dependentDialects = [ + "affine::AffineDialect", + "scf::SCFDialect", + ]; +} + //===----------------------------------------------------------------------===// // SCFToControlFlow //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h b/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h new file mode 100644 index 0000000000000..4f87ef8e6c6e4 --- /dev/null +++ b/mlir/include/mlir/Conversion/SCFToAffine/SCFToAffine.h @@ -0,0 +1,26 @@ +//===- SCFToAffine.h - SCF to Affine Pass entrypoint ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_ +#define MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_ + +#include + +namespace mlir { +class Pass; +class RewritePatternSet; + +#define GEN_PASS_DECL_RAISESCFTOAFFINEPASS +#include "mlir/Conversion/Passes.h.inc" + +/// Collect a set of patterns to convert SCF operations to Affine operations. +void populateSCFToAffineConversionPatterns(RewritePatternSet &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_SCFTOAFFINE_SCFTOAFFINE_H_ diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 785cb8293810c..b8059fcbfb028 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -51,6 +51,7 @@ add_subdirectory(OpenACCToSCF) add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) add_subdirectory(ReconcileUnrealizedCasts) +add_subdirectory(SCFToAffine) add_subdirectory(SCFToControlFlow) add_subdirectory(SCFToEmitC) add_subdirectory(SCFToGPU) diff --git a/mlir/lib/Conversion/SCFToAffine/CMakeLists.txt b/mlir/lib/Conversion/SCFToAffine/CMakeLists.txt new file mode 100644 index 0000000000000..bf1494d6f3cf0 --- /dev/null +++ b/mlir/lib/Conversion/SCFToAffine/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRSCFToAffine + SCFToAffine.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToAffine + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArithDialect + MLIRAffineDialect + MLIRLLVMDialect + MLIRSCFDialect + MLIRSCFTransforms + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/SCFToAffine/SCFToAffine.cpp b/mlir/lib/Conversion/SCFToAffine/SCFToAffine.cpp new file mode 100644 index 0000000000000..e68bb2123cadc --- /dev/null +++ b/mlir/lib/Conversion/SCFToAffine/SCFToAffine.cpp @@ -0,0 +1,100 @@ +//===- SCFToAffine.cpp - SCF to Affine conversion -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to raise scf.for, scf.if and loop.terminator +// ops into affine ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/SCFToAffine/SCFToAffine.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +#define GEN_PASS_DEF_RAISESCFTOAFFINEPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +struct SCFToAffinePass + : public impl::RaiseSCFToAffinePassBase { + void runOnOperation() override; +}; + +struct ForOpRewrite : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lower = op.getLowerBound(); + auto upper = op.getUpperBound(); + auto step = op.getStep(); + + if (!affine::isValidDim(lower) || !affine::isValidDim(upper) || + !affine::isValidSymbol(step)) + return llvm::failure(); + + auto lowerDim = rewriter.getAffineDimExpr(0); + auto upperDim = rewriter.getAffineDimExpr(1); + auto stepSym = rewriter.getAffineSymbolExpr(0); + auto affineFor = affine::AffineForOp::create( + rewriter, loc, ValueRange(), rewriter.getConstantAffineMap(0), + ValueRange({lower, upper, step}), + AffineMap::get(2, 1, + (upperDim - lowerDim + stepSym - 1).floorDiv(stepSym)), + 1, op.getInits()); + auto affineBody = affineFor.getBody(); + + if (affineBody->mightHaveTerminator()) + rewriter.eraseOp(affineBody->getTerminator()); + + rewriter.setInsertionPointToStart(affineBody); + auto actualIndexMap = + AffineMap::get(2, 1, lowerDim + rewriter.getAffineDimExpr(1) * stepSym); + Value newIndVar = + affine::AffineApplyOp::create( + rewriter, op.getLoc(), actualIndexMap, + ValueRange({lower, affineFor.getInductionVar(), step})) + .getResult(); + + SmallVector argValues; + argValues.push_back(newIndVar); + llvm::append_range(argValues, affineFor.getRegionIterArgs()); + rewriter.inlineBlockBefore(op.getBody(), affineBody, affineBody->end(), + argValues); + + auto scfYieldOp = cast(affineBody->getTerminator()); + rewriter.setInsertionPointToEnd(affineBody); + rewriter.replaceOpWithNewOp( + scfYieldOp, scfYieldOp->getOperands()); + + rewriter.replaceOp(op, affineFor); + return success(); + } +}; + +} // namespace + +void mlir::populateSCFToAffineConversionPatterns(RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +void SCFToAffinePass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateSCFToAffineConversionPatterns(patterns); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/test/Conversion/SCFToAffine/scf-to-affine.mlir b/mlir/test/Conversion/SCFToAffine/scf-to-affine.mlir new file mode 100644 index 0000000000000..6f419bc8ee9ce --- /dev/null +++ b/mlir/test/Conversion/SCFToAffine/scf-to-affine.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt -raise-scf-to-affine -split-input-file %s | FileCheck %s + +// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1, s2] -> ((s0 - s1 + s2 - 1) floordiv s0)> +// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1)[s0] -> (d0 + d1 * s0)> +// CHECK-LABEL: func.func @simple_loop( +// CHECK-SAME: %[[ARG0:.*]]: memref, +// CHECK-SAME: %[[ARG1:.*]]: memref<3xindex>) { +// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_4:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_1]]] : memref<3xindex> +// CHECK: %[[VAL_5:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_2]]] : memref<3xindex> +// CHECK: %[[VAL_6:.*]] = memref.load %[[ARG1]]{{\[}}%[[VAL_3]]] : memref<3xindex> +// CHECK: affine.for %[[VAL_7:.*]] = 0 to #[[$ATTR_0]](){{\[}}%[[VAL_6]], %[[VAL_4]], %[[VAL_5]]] { +// CHECK: %[[VAL_8:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_4]], %[[VAL_7]]){{\[}}%[[VAL_6]]] +// CHECK: memref.store %[[VAL_0]], %[[ARG0]]{{\[}}%[[VAL_8]]] : memref +// CHECK: } +// CHECK: return +// CHECK: } + +func.func @simple_loop(%arg0: memref, %arg1: memref<3xindex>) { + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = memref.load %arg1[%c0] : memref<3xindex> + %1 = memref.load %arg1[%c1] : memref<3xindex> + %2 = memref.load %arg1[%c2] : memref<3xindex> + scf.for %arg2 = %0 to %1 step %2 { + memref.store %c0_i32, %arg0[%arg2] : memref + } + return +}