diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index e9480d30c2d70..92fd6e99338ae 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -1113,4 +1113,73 @@ def AffineDelinearizeIndexOp : Affine_Op<"delinearize_index", let hasCanonicalizer = 1; } +//===----------------------------------------------------------------------===// +// AffineLinearizeIndexOp +//===----------------------------------------------------------------------===// +def AffineLinearizeIndexOp : Affine_Op<"linearize_index", + [Pure, AttrSizedOperandSegments]> { + let summary = "linearize an index"; + let description = [{ + The `affine.linearize_index` operation takes a sequence of index values and a + basis of the same length and linearizes the indices using that basis. + + That is, for indices `%idx_1` through `%idx_N` and basis elements `b_1` through `b_N`, + it computes + + ``` + sum(i = 1 to N) %idx_i * product(j = i + 1 to N) B_j + ``` + + If the `disjoint` property is present, this is an optimization hint that, + for all `i`, `0 <= %idx_i < B_i` - that is, no index affects any other index, + except that `%idx_0` may be negative to make the index as a whole negative. + + Note that the outputs of `affine.delinearize_index` are, by definition, `disjoint`. + + Example: + + ```mlir + %linear_index = affine.linearize_index [%index_0, %index_1, %index_2] (2, 3, 5) : index + ``` + + In the above example, `%linear_index` conceptually holds the following: + + ```mlir + #map = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)> + %linear_index = affine.apply #map()[%index_0, %index_1, %index_2] + ``` + }]; + + let arguments = (ins Variadic:$multi_index, + Variadic:$dynamic_basis, + DenseI64ArrayAttr:$static_basis, + UnitProperty:$disjoint); + let results = (outs Index:$linear_index); + + let assemblyFormat = [{ + (`disjoint` $disjoint^)? ` ` + `[` $multi_index `]` `by` ` ` + custom($dynamic_basis, $static_basis, "::mlir::AsmParser::Delimiter::Paren") + attr-dict `:` type($linear_index) + }]; + + let builders = [ + OpBuilder<(ins "ValueRange":$multi_index, "ValueRange":$basis, CArg<"bool", "false">:$disjoint)>, + OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef":$basis, CArg<"bool", "false">:$disjoint)>, + OpBuilder<(ins "ValueRange":$multi_index, "ArrayRef":$basis, CArg<"bool", "false">:$disjoint)> + ]; + + let extraClassDeclaration = [{ + /// Return a vector with all the static and dynamic basis values. + SmallVector getMixedBasis() { + OpBuilder builder(getContext()); + return ::mlir::getMixedValues(getStaticBasis(), getDynamicBasis(), builder); + } + + }]; + + let hasVerifier = 1; + let hasCanonicalizer = 1; +} + #endif // AFFINE_OPS diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index a2bf92323be01..0e98223969e08 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -315,12 +315,17 @@ FailureOr> delinearizeIndex(OpBuilder &b, Location loc, FailureOr> delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, ArrayRef basis); + // Generate IR that extracts the linear index from a multi-index according to // a basis/shape. OpFoldResult linearizeIndex(ArrayRef multiIndex, ArrayRef basis, ImplicitLocOpBuilder &builder); +OpFoldResult linearizeIndex(OpBuilder &builder, Location loc, + ArrayRef multiIndex, + ArrayRef basis); + /// Ensure that all operations that could be executed after `start` /// (noninclusive) and prior to `memOp` (e.g. on a control flow/op path /// between the operations) do not have the potential memory effect diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index f384f454bc472..ebdce522b05db 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -4684,6 +4684,115 @@ void affine::AffineDelinearizeIndexOp::getCanonicalizationPatterns( patterns.insert(context); } +//===----------------------------------------------------------------------===// +// LinearizeIndexOp +//===----------------------------------------------------------------------===// + +void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder, + OperationState &odsState, + ValueRange multiIndex, ValueRange basis, + bool disjoint) { + SmallVector dynamicBasis; + SmallVector staticBasis; + dispatchIndexOpFoldResults(getAsOpFoldResult(basis), dynamicBasis, + staticBasis); + build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint); +} + +void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder, + OperationState &odsState, + ValueRange multiIndex, + ArrayRef basis, + bool disjoint) { + SmallVector dynamicBasis; + SmallVector staticBasis; + dispatchIndexOpFoldResults(basis, dynamicBasis, staticBasis); + build(odsBuilder, odsState, multiIndex, dynamicBasis, staticBasis, disjoint); +} + +void AffineLinearizeIndexOp::build(OpBuilder &odsBuilder, + OperationState &odsState, + ValueRange multiIndex, + ArrayRef basis, bool disjoint) { + build(odsBuilder, odsState, multiIndex, ValueRange{}, basis, disjoint); +} + +LogicalResult AffineLinearizeIndexOp::verify() { + if (getStaticBasis().empty()) + return emitOpError("basis should not be empty"); + + if (getMultiIndex().size() != getStaticBasis().size()) + return emitOpError("should be passed an index for each basis element"); + + auto dynamicMarkersCount = + llvm::count_if(getStaticBasis(), ShapedType::isDynamic); + if (static_cast(dynamicMarkersCount) != getDynamicBasis().size()) + return emitOpError( + "mismatch between dynamic and static basis (kDynamic marker but no " + "corresponding dynamic basis entry) -- this can only happen due to an " + "incorrect fold/rewrite"); + + return success(); +} + +namespace { +/// Rewrite `affine.linearize_index disjoint [%...a, %x, %...b] by (%...c, 1, +/// %...d)` to `affine.linearize_index disjoint [%...a, %...b] by (%...c, +/// %...d)`. + +/// Note that `disjoint` is required here, because, without it, we could have +/// `affine.linearize_index [%...a, %c64, %...b] by (%...c, 1, %...d)` +/// is a valid operation where the `%c64` cannot be trivially dropped. +/// +/// Alternatively, if `%x` in the above is a known constant 0, remove it even if +/// the operation isn't asserted to be `disjoint`. +struct DropLinearizeUnitComponentsIfDisjointOrZero final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(affine::AffineLinearizeIndexOp op, + PatternRewriter &rewriter) const override { + size_t numIndices = op.getMultiIndex().size(); + SmallVector newIndices; + newIndices.reserve(numIndices); + SmallVector newBasis; + newBasis.reserve(numIndices); + + SmallVector basis = op.getMixedBasis(); + for (auto [index, basisElem] : llvm::zip_equal(op.getMultiIndex(), basis)) { + std::optional basisEntry = getConstantIntValue(basisElem); + if (!basisEntry || *basisEntry != 1) { + newIndices.push_back(index); + newBasis.push_back(basisElem); + continue; + } + + std::optional indexValue = getConstantIntValue(index); + if (!op.getDisjoint() && (!indexValue || *indexValue != 0)) { + newIndices.push_back(index); + newBasis.push_back(basisElem); + continue; + } + } + if (newIndices.size() == numIndices) + return failure(); + + if (newIndices.size() == 0) { + rewriter.replaceOpWithNewOp(op, 0); + return success(); + } + rewriter.replaceOpWithNewOp( + op, newIndices, newBasis, op.getDisjoint()); + return success(); + } +}; +} // namespace + +void affine::AffineLinearizeIndexOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp index d76968d3a7152..1930e987a33ff 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -44,6 +45,23 @@ struct LowerDelinearizeIndexOps } }; +/// Lowers `affine.linearize_index` into a sequence of multiplications and +/// additions. +struct LowerLinearizeIndexOps final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AffineLinearizeIndexOp op, + PatternRewriter &rewriter) const override { + SmallVector multiIndex = + getAsOpFoldResult(op.getMultiIndex()); + OpFoldResult linearIndex = + linearizeIndex(rewriter, op.getLoc(), multiIndex, op.getMixedBasis()); + Value linearIndexValue = + getValueOrCreateConstantIntOp(rewriter, op.getLoc(), linearIndex); + rewriter.replaceOp(op, linearIndexValue); + return success(); + } +}; + class ExpandAffineIndexOpsPass : public affine::impl::AffineExpandIndexOpsBase { public: @@ -63,7 +81,8 @@ class ExpandAffineIndexOpsPass void mlir::affine::populateAffineExpandIndexOpsPatterns( RewritePatternSet &patterns) { - patterns.insert(patterns.getContext()); + patterns.insert( + patterns.getContext()); } std::unique_ptr mlir::affine::createAffineExpandIndexOpsPass() { diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 2680502bb687d..7fe422f75c8fa 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -1999,6 +1999,12 @@ mlir::affine::delinearizeIndex(OpBuilder &b, Location loc, Value linearIndex, OpFoldResult mlir::affine::linearizeIndex(ArrayRef multiIndex, ArrayRef basis, ImplicitLocOpBuilder &builder) { + return linearizeIndex(builder, builder.getLoc(), multiIndex, basis); +} + +OpFoldResult mlir::affine::linearizeIndex(OpBuilder &builder, Location loc, + ArrayRef multiIndex, + ArrayRef basis) { assert(multiIndex.size() == basis.size()); SmallVector basisAffine; for (size_t i = 0; i < basis.size(); ++i) { @@ -2009,13 +2015,13 @@ OpFoldResult mlir::affine::linearizeIndex(ArrayRef multiIndex, SmallVector strides; strides.reserve(stridesAffine.size()); llvm::transform(stridesAffine, std::back_inserter(strides), - [&builder, &basis](AffineExpr strideExpr) { + [&builder, &basis, loc](AffineExpr strideExpr) { return affine::makeComposedFoldedAffineApply( - builder, builder.getLoc(), strideExpr, basis); + builder, loc, strideExpr, basis); }); auto &&[linearIndexExpr, multiIndexAndStrides] = computeLinearIndex( OpFoldResult(builder.getIndexAttr(0)), strides, multiIndex); - return affine::makeComposedFoldedAffineApply( - builder, builder.getLoc(), linearIndexExpr, multiIndexAndStrides); + return affine::makeComposedFoldedAffineApply(builder, loc, linearIndexExpr, + multiIndexAndStrides); } diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir index 3781d510897f8..3be42661f63ee 100644 --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -976,3 +976,20 @@ func.func @test_dilinearize_index(%linear_index: index) -> (index, index, index) // CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[VAL_37]], %[[VAL_34]] : index // CHECK: return %[[VAL_11]], %[[VAL_32]], %[[VAL_38]] : index, index, index // CHECK: } + +///////////////////////////////////////////////////////////////////// + +func.func @test_linearize_index(%arg0: index, %arg1: index, %arg2: index) -> index { + %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (2, 3, 5) : index + return %ret : index +} + +// CHECK-LABEL: @test_linearize_index +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index) +// CHECK: %[[c15:.+]] = arith.constant 15 : index +// CHECK-NEXT: %[[tmp0:.+]] = arith.muli %[[arg0]], %[[c15]] : index +// CHECK-NEXT: %[[c5:.+]] = arith.constant 5 : index +// CHECK-NEXT: %[[tmp1:.+]] = arith.muli %[[arg1]], %[[c5]] : index +// CHECK-NEXT: %[[tmp2:.+]] = arith.addi %[[tmp0]], %[[tmp1]] : index +// CHECK-NEXT: %[[ret:.+]] = arith.addi %[[tmp2]], %[[arg2]] : index +// CHECK-NEXT: return %[[ret]] diff --git a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir index 95773206a521e..ded1687ca560b 100644 --- a/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir +++ b/mlir/test/Dialect/Affine/affine-expand-index-ops.mlir @@ -41,3 +41,29 @@ func.func @dynamic_basis(%linear_index: index, %src: memref) -> (inde %1:3 = affine.delinearize_index %linear_index into (%b0, %b1, %b2) : index, index, index return %1#0, %1#1, %1#2 : index, index, index } + +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2] -> (s0 * 15 + s1 * 5 + s2)> + +// CHECK-LABEL: @linearize_static +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index) +// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg2]]] +// CHECK: return %[[val_0]] +func.func @linearize_static(%arg0: index, %arg1: index, %arg2: index) -> index { + %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (2, 3, 5) : index + func.return %0 : index +} + +// ----- + +// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1, s2, s3, s4] -> (s1 * s2 + s3 + s0 * (s2 * s4))> + +// CHECK-LABEL: @linearize_dynamic +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index, %[[arg4:.+]]: index, %[[arg5:.+]]: index) +// CHECK: %[[val_0:.+]] = affine.apply #[[$map0]]()[%[[arg0]], %[[arg1]], %[[arg5]], %[[arg2]], %[[arg4]]] +// CHECK: return %[[val_0]] +func.func @linearize_dynamic(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> index { + %0 = affine.linearize_index [%arg0, %arg1, %arg2] by (%arg3, %arg4, %arg5) : index + func.return %0 : index +} diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir index d78c3b667589b..aaa14fc873f25 100644 --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -1530,3 +1530,37 @@ func.func @delinearize_non_loop_like(%arg0: memref, %i : index) -> index %2 = affine.delinearize_index %i into (1024) : index return %2 : index } + +// ----- + +// CHECK-LABEL: @linearize_unit_basis_disjoint +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index, %[[arg3:.+]]: index) +// CHECK: %[[ret:.+]] = affine.linearize_index disjoint [%[[arg0]], %[[arg2]]] by (3, %[[arg3]]) : index +// CHECK: return %[[ret]] +func.func @linearize_unit_basis_disjoint(%arg0: index, %arg1: index, %arg2: index, %arg3: index) -> index { + %ret = affine.linearize_index disjoint [%arg0, %arg1, %arg2] by (3, 1, %arg3) : index + return %ret : index +} + +// ----- + +// CHECK-LABEL: @linearize_unit_basis_zero +// CHECK-SAME: (%[[arg0:.+]]: index, %[[arg1:.+]]: index, %[[arg2:.+]]: index) +// CHECK: %[[ret:.+]] = affine.linearize_index [%[[arg0]], %[[arg1]]] by (3, %[[arg2]]) : index +// CHECK: return %[[ret]] +func.func @linearize_unit_basis_zero(%arg0: index, %arg1: index, %arg2: index) -> index { + %c0 = arith.constant 0 : index + %ret = affine.linearize_index [%arg0, %c0, %arg1] by (3, 1, %arg2) : index + return %ret : index +} + +// ----- + +// CHECK-LABEL: @linearize_all_zero_unit_basis +// CHECK: arith.constant 0 : index +// CHECK-NOT: affine.linearize_index +func.func @linearize_all_zero_unit_basis() -> index { + %c0 = arith.constant 0 : index + %ret = affine.linearize_index [%c0, %c0] by (1, 1) : index + return %ret : index +} diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir index 869ea712bb369..2996194170900 100644 --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -548,6 +548,22 @@ func.func @delinearize(%idx: index, %basis0: index, %basis1 :index) { // ----- +func.func @linearize(%idx: index, %basis0: index, %basis1 :index) -> index { + // expected-error@+1 {{'affine.linearize_index' op should be passed an index for each basis element}} + %0 = affine.linearize_index [%idx] by (%basis0, %basis1) : index + return %0 : index +} + +// ----- + +func.func @linearize_empty() -> index { + // expected-error@+1 {{'affine.linearize_index' op basis should not be empty}} + %0 = affine.linearize_index [] by () : index + return %0 : index +} + +// ----- + func.func @dynamic_dimension_index() { "unknown.region"() ({ %idx = "unknown.test"() : () -> (index) diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir index 52ae53adcea9f..1d1db5f58f54c 100644 --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -282,3 +282,19 @@ func.func @delinearize_mixed(%linear_idx: index, %basis1: index) -> (index, inde %1:3 = affine.delinearize_index %linear_idx into (2, %basis1, 3) : index, index, index return %1#0, %1#1, %1#2 : index, index, index } + +// ----- + +// CHECK-LABEL: func @linearize +func.func @linearize(%index0: index, %index1: index, %basis0: index, %basis1 :index) -> index { + // CHECK: affine.linearize_index [%{{.+}}, %{{.+}}] by (%{{.+}}, %{{.+}}) : index + %1 = affine.linearize_index [%index0, %index1] by (%basis0, %basis1) : index + return %1 : index +} + +// CHECK-LABEL: @linearize_mixed +func.func @linearize_mixed(%index0: index, %index1: index, %index2: index, %basis1: index) -> index { + // CHECK: affine.linearize_index disjoint [%{{.+}}, %{{.+}}, %{{.+}}] by (2, %{{.+}}, 3) : index + %1 = affine.linearize_index disjoint [%index0, %index1, %index2] by (2, %basis1, 3) : index + return %1 : index +}