From ba4ae62b1ab03531b4bc94d08b7c5579d8fdf5c2 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Wed, 13 Nov 2024 23:44:29 +0000 Subject: [PATCH 1/2] [mlir][Affine] Expand affine.[de]linearize_index without affine maps As the documentation for -affine-expand-index-ops says, affine.delinearize_index and affine.linearize_index don't need to be expanded into the affine dialect. Expanding these operations into affine.apply operations can introduce unwanted "simplifications", mainly translations of `(dN mod C + ...)` to `(dN + ... - (dN floordiv C) * C)` and similar, which create worse generated code. This commit resolves this issue by expanding out affine.delanierize_index directly. In addition, the lowering of affine.linearize_index now sorts the operands by loop-independence, allowing an increased amount of loop-invariant code motion after lowering. The old behavior is preserved as -expand-affine-index-ops-as-affine but is no longer the default --- mlir/include/mlir/Dialect/Affine/LoopUtils.h | 5 + mlir/include/mlir/Dialect/Affine/Passes.h | 4 + mlir/include/mlir/Dialect/Affine/Passes.td | 5 + .../Dialect/Affine/Transforms/Transforms.h | 4 + .../Transforms/AffineExpandIndexOps.cpp | 148 ++++++++++++++++-- .../AffineExpandIndexOpsAsAffine.cpp | 98 ++++++++++++ .../Dialect/Affine/Transforms/CMakeLists.txt | 1 + mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp | 12 ++ .../AffineToStandard/lower-affine.mlir | 66 -------- .../affine-expand-index-ops-as-affine.mlir | 70 +++++++++ .../Affine/affine-expand-index-ops.mlir | 101 ++++++++---- 11 files changed, 405 insertions(+), 109 deletions(-) create mode 100644 mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp create mode 100644 mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h index 380c742b5224c..7fe1f6d48ceeb 100644 --- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h +++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h @@ -301,6 +301,11 @@ separateFullTiles(MutableArrayRef nest, /// Walk an affine.for to find a band to coalesce. LogicalResult coalescePerfectlyNestedAffineLoops(AffineForOp op); +/// Count the number of loops surrounding `operand` such that operand could be +/// hoisted above. +/// Stop counting at the first loop over which the operand cannot be hoisted. +/// This counts any LoopLikeOpInterface, not just affine.for. +int64_t numEnclosingInvariantLoops(OpOperand &operand); } // namespace affine } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h index 61f24255f305f..e152101236dc7 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -116,6 +116,10 @@ std::unique_ptr> createPipelineDataTransferPass(); /// operations (not necessarily restricted to Affine dialect). std::unique_ptr createAffineExpandIndexOpsPass(); +/// Creates a pass to expand affine index operations into affine.apply +/// operations. +std::unique_ptr createAffineExpandIndexOpsAsAffinePass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td index b08e803345f76..77073aa29da73 100644 --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -408,4 +408,9 @@ def AffineExpandIndexOps : Pass<"affine-expand-index-ops"> { let constructor = "mlir::affine::createAffineExpandIndexOpsPass()"; } +def AffineExpandIndexOpsAsAffine : Pass<"affine-expand-index-ops-as-affine"> { + let summary = "Lower affine operations operating on indices into affine.apply operations"; + let constructor = "mlir::affine::createAffineExpandIndexOpsAsAffinePass()"; +} + #endif // MLIR_DIALECT_AFFINE_PASSES diff --git a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h index b244d37c0707f..bf830a29613fd 100644 --- a/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Affine/Transforms/Transforms.h @@ -37,6 +37,10 @@ class AffineApplyOp; /// operations (not necessarily restricted to Affine dialect). void populateAffineExpandIndexOpsPatterns(RewritePatternSet &patterns); +/// Populate patterns that expand affine index operations into their equivalent +/// `affine.apply` representations. +void populateAffineExpandIndexOpsAsAffinePatterns(RewritePatternSet &patterns); + /// Helper function to rewrite `op`'s affine map and reorder its operands such /// that they are in increasing order of hoistability (i.e. the least hoistable) /// operands come first in the operand list. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp index 15478e0e1e3a5..d7b218225bc9a 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -10,6 +10,7 @@ // fundamental operations. //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Affine/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -28,6 +29,50 @@ namespace affine { using namespace mlir; using namespace mlir::affine; +/// Given a basis (in static and dynamic components), return the sequence of +/// suffix products of the basis, including the product of the entire basis, +/// which must **not** contain an outer bound. +/// +/// If excess dynamic values are provided, the values at the beginning +/// will be ignored. This allows for dropping the outer bound without +/// needing to manipulate the dynamic value array. +static SmallVector computeStrides(Location loc, RewriterBase &rewriter, + ValueRange dynamicBasis, + ArrayRef staticBasis) { + if (staticBasis.empty()) + return {}; + + SmallVector result; + result.reserve(staticBasis.size()); + size_t dynamicIndex = dynamicBasis.size(); + Value dynamicPart = nullptr; + int64_t staticPart = 1; + for (int64_t elem : llvm::reverse(staticBasis)) { + if (ShapedType::isDynamic(elem)) { + if (dynamicPart) + dynamicPart = rewriter.create( + loc, dynamicPart, dynamicBasis[dynamicIndex - 1]); + else + dynamicPart = dynamicBasis[dynamicIndex - 1]; + --dynamicIndex; + } else { + staticPart *= elem; + } + + if (dynamicPart && staticPart == 1) { + result.push_back(dynamicPart); + } else { + Value stride = + rewriter.createOrFold(loc, staticPart); + if (dynamicPart) + stride = rewriter.create(loc, dynamicPart, stride); + result.push_back(stride); + } + } + std::reverse(result.begin(), result.end()); + return result; +} + namespace { /// Lowers `affine.delinearize_index` into a sequence of division and remainder /// operations. @@ -36,18 +81,62 @@ struct LowerDelinearizeIndexOps using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, PatternRewriter &rewriter) const override { - FailureOr> multiIndex = - delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(), - op.getEffectiveBasis(), /*hasOuterBound=*/false); - if (failed(multiIndex)) - return failure(); - rewriter.replaceOp(op, *multiIndex); + Location loc = op.getLoc(); + Value linearIdx = op.getLinearIndex(); + unsigned numResults = op.getNumResults(); + ArrayRef staticBasis = op.getStaticBasis(); + if (numResults == staticBasis.size()) + staticBasis = staticBasis.drop_front(); + + if (numResults == 1) { + rewriter.replaceOp(op, linearIdx); + return success(); + } + + SmallVector results; + results.reserve(numResults); + SmallVector strides = + computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis); + + Value zero = rewriter.createOrFold(loc, 0); + + Value initialPart = + rewriter.create(loc, linearIdx, strides.front()); + results.push_back(initialPart); + + auto emitModTerm = [&](Value stride) -> Value { + Value remainder = rewriter.create(loc, linearIdx, stride); + Value remainderNegative = rewriter.create( + loc, arith::CmpIPredicate::slt, remainder, zero); + Value corrected = rewriter.create(loc, remainder, stride); + Value mod = rewriter.create(loc, remainderNegative, + corrected, remainder); + return mod; + }; + + // Generate all the intermediate parts + for (size_t i = 0, e = strides.size() - 1; i < e; ++i) { + Value thisStride = strides[i]; + Value nextStride = strides[i + 1]; + Value modulus = emitModTerm(thisStride); + // We know both inputs are positive, so floorDiv == div. + // This could potentially be a divui, but it's not clear if that would + // cause issues. + Value divided = rewriter.create(loc, modulus, nextStride); + results.push_back(divided); + } + + results.push_back(emitModTerm(strides.back())); + + rewriter.replaceOp(op, results); return success(); } }; /// Lowers `affine.linearize_index` into a sequence of multiplications and -/// additions. +/// additions. Make a best effort to sort the input indices so that +/// the most loop-invariant terms are at the left of the additions +/// to enable loop-invariant code motion. struct LowerLinearizeIndexOps final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AffineLinearizeIndexOp op, @@ -58,13 +147,44 @@ struct LowerLinearizeIndexOps final : OpRewritePattern { return success(); } - SmallVector multiIndex = - getAsOpFoldResult(op.getMultiIndex()); - OpFoldResult linearIndex = - linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis()); - Value linearIndexValue = - getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex); - rewriter.replaceOp(op, linearIndexValue); + Location loc = op.getLoc(); + ValueRange multiIndex = op.getMultiIndex(); + size_t numIndexes = multiIndex.size(); + ArrayRef staticBasis = op.getStaticBasis(); + if (numIndexes == staticBasis.size()) + staticBasis = staticBasis.drop_front(); + + SmallVector strides = + computeStrides(loc, rewriter, op.getDynamicBasis(), staticBasis); + SmallVector> scaledValues; + scaledValues.reserve(numIndexes); + + // Note: strides doesn't contain a value for the final element (stride 1) + // and everything else lines up. We use the "mutable" accessor so we can get + // our hands on an `OpOperand&` for the loop invariant counting function. + for (auto [stride, idxOp] : + llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) { + Value scaledIdx = + rewriter.create(loc, idxOp.get(), stride); + int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp); + scaledValues.emplace_back(scaledIdx, numHoistableLoops); + } + scaledValues.emplace_back( + multiIndex.back(), + numEnclosingInvariantLoops(op.getMultiIndexMutable()[numIndexes - 1])); + + // Sort by how many enclosing loops there are, ties implicitly broken by + // size of the stride. + llvm::stable_sort(scaledValues, + [&](auto l, auto r) { return l.second > r.second; }); + + Value result = scaledValues.front().first; + for (auto [scaledValue, numHoistableLoops] : + llvm::drop_begin(scaledValues)) { + std::ignore = numHoistableLoops; + result = rewriter.create(loc, result, scaledValue); + } + rewriter.replaceOp(op, result); return success(); } }; diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp new file mode 100644 index 0000000000000..bfcc1ddf91653 --- /dev/null +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOpsAsAffine.cpp @@ -0,0 +1,98 @@ +//===- AffineExpandIndexOpsAsAffine.cpp - Expand index ops to apply pass --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to expand affine index ops into one or more more +// fundamental operations. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +namespace affine { +#define GEN_PASS_DEF_AFFINEEXPANDINDEXOPSASAFFINE +#include "mlir/Dialect/Affine/Passes.h.inc" +} // namespace affine +} // namespace mlir + +using namespace mlir; +using namespace mlir::affine; + +namespace { +/// Lowers `affine.delinearize_index` into a sequence of division and remainder +/// operations. +struct LowerDelinearizeIndexOps + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AffineDelinearizeIndexOp op, + PatternRewriter &rewriter) const override { + FailureOr> multiIndex = + delinearizeIndex(rewriter, op->getLoc(), op.getLinearIndex(), + op.getEffectiveBasis(), /*hasOuterBound=*/false); + if (failed(multiIndex)) + return failure(); + rewriter.replaceOp(op, *multiIndex); + return success(); + } +}; + +/// Lowers `affine.linearize_index` into a sequence of multiplications and +/// additions. +struct LowerLinearizeIndexOps final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AffineLinearizeIndexOp op, + PatternRewriter &rewriter) const override { + // Should be folded away, included here for safety. + if (op.getMultiIndex().empty()) { + rewriter.replaceOpWithNewOp(op, 0); + return success(); + } + + SmallVector multiIndex = + getAsOpFoldResult(op.getMultiIndex()); + OpFoldResult linearIndex = + linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis()); + Value linearIndexValue = + getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex); + rewriter.replaceOp(op, linearIndexValue); + return success(); + } +}; + +class ExpandAffineIndexOpsAsAffinePass + : public affine::impl::AffineExpandIndexOpsAsAffineBase< + ExpandAffineIndexOpsAsAffinePass> { +public: + ExpandAffineIndexOpsAsAffinePass() = default; + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + populateAffineExpandIndexOpsAsAffinePatterns(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::affine::populateAffineExpandIndexOpsAsAffinePatterns( + RewritePatternSet &patterns) { + patterns.insert( + patterns.getContext()); +} + +std::unique_ptr mlir::affine::createAffineExpandIndexOpsAsAffinePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt index 772f15335d907..c42789b01bc9f 100644 --- a/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRAffineTransforms AffineDataCopyGeneration.cpp AffineExpandIndexOps.cpp + AffineExpandIndexOpsAsAffine.cpp AffineLoopInvariantCodeMotion.cpp AffineLoopNormalize.cpp AffineParallelize.cpp diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index d6fc4ed07bfab..e75d1c571d08c 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -2772,3 +2772,15 @@ LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) { } return result; } + +int64_t mlir::affine::numEnclosingInvariantLoops(OpOperand &operand) { + int64_t count = 0; + Operation *currentOp = operand.getOwner(); + while (auto loopOp = currentOp->getParentOfType()) { + if (!loopOp.isDefinedOutsideOfLoop(operand.get())) + break; + currentOp = loopOp; + count++; + } + return count; +} diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir index 3be42661f63ee..00d7b6b8d65f6 100644 --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -927,69 +927,3 @@ func.func @affine_parallel_with_reductions_i64(%arg0: memref<3x3xi64>, %arg1: me // CHECK: scf.reduce.return %[[RES]] : i64 // CHECK: } // CHECK: } - -/////////////////////////////////////////////////////////////////////// - -func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) { - %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index - return %1#0, %1#1, %1#2 : index, index, index -} -// CHECK-LABEL: func.func @test_dilinearize_index( -// CHECK-SAME: %[[VAL_0:.*]]: index) -> (index, index, index) { -// CHECK: %[[VAL_1:.*]] = arith.constant 224 : index -// CHECK: %[[VAL_2:.*]] = arith.constant 50176 : index -// CHECK: %[[VAL_3:.*]] = arith.constant 50176 : index -// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_5:.*]] = arith.constant -1 : index -// CHECK: %[[VAL_6:.*]] = arith.cmpi slt, %[[VAL_0]], %[[VAL_4]] : index -// CHECK: %[[VAL_7:.*]] = arith.subi %[[VAL_5]], %[[VAL_0]] : index -// CHECK: %[[VAL_8:.*]] = arith.select %[[VAL_6]], %[[VAL_7]], %[[VAL_0]] : index -// CHECK: %[[VAL_9:.*]] = arith.divsi %[[VAL_8]], %[[VAL_3]] : index -// CHECK: %[[VAL_10:.*]] = arith.subi %[[VAL_5]], %[[VAL_9]] : index -// CHECK: %[[VAL_11:.*]] = arith.select %[[VAL_6]], %[[VAL_10]], %[[VAL_9]] : index -// CHECK: %[[VAL_12:.*]] = arith.constant 50176 : index -// CHECK: %[[VAL_13:.*]] = arith.remsi %[[VAL_0]], %[[VAL_12]] : index -// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_15:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_14]] : index -// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_13]], %[[VAL_12]] : index -// CHECK: %[[VAL_17:.*]] = arith.select %[[VAL_15]], %[[VAL_16]], %[[VAL_13]] : index -// CHECK: %[[VAL_18:.*]] = arith.constant 50176 : index -// CHECK: %[[VAL_19:.*]] = arith.remsi %[[VAL_0]], %[[VAL_18]] : index -// CHECK: %[[VAL_20:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_21:.*]] = arith.cmpi slt, %[[VAL_19]], %[[VAL_20]] : index -// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_18]] : index -// CHECK: %[[VAL_23:.*]] = arith.select %[[VAL_21]], %[[VAL_22]], %[[VAL_19]] : index -// CHECK: %[[VAL_24:.*]] = arith.constant 224 : index -// CHECK: %[[VAL_25:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_26:.*]] = arith.constant -1 : index -// CHECK: %[[VAL_27:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_25]] : index -// CHECK: %[[VAL_28:.*]] = arith.subi %[[VAL_26]], %[[VAL_23]] : index -// CHECK: %[[VAL_29:.*]] = arith.select %[[VAL_27]], %[[VAL_28]], %[[VAL_23]] : index -// CHECK: %[[VAL_30:.*]] = arith.divsi %[[VAL_29]], %[[VAL_24]] : index -// CHECK: %[[VAL_31:.*]] = arith.subi %[[VAL_26]], %[[VAL_30]] : index -// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_27]], %[[VAL_31]], %[[VAL_30]] : index -// CHECK: %[[VAL_33:.*]] = arith.constant 224 : index -// CHECK: %[[VAL_34:.*]] = arith.remsi %[[VAL_0]], %[[VAL_33]] : index -// CHECK: %[[VAL_35:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_36:.*]] = arith.cmpi slt, %[[VAL_34]], %[[VAL_35]] : index -// CHECK: %[[VAL_37:.*]] = arith.addi %[[VAL_34]], %[[VAL_33]] : index -// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_37]], %[[VAL_34]] : index -// CHECK: return %[[VAL_11]], %[[VAL_32]], %[[VAL_38]] : index, index, index -// CHECK: } - -///////////////////////////////////////////////////////////////////// - -func.func @test_linearize_index(%arg0: index, %arg1: index, %arg2: index) -> index { - %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 3, 5) : index - return %ret : index -} - -// CHECK-LABEL: @test_linearize_index -// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index) -// CHECK: %[[c15:.+]] = arith.constant 15 : index -// CHECK-NEXT: %[[tmp0:.+]] = arith.muli %[[arg0]], %[[c15]] : index -// CHECK-NEXT: %[[c5:.+]] = arith.constant 5 : index -// CHECK-NEXT: %[[tmp1:.+]] = arith.muli %[[arg1]], %[[c5]] : index -// CHECK-NEXT: %[[tmp2:.+]] = arith.addi %[[tmp0]], %[[tmp1]] : index -// CHECK-NEXT: %[[ret:.+]] = arith.addi %[[tmp2]], %[[arg2]] : index -// CHECK-NEXT: return %[[ret]] diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir new file mode 100644 index 0000000000000..bf9f00da5793a --- /dev/null +++ b/mlir/test/Dialect/Affine/affine-expand-index-ops-as-affine.mlir @@ -0,0 +1,70 @@ +// RUN: mlir-opt %s -affine-expand-index-ops-as-affine -split-input-file | FileCheck %s + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 floordiv 50176)> +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> (s0 mod 224)> + +// CHECK-LABEL: @static_basis +// CHECK-SAME: (%[[IDX:.+]]: index) +// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[IDX]]] +// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[IDX]]] +// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[IDX]]] +// CHECK: return %[[N]], %[[P]], %[[Q]] +func.func @static_basis(%linear_index: index) -> (index, index, index) { + %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index + return %1#0, %1#1, %1#2 : index, index, index +} + +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s2 floordiv (s0 * s1))> +// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) floordiv s1)> +// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) mod s1)> + +// CHECK-LABEL: @dynamic_basis +// CHECK-SAME: (%[[IDX:.+]]: index, %[[MEMREF:.+]]: memref +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] : +// CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] : +// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]] +// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]] +// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]] +// CHECK: return %[[N]], %[[P]], %[[Q]] +func.func @dynamic_basis(%linear_index: index, %src: memref) -> (index, index, index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %b1 = memref.dim %src, %c1 : memref + %b2 = memref.dim %src, %c2 : memref + // Note: no outer bound. + %1:3 = affine.delinearize_index %linear_index into (%b1, %b2) : index, index, index + return %1#0, %1#1, %1#2 : index, index, index +} + +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)> + +// CHECK-LABEL: @linearize_static +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index) +// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]] +// CHECK: return %[[val_0]] +func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index { + %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index + func.return %0 : index +} + +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))> + +// CHECK-LABEL: @linearize_dynamic +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index) +// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg4]], %[[arg2]], %[[arg3]]] +// CHECK: return %[[val_0]] +func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index { + // Note: no outer bounds + %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index + func.return %0 : index +} diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir index 650555cfb5fe1..e4b1b98d1893d 100644 --- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir +++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir @@ -1,38 +1,48 @@ // RUN: mlir-opt %s -affine-expand-index-ops -split-input-file | FileCheck %s -// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 floordiv 50176)> -// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> ((s0 mod 50176) floordiv 224)> -// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0] -> (s0 mod 224)> - -// CHECK-LABEL: @static_basis +// CHECK-LABEL: @delinearize_static_basis // CHECK-SAME: (%[[IDX:.+]]: index) -// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[IDX]]] -// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[IDX]]] -// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[IDX]]] +// CHECK-DAG: %[[C224:.+]] = arith.constant 224 : index +// CHECK-DAG: %[[C50176:.+]] = arith.constant 50176 : index +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[N:.+]] = arith.floordivsi %[[IDX]], %[[C50176]] +// CHECK-DAG: %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[C50176]] +// CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]] +// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[C50176]] +// CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]] +// CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[C224]] +// CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[C224]] +// CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]] +// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[C224]] +// CHECK: %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]] // CHECK: return %[[N]], %[[P]], %[[Q]] -func.func @static_basis(%linear_index: index) -> (index, index, index) { +func.func @delinearize_static_basis(%linear_index: index) -> (index, index, index) { %1:3 = affine.delinearize_index %linear_index into (16, 224, 224) : index, index, index return %1#0, %1#1, %1#2 : index, index, index } // ----- -// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s2 floordiv (s0 * s1))> -// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) floordiv s1)> -// CHECK-DAG: #[[$map2:.+]] = affine_map<()[s0, s1, s2] -> ((s2 mod (s0 * s1)) mod s1)> - -// CHECK-LABEL: @dynamic_basis +// CHECK-LABEL: @delinearize_dynamic_basis // CHECK-SAME: (%[[IDX:.+]]: index, %[[MEMREF:.+]]: memref +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] : -// CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] : -// CHECK: %[[N:.+]] = affine.apply #[[$map0]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]] -// CHECK: %[[P:.+]] = affine.apply #[[$map1]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]] -// CHECK: %[[Q:.+]] = affine.apply #[[$map2]]()[%[[DIM1]], %[[DIM2]], %[[IDX]]] +// CHECK: %[[DIM1:.+]] = memref.dim %[[MEMREF]], %[[C1]] : +// CHECK: %[[DIM2:.+]] = memref.dim %[[MEMREF]], %[[C2]] : +// CHECK: %[[STRIDE1:.+]] = arith.muli %[[DIM2]], %[[DIM1]] +// CHECK: %[[N:.+]] = arith.floordivsi %[[IDX]], %[[STRIDE1]] +// CHECK-DAG: %[[P_REM:.+]] = arith.remsi %[[IDX]], %[[STRIDE1]] +// CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]] +// CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]] +// CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]] +// CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[D2]] +// CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[DIM2]] +// CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]] +// CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]] +// CHECK: %[[Q:.+]] = arith.select %[[Q_NEG]], %[[Q_SHIFTED]], %[[Q_REM]] // CHECK: return %[[N]], %[[P]], %[[Q]] -func.func @dynamic_basis(%linear_index: index, %src: memref) -> (index, index, index) { - %c0 = arith.constant 0 : index +func.func @delinearize_dynamic_basis(%linear_index: index, %src: memref) -> (index, index, index) { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %b1 = memref.dim %src, %c1 : memref @@ -44,12 +54,15 @@ func.func @dynamic_basis(%linear_index: index, %src: memref) -> (inde // ----- -// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)> - // CHECK-LABEL: @linearize_static // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index) -// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]] -// CHECK: return %[[val_0]] +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[C15:.+]] = arith.constant 15 : index +// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[C15]] +// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[C5]] +// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] +// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] +// CHECK: return %[[val_1]] func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index { %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index func.return %0 : index @@ -57,14 +70,44 @@ func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index { // ----- -// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))> - // CHECK-LABEL: @linearize_dynamic // CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index) -// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg4]], %[[arg2]], %[[arg3]]] -// CHECK: return %[[val_0]] +// CHECK: %[[stride_0:.+]] = arith.muli %[[arg4]], %[[arg3]] +// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg0]], %[[stride_0]] +// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg1]], %[[arg4]] +// CHECK: %[[val_0:.+]] = arith.addi %[[scaled_0]], %[[scaled_1]] +// CHECK: %[[val_1:.+]] = arith.addi %[[val_0]], %[[arg2]] +// CHECK: return %[[val_1]] func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) -> index { // Note: no outer bounds %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4) : index func.return %0 : index } + +// ----- + +// CHECK-LABEL: @linearize_sort_adds +// CHECK-SAME: (%[[arg0:.+]]: memref, %[[arg1:.+]]: index, %[[arg2:.+]]: index) +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK: scf.for %[[ARG3:.+]] = %{{.*}} to %[[arg2]] step %{{.*}} { +// CHECK: scf.for %[[ARG4:.+]] = %{{.*}} to %[[C4]] step %{{.*}} { +// CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]] +// CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]] +// CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]] +// Note: even though %arg3 has a lower stride, we add it first +// CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]] +// CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]] +// CHECK: return %[[val_1]] +func.func @linearize_sort_adds(%arg0: memref, %arg1: index, %arg2: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c0_i32 = arith.constant 0 : i32 + scf.for %arg3 = %c0 to %arg2 step %c1 { + scf.for %arg4 = %c0 to %c4 step %c1 { + %idx = affine.linearize_index disjoint [%arg1, %arg4, %arg3] by (4, %arg2) : index + memref.store %c0_i32, %arg0[%idx] : memref + } + } + return +} From 0d15eec842c56b182ff9cefb2a0caf53216fa294 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 19 Nov 2024 19:20:07 +0000 Subject: [PATCH 2/2] Fix tests --- mlir/test/Dialect/Affine/affine-expand-index-ops.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir index e4b1b98d1893d..9bfaafb8c2468 100644 --- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir +++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir @@ -36,7 +36,7 @@ func.func @delinearize_static_basis(%linear_index: index) -> (index, index, inde // CHECK-DAG: %[[P_NEG:.+]] = arith.cmpi slt, %[[P_REM]], %[[C0]] // CHECK-DAG: %[[P_SHIFTED:.+]] = arith.addi %[[P_REM]], %[[STRIDE1]] // CHECK-DAG: %[[P_MOD:.+]] = arith.select %[[P_NEG]], %[[P_SHIFTED]], %[[P_REM]] -// CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[D2]] +// CHECK: %[[P:.+]] = arith.divsi %[[P_MOD]], %[[DIM2]] // CHECK-DAG: %[[Q_REM:.+]] = arith.remsi %[[IDX]], %[[DIM2]] // CHECK-DAG: %[[Q_NEG:.+]] = arith.cmpi slt, %[[Q_REM]], %[[C0]] // CHECK-DAG: %[[Q_SHIFTED:.+]] = arith.addi %[[Q_REM]], %[[DIM2]] @@ -89,15 +89,15 @@ func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: in // CHECK-LABEL: @linearize_sort_adds // CHECK-SAME: (%[[arg0:.+]]: memref, %[[arg1:.+]]: index, %[[arg2:.+]]: index) // CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index -// CHECK: scf.for %[[ARG3:.+]] = %{{.*}} to %[[arg2]] step %{{.*}} { -// CHECK: scf.for %[[ARG4:.+]] = %{{.*}} to %[[C4]] step %{{.*}} { +// CHECK: scf.for %[[arg3:.+]] = %{{.*}} to %[[arg2]] step %{{.*}} { +// CHECK: scf.for %[[arg4:.+]] = %{{.*}} to %[[C4]] step %{{.*}} { // CHECK: %[[stride_0:.+]] = arith.muli %[[arg2]], %[[C4]] // CHECK: %[[scaled_0:.+]] = arith.muli %[[arg1]], %[[stride_0]] // CHECK: %[[scaled_1:.+]] = arith.muli %[[arg4]], %[[arg2]] // Note: even though %arg3 has a lower stride, we add it first // CHECK: %[[val_0_2:.+]] = arith.addi %[[scaled_0]], %[[arg3]] // CHECK: %[[val_1:.+]] = arith.addi %[[val_0_2]], %[[scaled_1]] -// CHECK: return %[[val_1]] +// CHECK: memref.store %{{.*}}, %[[arg0]][%[[val_1]]] func.func @linearize_sort_adds(%arg0: memref, %arg1: index, %arg2: index) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index