diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h index 96bd3c6a9a7bc..a592114db4325 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -117,6 +117,11 @@ std::unique_ptr createAffineExpandIndexOpsPass(); /// operations. std::unique_ptr createAffineExpandIndexOpsAsAffinePass(); +/// Creates a pass in order to remove invalid loops or to move the IR out of the +/// loop when the loop is only iterated once. +std::unique_ptr> +createRemoveSingleIterationLoopPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td index 728b8d25efcf2..f715ff6839fd5 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -229,6 +229,11 @@ def AffineLoopUnrollAndJam : InterfacePass<"affine-loop-unroll-jam", "FunctionOp ]; } +def RemoveSingleIterationLoop : InterfacePass<"remove-single-iteration-loop", "FunctionOpInterface"> { + let summary = "Remove distributed loop with single iteration."; + let constructor = "mlir::affine::createRemoveSingleIterationLoopPass()"; +} + def AffinePipelineDataTransfer : Pass<"affine-pipeline-data-transfer", "func::FuncOp"> { let summary = "Pipeline non-blocking data transfers between explicitly " diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h index bf830a29613fd..fb61f9333fac6 100644 --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -41,6 +41,10 @@ void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns); /// `affine.apply` representations. void populateAffineExpandIndexOpsAsAffinePatterns(RewritePatternSet &patterns); +/// Insert pattern to remove single iteration loop. The pattern will detect +/// single iteration loops based on the range returned ValueBoundsOpInterface. +void populateRemoveSingleIterationLoopPattern(RewritePatternSet &patterns); + /// Helper function to rewrite `op`'s affine map and reorder its operands such /// that they are in increasing order of hoistability (i.e. the least hoistable) /// operands come first in the operand list. diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt index c42789b01bc9f..317e28b542565 100644 --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRAffineTransforms ReifyValueBounds.cpp SuperVectorize.cpp SimplifyAffineStructures.cpp + RemoveSingleIterationLoop.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Affine diff --git a/mlir/lib/Dialect/Affine/Transforms/RemoveSingleIterationLoop.cpp b/mlir/lib/Dialect/Affine/Transforms/RemoveSingleIterationLoop.cpp new file mode 100644 index 0000000000000..e8bbc2b6d8948 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/RemoveSingleIterationLoop.cpp @@ -0,0 +1,136 @@ +//===- RemoveSingleIterationLoop.cpp --- remove single iteration loop ---*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/Support/Debug.h" + +namespace mlir { +namespace affine { +#define GEN_PASS_DEF_REMOVESINGLEITERATIONLOOP +#include "mlir/Dialect/Affine/Passes.h.inc" +} // namespace affine +} // namespace mlir + +#define DEBUG_TYPE "affine-remove-single-iteration" + +using namespace mlir; +using namespace affine; + +/// Replaces the given op with the contents of the given single-block region, +/// using the operands of the block terminator to replace operation results. +static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, + Region ®ion, ValueRange blockArgs = {}) { + assert(llvm::hasSingleElement(region) && "expected single-region block"); + Block *block = ®ion.front(); + Operation *terminator = block->getTerminator(); + ValueRange results = terminator->getOperands(); + rewriter.inlineBlockBefore(block, op, blockArgs); + rewriter.replaceOp(op, results); + rewriter.eraseOp(terminator); +} + +/// Return true if we can prove that the we always run at least the first +/// iteration of the ForOp. +static bool alwaysRunsFirstIteration(AffineForOp op) { + // Can't perform the analysis if the loops's bounds aren't index-typed. + if (!op.getInductionVar().getType().isIndex()) + return false; + SmallVector lowerMapOperands = op.getLowerBoundOperands(); + SmallVector upperMapOperands = op.getUpperBoundOperands(); + ValueBoundsConstraintSet::Variable lower(op.getLowerBoundMap(), + lowerMapOperands); + ValueBoundsConstraintSet::Variable upper(op.getUpperBoundMap(), + upperMapOperands); + FailureOr isLb = ValueBoundsConstraintSet::compare( + lower, ValueBoundsConstraintSet::LT, upper); + return isLb.value_or(false); +} + +/// Return true if we can prove that the we never run more than one iteration of +/// the ForOp. +static bool neverRunsSecondIteration(AffineForOp op) { + // Can't perform the analysis if the loops's bounds aren't index-typed. + if (!op.getInductionVar().getType().isIndex()) + return false; + + // The loop will only loop once if the inducation variable for the next time + // in the loop is greater than or equal to upper. + MLIRContext *context = op.getContext(); + SmallVector lowerMapOperands = op.getLowerBoundOperands(); + SmallVector upperMapOperands = op.getUpperBoundOperands(); + SmallVector results; + AffineMap lowerMap = op.getLowerBoundMap(); + for (AffineExpr expr : lowerMap.getResults()) { + results.push_back(expr + op.getStep().getSExtValue()); + } + AffineMap nextItMap = AffineMap::get( + lowerMap.getNumDims(), lowerMap.getNumSymbols(), results, context); + ValueBoundsConstraintSet::Variable nextItVar(nextItMap, lowerMapOperands); + ValueBoundsConstraintSet::Variable upperVar(op.getUpperBoundMap(), + upperMapOperands); + FailureOr isUpperUnderNextIter = ValueBoundsConstraintSet::compare( + nextItVar, ValueBoundsConstraintSet::LE, upperVar); + return isUpperUnderNextIter.value_or(false); +} + +namespace { + +/// Rewriting pattern that replaces single-iteration loops with their bodies. +struct SimplifyTrivialLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineForOp op, + PatternRewriter &rewriter) const override { + if (!(alwaysRunsFirstIteration(op) && neverRunsSecondIteration(op))) { + return failure(); + } + + // The first iteration is always run and the second iteration is never run + // so the loop always have 1 iteration. Inline its body and remove the loop. + SmallVector blockArgs; + blockArgs.reserve(op.getInits().size() + 1); + rewriter.setInsertionPointToStart(op.getBody()); + Value lower = rewriter.create( + op.getLoc(), op.getLowerBoundMap(), op.getLowerBoundOperands()); + op.getInductionVar().replaceAllUsesWith(lower); + blockArgs.push_back(lower); + llvm::append_range(blockArgs, op.getInits()); + replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs); + return success(); + } +}; + +struct RemoveSingleIterationLoop + : public affine::impl::RemoveSingleIterationLoopBase< + RemoveSingleIterationLoop> { + void runOnOperation() override { + auto funcOp = getOperation(); + RewritePatternSet patterns(funcOp.getContext()); + populateRemoveSingleIterationLoopPattern(patterns); + if (failed(applyPatternsGreedily(funcOp, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::affine::populateRemoveSingleIterationLoopPattern( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + +std::unique_ptr> +mlir::affine::createRemoveSingleIterationLoopPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Affine/remove-single-iteration-loop.mlir b/mlir/test/Dialect/Affine/remove-single-iteration-loop.mlir new file mode 100644 index 0000000000000..c8c321242f626 --- /dev/null +++ b/mlir/test/Dialect/Affine/remove-single-iteration-loop.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(remove-single-iteration-loop))' -split-input-file | FileCheck %s + +// ----- + +// CHECK-LABEL: func @loop_once( +func.func @loop_once(%arg : index) -> index{ + %0 = affine.for %iv = 2 to 3 step 1 iter_args(%arg1 = %arg) -> index { + %sum = arith.addi %arg1, %iv : index + affine.yield %sum : index + } + return %0 : index +} +// CHECK-SAME: %[[ARG:.*]]: index) +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: %[[SUM:.*]] = arith.addi %[[ARG]], %[[C2]] : index +// CHECK: return %[[SUM]] : index + +// ----- + +// CHECK-LABEL: func @invalid_loop( +func.func @invalid_loop(%arg : index) -> index{ + %0 = affine.for %iv = 4 to 3 step 1 iter_args(%arg1 = %arg) -> index { + %sum = arith.addi %arg1, %iv : index + affine.yield %sum : index + } + return %0 : index +} +// CHECK-SAME: %[[ARG:.*]]: index) +// CHECK: return %[[ARG]] : index + +// ----- + +// CHECK-LABEL: func @gpu_invalid_loop +func.func @gpu_invalid_loop() { + %0 = arith.constant 0 :index + %1 = arith.constant 2 : index + gpu.launch blocks(%bx, %by, %bz) in (%sz_bx = %1, %sz_by = %1, %sz_bz = %1) + threads(%tx, %ty, %tz) in (%sz_tx = %1, %sz_ty = %1, %sz_tz = %1) { + %threadid = gpu.thread_id x + affine.for %iv = %tx to 0 step 2 iter_args(%arg = %0) -> index { + %3 = arith.addi %arg, %0 : index + affine.yield %3 : index + } + gpu.terminator + } + // CHECK-NEXT: return + return +}