From 511fa2362c0f4eb4597fa7a190dd337ac69c974c Mon Sep 17 00:00:00 2001 From: hanhanW Date: Wed, 17 Sep 2025 15:18:21 -0700 Subject: [PATCH 1/6] [mlir][linalg] Use ub.poison in data layout propagation if a packed operand requires padding. In the past, it was hard to set padding values because we did not have ub.poison. It is not always correct if we set zeros as padding values. Now we can use `ub.poison` in this case. The revision adds the support for setting padding value using `ub.poison` when padding is required in the propagation. Otherwise, it creats an invalid pack op. Signed-off-by: hanhanW --- .../Transforms/DataLayoutPropagation.cpp | 9 +++-- .../Linalg/data-layout-propagation.mlir | 35 ++++++++++++++++--- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 6c17c3c2d0cab..2d075d92017f2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/SetOperations.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" @@ -289,9 +290,11 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, auto empty = linalg::PackOp::createDestinationTensor( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); - auto packedOperand = linalg::PackOp::create( - b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, - /*padding=*/std::nullopt, outerDimsPerm); + auto poison = ub::PoisonOp::create( + b, loc, getElementTypeOrSelf(opOperand->get().getType())); + auto packedOperand = + linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos, + innerTileSizes, poison, outerDimsPerm); return std::make_tuple(packedOperand, indexingMap); } diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index a5f8d63a3e912..7a16bc0a4faee 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1450,6 +1450,33 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar // ----- +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @push_unpack_in_padded_domain_multiple_inputs(%arg0: tensor<1x4x16x16xf32>, %arg1: tensor<8x64xf32>, %arg2: tensor<8x64xf32>) -> tensor<8x64xf32> { + %0 = tensor.empty() : tensor<8x64xf32> + %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<1x4x16x16xf32> -> tensor<8x64xf32> + %1 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg1, %unpack : tensor<8x64xf32>, tensor<8x64xf32>) outs(%arg2 : tensor<8x64xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %2 = arith.addf %in, %in_0 : f32 + linalg.yield %2 : f32 + } -> tensor<8x64xf32> + return %1 : tensor<8x64xf32> +} +// CHECK-LABEL: func.func @push_unpack_in_padded_domain_multiple_inputs +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[POISON:.+]] = ub.poison : f32 +// CHECK: %[[PACK:.+]] = linalg.pack %[[ARG1]] padding_value(%[[POISON]] : f32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [16, 16] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK: ins(%[[PACK]], %[[ARG0]] +// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[ELEM]] +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [16, 16] +// CHECK-SAME: into %[[ARG2]] +// CHECK: return %[[UNPACK]] + +// ----- + module { func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor, %arg2: tensor, %arg3: index) -> tensor { %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32> @@ -1473,7 +1500,7 @@ module { // CHECK: } : tensor to tensor // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16> // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]] +// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]] // CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor // CHECK: return %[[EXTRACT]] @@ -1492,7 +1519,7 @@ func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32 // CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1 // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK: return %[[GENERIC]] +// CHECK: return %[[GENERIC]] // ----- @@ -1508,7 +1535,7 @@ func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32 // CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2 // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK: return %[[GENERIC]] +// CHECK: return %[[GENERIC]] // ----- @@ -1575,7 +1602,7 @@ func.func @push_extract_through_generic_rank0_operand(%arg0: tensor<128x128xf32> // CHECK-LABEL: func.func @push_extract_through_generic_rank0_operand // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[GENERIC]] // CHECK: return %[[EXTRACT]] // ----- From d9f526eba1e90eacbc2e2189393d1ea7c7ff0abf Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Thu, 18 Sep 2025 12:11:57 -0700 Subject: [PATCH 2/6] add option to control poison padding Signed-off-by: Nirvedh Meshram --- .../Dialect/Linalg/IR/LinalgRelayoutOps.td | 8 ++ .../Dialect/Linalg/Transforms/Transforms.h | 5 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 29 +++++ .../Transforms/DataLayoutPropagation.cpp | 110 ++++++++++++------ .../Linalg/TestDataLayoutPropagation.cpp | 3 +- 5 files changed, 119 insertions(+), 36 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index f36b41ccf6745..ff9eccacf6278 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -239,6 +239,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ ArrayRef outerDimsPerm, ArrayRef innerTiles); + // Same as above function but here dynamic dimensions are assumed + // to require padding. + static bool requirePaddingValueStrict(ArrayRef inputShape, + ArrayRef innerDimsPos, + ArrayRef outputShape, + ArrayRef outerDimsPerm, + ArrayRef innerTiles); + static Value createDestinationTensor(OpBuilder &b, Location loc, Value source, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 64d3a2448b409..41670249936e6 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1914,9 +1914,12 @@ void populateElementwiseOpsFusionPatterns( using ControlPropagationFn = std::function; /// Patterns to bubble up or down data layout ops across other operations. +/// The function also has an option to allow the patterns to propagate with +/// poison padding if requested by the caller. void populateDataLayoutPropagationPatterns( RewritePatternSet &patterns, - const ControlPropagationFn &controlPackUnPackPropagation); + const ControlPropagationFn &controlPackUnPackPropagation, + bool PoisonPaddingOk = false); /// Patterns to sink extract slice across other operations. void populateExtractSliceSinkingPatterns( diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 578931e1351c6..0932bfe45916a 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5310,6 +5310,35 @@ bool PackOp::requirePaddingValue(ArrayRef inputShape, return false; } +bool PackOp::requirePaddingValueStrict(ArrayRef inputShape, + ArrayRef innerDimsPos, + ArrayRef outputShape, + ArrayRef outerDimsPerm, + ArrayRef innerTiles) { + SmallVector outputTileSizes( + outputShape.take_front(inputShape.size())); + if (!outerDimsPerm.empty()) { + assert(outerDimsPerm.size() == outputTileSizes.size() && + "expected output and outer_dims_perm to have same size"); + applyPermutationToVector(outputTileSizes, + invertPermutationVector(outerDimsPerm)); + } + for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) { + if (ShapedType::isDynamic(inputShape[pos])) + return true; + std::optional constantTile = getConstantIntValue(tileSize); + + if (!constantTile) { + if (ShapedType::isStatic(outputTileSizes[pos]) && + (inputShape[pos] % outputTileSizes[pos] != 0)) + return true; + } else if (inputShape[pos] % (*constantTile) != 0) { + return true; + } + } + return false; +} + LogicalResult PackOp::verify() { if (failed(commonVerifierPackAndUnPackOp(*this))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 2d075d92017f2..e0926d9a566a6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -221,9 +221,10 @@ static SmallVector computeOuterDims(ArrayRef perm, /// inner_dims_pos = [0] /// inner_tiles = [8] /// into %init : tensor -> tensor -static std::tuple +static FailureOr> getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, - GenericOp genericOp, OpOperand *opOperand) { + GenericOp genericOp, OpOperand *opOperand, + bool poisonPaddingOk) { int64_t numOrigLoops = genericOp.getNumLoops(); int64_t numInnerLoops = packInfo.getNumTiledLoops(); int64_t numLoops = numOrigLoops + numInnerLoops; @@ -287,12 +288,24 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, // The operand does not have dimensions that relates to pack op. if (innerDimsPos.empty() && outerDimsPerm.empty()) return std::make_tuple(opOperand->get(), indexingMap); - + auto inputType = cast(opOperand->get().getType()); + auto maybeIntInnerTileSizes = getConstantIntValues(innerTileSizes); + if (!maybeIntInnerTileSizes.has_value()) { + return failure(); + } + if (!poisonPaddingOk && + linalg::PackOp::requirePaddingValueStrict( + inputType.getShape(), innerDimsPos, + linalg::PackOp::inferPackedType(inputType, *maybeIntInnerTileSizes, + innerDimsPos, outerDimsPerm) + .getShape(), + outerDimsPerm, innerTileSizes)) + return failure(); auto empty = linalg::PackOp::createDestinationTensor( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); auto poison = ub::PoisonOp::create( b, loc, getElementTypeOrSelf(opOperand->get().getType())); - auto packedOperand = + Value packedOperand = linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, poison, outerDimsPerm); return std::make_tuple(packedOperand, indexingMap); @@ -304,10 +317,10 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, /// around it. Implicitly this will only work when a packInfo can be obtained. /// This make sure that we are only using this function on parallel permuted /// dimensions. -static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, - Value dest, AffineMap packedOutIndexingMap, - const PackInfo &packInfo, - bool isFoldableUnpackPack) { +static FailureOr +packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest, + AffineMap packedOutIndexingMap, const PackInfo &packInfo, + bool isFoldableUnpackPack, bool poisonPaddingOk) { Location loc = genericOp.getLoc(); SmallVector inputOperands; SmallVector inputOperandsFromUnpackedSource; @@ -318,8 +331,13 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles()); }; for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { - auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( - rewriter, loc, packInfo, genericOp, inputOperand); + auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand( + rewriter, loc, packInfo, genericOp, inputOperand, poisonPaddingOk); + if (failed(mayBepackedOperandAndIndexing)) { + return failure(); + } + auto packedOperand = std::get<0>(*mayBepackedOperandAndIndexing); + auto packedIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing); auto unpackOp = inputOperand->get().getDefiningOp(); auto packOp = packedOperand.getDefiningOp(); if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) { @@ -410,7 +428,8 @@ static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) { /// } -> tensor static FailureOr bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, - const ControlPropagationFn &controlFn) { + const ControlPropagationFn &controlFn, + bool poisonPaddingOk) { auto genericOp = packOp.getSource().getDefiningOp(); if (!genericOp) return failure(); @@ -473,9 +492,14 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, } // Rebuild the indexing map for the corresponding init operand. - auto [packedOutOperand, packedOutIndexingMap] = + auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, - genericOp, opOperand); + genericOp, opOperand, poisonPaddingOk); + if (failed(mayBepackedOperandAndIndexing)) { + return failure(); + } + auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing); + auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing); // Forward the new tensor.empty as a destination if it is one of the following // situations: @@ -491,7 +515,8 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, // pack(unpack) isn't naively foldable because the unpack op can be from // an arbitrary domain so we need to keep both. return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, - *packInfo, /*isFoldableUnpackPack=*/false); + *packInfo, /*isFoldableUnpackPack=*/false, + poisonPaddingOk); } /// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method. @@ -499,13 +524,15 @@ struct BubbleUpPackOpThroughGenericOpPattern : public OpRewritePattern { public: BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context, - ControlPropagationFn fun) - : OpRewritePattern(context), controlFn(std::move(fun)) {} + ControlPropagationFn fun, + bool poisonPaddingOk) + : OpRewritePattern(context), controlFn(std::move(fun)), + poisonPaddingOk(std::move(poisonPaddingOk)) {} LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { - auto genericOp = - bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); + auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn, + poisonPaddingOk); if (failed(genericOp)) return failure(); rewriter.replaceOp(packOp, genericOp->getResults()); @@ -514,6 +541,7 @@ struct BubbleUpPackOpThroughGenericOpPattern private: ControlPropagationFn controlFn; + bool poisonPaddingOk; }; /// Propagate a linalg.pack operation up through a tensor.pad. The idea is to @@ -1083,7 +1111,8 @@ static FailureOr getUnPackedOperand(GenericOp genericOp) { /// static FailureOr> pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, - ControlPropagationFn controlFn) { + ControlPropagationFn controlFn, + bool poisonPaddingOk) { if (genericOp.getNumResults() != 1) return failure(); @@ -1110,9 +1139,14 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, return failure(); // Rebuild the indexing map for the corresponding init operand. - auto [packedOutOperand, packedOutIndexingMap] = - getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, - genericOp, genericOp.getDpsInitOperand(0)); + auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand( + rewriter, genericOp.getLoc(), *packInfo, genericOp, + genericOp.getDpsInitOperand(0), poisonPaddingOk); + if (failed(mayBepackedOperandAndIndexing)) { + return failure(); + } + auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing); + auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing); auto destPack = packedOutOperand.getDefiningOp(); // Forward the new tensor.empty as a destination if it is one of the following @@ -1132,9 +1166,12 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, // pack(unpack) is foldable in this case. This is because in pushing down the // unpack, by default we will populate an additional pack op after the unpack. // This guarantees them to be foldable. - GenericOp newGenericOp = + auto maybeGenericOp = packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo, - /*isFoldableUnpackPack=*/true); + /*isFoldableUnpackPack=*/true, poisonPaddingOk); + if (failed(maybeGenericOp)) + return failure(); + GenericOp newGenericOp = *maybeGenericOp; Value newResult = newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)); @@ -1160,13 +1197,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern { public: PushDownUnPackOpThroughGenericOp(MLIRContext *context, - ControlPropagationFn fun) - : OpRewritePattern(context), controlFn(std::move(fun)) {} + ControlPropagationFn fun, + bool poisonPaddingOk) + : OpRewritePattern(context), controlFn(std::move(fun)), + poisonPaddingOk(std::move(poisonPaddingOk)) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - auto genericAndRepl = - pushDownUnPackOpThroughGenericOp(rewriter, genericOp, controlFn); + auto genericAndRepl = pushDownUnPackOpThroughGenericOp( + rewriter, genericOp, controlFn, poisonPaddingOk); if (failed(genericAndRepl)) return failure(); rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); @@ -1175,6 +1214,7 @@ struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern { private: ControlPropagationFn controlFn; + bool poisonPaddingOk; }; /// Propagate a linalg.unpack operation through a tensor.pad. The idea is to @@ -1525,12 +1565,14 @@ class PushDownExtractSliceOpThroughGenericOp final void mlir::linalg::populateDataLayoutPropagationPatterns( RewritePatternSet &patterns, - const ControlPropagationFn &controlPackUnPackPropagation) { - patterns - .insert( - patterns.getContext(), controlPackUnPackPropagation); + const ControlPropagationFn &controlPackUnPackPropagation, + bool PoisonPaddingOk) { + patterns.insert( + patterns.getContext(), controlPackUnPackPropagation); + patterns.insert( + patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk); } void mlir::linalg::populateExtractSliceSinkingPatterns( diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp index d332270468ea8..d45aaf788f9c2 100644 --- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp @@ -33,7 +33,8 @@ struct TestDataLayoutPropagationPass MLIRContext *context = &getContext(); RewritePatternSet patterns(context); linalg::populateDataLayoutPropagationPatterns( - patterns, [](OpOperand *opOperand) { return true; }); + patterns, [](OpOperand *opOperand) { return true; }, + /*poisonPaddingOk=*/true); linalg::ControlPropagationFn controlExtract = [](OpOperand *opOperand) -> bool { Operation *producer = opOperand->get().getDefiningOp(); From 5a5c53eb3f1a2791540c5da7647ade4651eab667 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Fri, 19 Sep 2025 15:31:03 -0700 Subject: [PATCH 3/6] improve padding check Signed-off-by: Nirvedh Meshram --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 0932bfe45916a..49c2b54748c29 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -5324,17 +5324,14 @@ bool PackOp::requirePaddingValueStrict(ArrayRef inputShape, invertPermutationVector(outerDimsPerm)); } for (auto [pos, tileSize] : llvm::zip_equal(innerDimsPos, innerTiles)) { - if (ShapedType::isDynamic(inputShape[pos])) + if (ShapedType::isDynamic(inputShape[pos]) || + ShapedType::isDynamic(outputTileSizes[pos])) return true; std::optional constantTile = getConstantIntValue(tileSize); - - if (!constantTile) { - if (ShapedType::isStatic(outputTileSizes[pos]) && - (inputShape[pos] % outputTileSizes[pos] != 0)) - return true; - } else if (inputShape[pos] % (*constantTile) != 0) { + if (!constantTile) + return true; + if (inputShape[pos] % (*constantTile) != 0) return true; - } } return false; } From 15c9016d93db05fd25588fd2ad8a5b51b80d9acb Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Tue, 23 Sep 2025 09:26:14 -0700 Subject: [PATCH 4/6] reviewer comments Signed-off-by: Nirvedh Meshram --- .../Dialect/Linalg/IR/LinalgRelayoutOps.td | 8 +- .../Transforms/DataLayoutPropagation.cpp | 130 ++++++++++++------ 2 files changed, 92 insertions(+), 46 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index ff9eccacf6278..3390f380c7eb8 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -242,10 +242,10 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ // Same as above function but here dynamic dimensions are assumed // to require padding. static bool requirePaddingValueStrict(ArrayRef inputShape, - ArrayRef innerDimsPos, - ArrayRef outputShape, - ArrayRef outerDimsPerm, - ArrayRef innerTiles); + ArrayRef innerDimsPos, + ArrayRef outputShape, + ArrayRef outerDimsPerm, + ArrayRef innerTiles); static Value createDestinationTensor(OpBuilder &b, Location loc, Value source, ArrayRef innerTileSizes, diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index e0926d9a566a6..ea761599270ec 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -221,10 +221,21 @@ static SmallVector computeOuterDims(ArrayRef perm, /// inner_dims_pos = [0] /// inner_tiles = [8] /// into %init : tensor -> tensor -static FailureOr> -getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, - GenericOp genericOp, OpOperand *opOperand, - bool poisonPaddingOk) { + +struct PackedOperandDetails { + SmallVector innerTileSizes; + SmallVector innerDimsPos; + SmallVector outerDimsPerm; + AffineMap indexingMap; +}; + +/// Helper function for getOrCreatePackedViewOfOperand that populates +/// the details of the packedOperand that needs to be formed and also +// returns if the packing would require padding. +static bool getPackedOperandDetails( + OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand, + DenseMap &packedOperandMap) { + PackedOperandDetails currOperandDetails; int64_t numOrigLoops = genericOp.getNumLoops(); int64_t numInnerLoops = packInfo.getNumTiledLoops(); int64_t numLoops = numOrigLoops + numInnerLoops; @@ -233,9 +244,12 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, SmallVector exprs(origIndexingMap.getResults()); // If the OpOperand is a scalar or a zero-rank tensor, no need to pack. - if (genericOp.isScalar(opOperand) || exprs.empty()) - return std::make_tuple(opOperand->get(), - AffineMap::get(numLoops, 0, exprs, b.getContext())); + if (genericOp.isScalar(opOperand) || exprs.empty()) { + currOperandDetails.indexingMap = + AffineMap::get(numLoops, 0, exprs, b.getContext()); + packedOperandMap[opOperand] = currOperandDetails; + return false; + } // Step 1. Construct the information of packing data dimensions; append inner // dimensions to the indexing maps for the operand. @@ -283,24 +297,49 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, exprs = auxVec; } } - auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext()); + currOperandDetails.indexingMap = + AffineMap::get(numLoops, 0, exprs, b.getContext()); // The operand does not have dimensions that relates to pack op. - if (innerDimsPos.empty() && outerDimsPerm.empty()) - return std::make_tuple(opOperand->get(), indexingMap); + if (innerDimsPos.empty() && outerDimsPerm.empty()) { + packedOperandMap[opOperand] = currOperandDetails; + return false; + } auto inputType = cast(opOperand->get().getType()); - auto maybeIntInnerTileSizes = getConstantIntValues(innerTileSizes); - if (!maybeIntInnerTileSizes.has_value()) { - return failure(); + + auto maybeIntInnerTileSizes = + llvm::map_to_vector(innerTileSizes, [](OpFoldResult ofr) -> int64_t { + std::optional maybeCst = getConstantIntValue(ofr); + return maybeCst.value_or(ShapedType::kDynamic); + }); + bool requirePadding = linalg::PackOp::requirePaddingValueStrict( + inputType.getShape(), innerDimsPos, + linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes, + innerDimsPos, outerDimsPerm) + .getShape(), + outerDimsPerm, innerTileSizes); + currOperandDetails.innerDimsPos = innerDimsPos; + currOperandDetails.innerTileSizes = innerTileSizes; + currOperandDetails.outerDimsPerm = outerDimsPerm; + packedOperandMap[opOperand] = currOperandDetails; + + if (requirePadding) + return true; + return false; +} + +static std::tuple getOrCreatePackedViewOfOperand( + OpBuilder &b, Location loc, OpOperand *opOperand, + DenseMap packedOperandMap) { + assert(packedOperandMap.contains(opOperand) && + "packed operand details expected to be populated"); + auto currOperandDetails = packedOperandMap[opOperand]; + auto innerDimsPos = currOperandDetails.innerDimsPos; + auto outerDimsPerm = currOperandDetails.outerDimsPerm; + auto innerTileSizes = currOperandDetails.innerTileSizes; + if (innerDimsPos.empty() && outerDimsPerm.empty()) { + return std::make_tuple(opOperand->get(), currOperandDetails.indexingMap); } - if (!poisonPaddingOk && - linalg::PackOp::requirePaddingValueStrict( - inputType.getShape(), innerDimsPos, - linalg::PackOp::inferPackedType(inputType, *maybeIntInnerTileSizes, - innerDimsPos, outerDimsPerm) - .getShape(), - outerDimsPerm, innerTileSizes)) - return failure(); auto empty = linalg::PackOp::createDestinationTensor( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); auto poison = ub::PoisonOp::create( @@ -308,7 +347,7 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, Value packedOperand = linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, poison, outerDimsPerm); - return std::make_tuple(packedOperand, indexingMap); + return std::make_tuple(packedOperand, currOperandDetails.indexingMap); } /// This function is a helper subroutine to pack a genericOp and return it. It @@ -330,14 +369,18 @@ packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest, packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() && llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles()); }; + DenseMap packedOperandMap; + bool requiresPadding = false; for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { - auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand( - rewriter, loc, packInfo, genericOp, inputOperand, poisonPaddingOk); - if (failed(mayBepackedOperandAndIndexing)) { - return failure(); - } - auto packedOperand = std::get<0>(*mayBepackedOperandAndIndexing); - auto packedIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing); + requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp, + inputOperand, packedOperandMap); + } + if (requiresPadding && !poisonPaddingOk) { + return failure(); + } + for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { + auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( + rewriter, loc, inputOperand, packedOperandMap); auto unpackOp = inputOperand->get().getDefiningOp(); auto packOp = packedOperand.getDefiningOp(); if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) { @@ -492,15 +535,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, } // Rebuild the indexing map for the corresponding init operand. - auto mayBepackedOperandAndIndexing = - getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, - genericOp, opOperand, poisonPaddingOk); - if (failed(mayBepackedOperandAndIndexing)) { + DenseMap packedOperandMap; + bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp, + opOperand, packedOperandMap); + if (requiresPadding && !poisonPaddingOk) { return failure(); } - auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing); - auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing); - + auto [packedOutOperand, packedOutIndexingMap] = + getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand, + packedOperandMap); // Forward the new tensor.empty as a destination if it is one of the following // situations: // 1) The dps init operand is a tensor.empty. @@ -1139,14 +1182,17 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, return failure(); // Rebuild the indexing map for the corresponding init operand. - auto mayBepackedOperandAndIndexing = getOrCreatePackedViewOfOperand( - rewriter, genericOp.getLoc(), *packInfo, genericOp, - genericOp.getDpsInitOperand(0), poisonPaddingOk); - if (failed(mayBepackedOperandAndIndexing)) { + DenseMap packedOperandMap; + bool requiresPadding = + getPackedOperandDetails(rewriter, *packInfo, genericOp, + genericOp.getDpsInitOperand(0), packedOperandMap); + if (requiresPadding && !poisonPaddingOk) { return failure(); } - auto packedOutOperand = std::get<0>(*mayBepackedOperandAndIndexing); - auto packedOutIndexingMap = std::get<1>(*mayBepackedOperandAndIndexing); + auto [packedOutOperand, packedOutIndexingMap] = + getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), + genericOp.getDpsInitOperand(0), + packedOperandMap); auto destPack = packedOutOperand.getDefiningOp(); // Forward the new tensor.empty as a destination if it is one of the following From 83c477721483b745899a89c00aa9b2f51f32fea4 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Tue, 23 Sep 2025 15:16:57 -0700 Subject: [PATCH 5/6] further review comments Signed-off-by: Nirvedh Meshram --- .../Transforms/DataLayoutPropagation.cpp | 90 +++++++++---------- 1 file changed, 44 insertions(+), 46 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index ea761599270ec..cbb39dc0a4099 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -190,38 +190,6 @@ static SmallVector computeOuterDims(ArrayRef perm, return outerDimsPerm; } -/// Returns a tuple for packed operand and indexing_map with the assumptions: -/// 1) The generic op is the producer of the pack op. -/// 2) The generic op has only one result. -/// If the operand is a scalar or packing dimensions are all irrelevant to the -/// operand, the operand and the updated indexing map will be returned. -/// Otherwise, it returns the packed operand and the updated indexing map. E.g., -/// -/// #map0 = affine_map<(d0, d1) -> (d0, d1)> -/// #map1 = affine_map<(d0, d1) -> (d0)> -/// #map2 = affine_map<(d0, d1) -> (d1)> -/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], -/// iterator_types = ["parallel", "parallel"]} -/// ins(%arg0, %arg1 : tensor, tensor) -/// outs(%init : tensor) { -/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): -/// %4 = arith.addf %arg3, %arg4 : f32 -/// linalg.yield %4 : f32 -/// } -> tensor -/// %1 = linalg.pack %0 -/// inner_dims_pos = [0, 1] -/// inner_tiles = [8, 2] -/// into %dest : tensor -> tensor -/// -/// Taking the first input operand as an example, the inner tile size of d1 is -/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> -/// affine_map<(d1, d3)>` will be returned. -/// -/// %pack = linalg.pack %arg0 -/// inner_dims_pos = [0] -/// inner_tiles = [8] -/// into %init : tensor -> tensor - struct PackedOperandDetails { SmallVector innerTileSizes; SmallVector innerDimsPos; @@ -231,7 +199,7 @@ struct PackedOperandDetails { /// Helper function for getOrCreatePackedViewOfOperand that populates /// the details of the packedOperand that needs to be formed and also -// returns if the packing would require padding. +/// returns if the packing would require padding. static bool getPackedOperandDetails( OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand, DenseMap &packedOperandMap) { @@ -323,23 +291,53 @@ static bool getPackedOperandDetails( currOperandDetails.outerDimsPerm = outerDimsPerm; packedOperandMap[opOperand] = currOperandDetails; - if (requirePadding) - return true; - return false; + return requirePadding; } +/// Returns a tuple for packed operand and indexing_map with the assumptions: +/// 1) The generic op is the producer of the pack op. +/// 2) The generic op has only one result. +/// If the operand is a scalar or packing dimensions are all irrelevant to the +/// operand, the operand and the updated indexing map will be returned. +/// Otherwise, it returns the packed operand and the updated indexing map. E.g., +/// +/// #map0 = affine_map<(d0, d1) -> (d0, d1)> +/// #map1 = affine_map<(d0, d1) -> (d0)> +/// #map2 = affine_map<(d0, d1) -> (d1)> +/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], +/// iterator_types = ["parallel", "parallel"]} +/// ins(%arg0, %arg1 : tensor, tensor) +/// outs(%init : tensor) { +/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): +/// %4 = arith.addf %arg3, %arg4 : f32 +/// linalg.yield %4 : f32 +/// } -> tensor +/// %1 = linalg.pack %0 +/// inner_dims_pos = [0, 1] +/// inner_tiles = [8, 2] +/// into %dest : tensor -> tensor +/// +/// Taking the first input operand as an example, the inner tile size of d1 is +/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> +/// affine_map<(d1, d3)>` will be returned. +/// +/// %pack = linalg.pack %arg0 +/// inner_dims_pos = [0] +/// inner_tiles = [8] +/// into %init : tensor -> tensor + static std::tuple getOrCreatePackedViewOfOperand( OpBuilder &b, Location loc, OpOperand *opOperand, - DenseMap packedOperandMap) { + const DenseMap &packedOperandMap) { assert(packedOperandMap.contains(opOperand) && "packed operand details expected to be populated"); - auto currOperandDetails = packedOperandMap[opOperand]; + auto currOperandDetails = packedOperandMap.at(opOperand); auto innerDimsPos = currOperandDetails.innerDimsPos; auto outerDimsPerm = currOperandDetails.outerDimsPerm; auto innerTileSizes = currOperandDetails.innerTileSizes; - if (innerDimsPos.empty() && outerDimsPerm.empty()) { + if (innerDimsPos.empty() && outerDimsPerm.empty()) return std::make_tuple(opOperand->get(), currOperandDetails.indexingMap); - } + auto empty = linalg::PackOp::createDestinationTensor( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); auto poison = ub::PoisonOp::create( @@ -375,9 +373,9 @@ packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest, requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp, inputOperand, packedOperandMap); } - if (requiresPadding && !poisonPaddingOk) { + if (requiresPadding && !poisonPaddingOk) return failure(); - } + for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( rewriter, loc, inputOperand, packedOperandMap); @@ -538,9 +536,9 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, DenseMap packedOperandMap; bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp, opOperand, packedOperandMap); - if (requiresPadding && !poisonPaddingOk) { + if (requiresPadding && !poisonPaddingOk) return failure(); - } + auto [packedOutOperand, packedOutIndexingMap] = getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand, packedOperandMap); @@ -1186,9 +1184,9 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp, genericOp.getDpsInitOperand(0), packedOperandMap); - if (requiresPadding && !poisonPaddingOk) { + if (requiresPadding && !poisonPaddingOk) return failure(); - } + auto [packedOutOperand, packedOutIndexingMap] = getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), genericOp.getDpsInitOperand(0), From ada610e6b8e773bdbada9ee1c79f8a8b95be144e Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Tue, 23 Sep 2025 15:23:41 -0700 Subject: [PATCH 6/6] further review comments 2 Signed-off-by: Nirvedh Meshram --- mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index cbb39dc0a4099..3bb5f8af821c0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -325,7 +325,6 @@ static bool getPackedOperandDetails( /// inner_dims_pos = [0] /// inner_tiles = [8] /// into %init : tensor -> tensor - static std::tuple getOrCreatePackedViewOfOperand( OpBuilder &b, Location loc, OpOperand *opOperand, const DenseMap &packedOperandMap) {