From 4eebe2174cc773b213a2f512b7405e14174c4714 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Fri, 8 Aug 2025 14:44:54 -0700 Subject: [PATCH 1/3] [Linalg] Add pattern to push down extract slice through generic Signed-off-by: Nirvedh Meshram --- .../Dialect/Linalg/Transforms/Transforms.h | 5 + .../Transforms/DataLayoutPropagation.cpp | 272 ++++++++++++++++++ .../Linalg/data-layout-propagation.mlir | 110 +++++++ .../Linalg/TestDataLayoutPropagation.cpp | 2 + 4 files changed, 389 insertions(+) diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 8d5306dca43e3..680fdffa9e587 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1918,6 +1918,11 @@ void populateDataLayoutPropagationPatterns( RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation); +/// Patterns to sink extract slice across other operations. +void populateExtractSliceSinkingPatterns( + RewritePatternSet &patterns, + const ControlPropagationFn &controlPackUnPackPropagation); + /// Pattern to remove dead operands and results of `linalg.generic` operations. /// This is a pattern wrapper for `deduplicateOperandsAndRemoveDeadResults`. void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 0a9c1766425bd..d50ab8cf03714 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -6,10 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" #include "llvm/ADT/SetOperations.h" @@ -1236,6 +1238,269 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern { ControlPropagationFn controlFn; }; +// This struct contains infomation about extract_slice dims. +struct SliceDimInfo { + OpFoldResult offset; + OpFoldResult sliceSize; + OpFoldResult outputSize; +}; + +/// Return the first input extract slice operand, if present, for the current +/// generic op. +static FailureOr> +getSliceOperandAndIndex(GenericOp genericOp) { + OpOperand *sliceOperand = nullptr; + unsigned operandIndex; + for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { + auto extractOp = operand->get().getDefiningOp(); + if (!extractOp) + continue; + sliceOperand = operand; + operandIndex = idx; + break; + } + if (!sliceOperand) { + return failure(); + } + return std::make_tuple(sliceOperand, operandIndex); +} + +// Return a map of dims that have non full slices on them so that other operands +// can use this information. Also return a bool mentioning if a reduction dim +// has a non full slice as that can be used to fold the original extract slice. +static FailureOr, bool>> +getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand, + tensor::ExtractSliceOp producerSliceOp) { + llvm::DenseMap nonZeroSliceDimMap; + bool hasNonZeroReductionDimSlice = false; + SmallVector iterators = + genericOp.getIteratorTypesArray(); + SmallVector offsets = producerSliceOp.getMixedOffsets(); + SmallVector sizes = producerSliceOp.getMixedSizes(); + + SmallVector shape = llvm::map_to_vector( + producerSliceOp.getSourceType().getShape(), + [&](int64_t sz) -> OpFoldResult { + return getAsIndexOpFoldResult(genericOp.getContext(), sz); + }); + + for (auto [idx, expr] : llvm::enumerate( + genericOp.getMatchingIndexingMap(sliceOperand).getResults())) { + if (isConstantIntValue(offsets[idx], 0) && + isEqualConstantIntOrValue(sizes[idx], shape[idx])) { + continue; + } + if (!isa(expr)) { + return failure(); + } + SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]}; + int64_t dimPos = cast(expr).getPosition(); + nonZeroSliceDimMap[dimPos] = sliceDimInfo; + if (iterators[dimPos] == utils::IteratorType::reduction) { + hasNonZeroReductionDimSlice = true; + } + } + // Next check if the dims with non zero slice info are used as non + // AffineDimExpr and if they are then bail-out. + for (OpOperand &operand : genericOp->getOpOperands()) { + if (operand == *sliceOperand) { + continue; + } + AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand); + if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) { + if (isa(expr)) { + return false; + } + WalkResult status = expr.walk([&](AffineExpr expr) { + if (auto dimExpr = dyn_cast(expr)) { + if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + if (status.wasInterrupted()) { + return true; + } + return false; + })) { + return failure(); + } + } + return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice); +} + +static FailureOr> +pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, + GenericOp genericOp, + ControlPropagationFn controlFn) { + if (genericOp.getNumResults() != 1) + return failure(); + if (hasGatherSemantics(genericOp)) + return failure(); + // Collect the unPacked operand, if present. + auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp); + if (failed(maybeSliceOperandAndIndex)) + return failure(); + OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex); + unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex); + + if (!controlFn(sliceOperand)) + return failure(); + + tensor::ExtractSliceOp producerSliceOp = + sliceOperand->get().getDefiningOp(); + assert(producerSliceOp && "expect a valid UnPackOp"); + + if (producerSliceOp.getSource().getType().getRank() != + producerSliceOp.getResult().getType().getRank()) { + return failure(); + } + + SmallVector strides = producerSliceOp.getMixedStrides(); + if (!areAllConstantIntValue(strides, 1)) + return failure(); + + SmallVector offsets = producerSliceOp.getMixedOffsets(); + SmallVector sizes = producerSliceOp.getMixedSizes(); + + // check if we can support the propagation of this extractSlice + // through the generic op and if so return the dimensions that + + auto maybeNonZeroSliceDimMap = + getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp); + + if (failed(maybeNonZeroSliceDimMap)) { + return failure(); + } + + auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap); + bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap); + + // Store the padding information as (dimPos, lowPad, highPad, PaddedShape). + Location loc = genericOp->getLoc(); + AffineExpr dim0, dim1; + bindDims(rewriter.getContext(), dim0, dim1); + auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); + auto sub = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap, + {v1, v2}); + }; + + MLIRContext *ctx = genericOp.getContext(); + SmallVector paddedInputs; + for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { + if (idx == OperandIndex && !hasNonZeroReductionDimSlice) { + paddedInputs.push_back(producerSliceOp.getSource()); + continue; + } + AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand); + SmallVector operandLowPads(IndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector operandHighPads(IndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) { + if (!isa(expr)) { + continue; + } + AffineDimExpr dimExpr = cast(expr); + if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) { + SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()]; + operandLowPads[idx] = sliceDimInfo.offset; + operandHighPads[idx] = + sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + } + } + auto paddingValue = ub::PoisonOp::create( + rewriter, loc, getElementTypeOrSelf(operand->get().getType())); + auto paddedOperand = tensor::PadOp::create( + rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads, + paddingValue, /*nofold=*/false); + paddedInputs.push_back(paddedOperand); + } + AffineMap outputIndexingMap = + genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0)); + + auto outputShapeType = + llvm::cast(genericOp.getDpsInitOperand(0)->get().getType()); + SmallVector OutputShape = llvm::map_to_vector( + outputShapeType.getShape(), + [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); }); + SmallVector newSizes = OutputShape; + SmallVector outputLowPads(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector outputHighPads(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector newStrides(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 1)); + for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) { + if (!isa(expr)) { + continue; + } + AffineDimExpr dimExpr = cast(expr); + if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) { + SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()]; + outputLowPads[idx] = sliceDimInfo.offset; + outputHighPads[idx] = + sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + OutputShape[idx] = sliceDimInfo.outputSize; + newSizes[idx] = sliceDimInfo.sliceSize; + } + } + Value newPadOutput; + auto outputElType = + getElementTypeOrSelf(genericOp.getDpsInits()[0].getType()); + if (isGenericOutsNotUsed(genericOp)) { + newPadOutput = + tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType); + + } else { + + auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType); + newPadOutput = tensor::PadOp::create( + rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads, + outputHighPads, paddingValue, /*nofold=*/false); + } + + auto newGenericOp = linalg::GenericOp::create( + rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput}, + genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(), + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), + newGenericOp.getRegion().begin()); + + auto extractOp = tensor::ExtractSliceOp::create( + rewriter, loc, + newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)), + outputLowPads, newSizes, newStrides); + Value extractRes = extractOp.getResult(); + + return std::make_tuple(newGenericOp, extractRes); +} + +class PushDownExtractSliceOpThroughGenericOp final + : public OpRewritePattern { +public: + PushDownExtractSliceOpThroughGenericOp(MLIRContext *context, + ControlPropagationFn fun) + : OpRewritePattern(context), controlFn(std::move(fun)) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + auto genericAndRepl = + pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn); + if (failed(genericAndRepl)) + return failure(); + rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); + return success(); + } + +private: + ControlPropagationFn controlFn; +}; + } // namespace void mlir::linalg::populateDataLayoutPropagationPatterns( @@ -1247,3 +1512,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns( PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( patterns.getContext(), controlPackUnPackPropagation); } + +void mlir::linalg::populateExtractSliceSinkingPatterns( + RewritePatternSet &patterns, + const ControlPropagationFn &controlPackUnPackPropagation) { + patterns.insert( + patterns.getContext(), controlPackUnPackPropagation); +} diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index cc26fa48abf4b..723eecb52351b 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1447,3 +1447,113 @@ func.func @push_unpack_in_padded_domain_out_used(%arg0: tensor<8x8x4x8xf32>, %ar // CHECK: %[[UNPACK2:.+]] = linalg.unpack %[[GENERIC]] // CHECK-SAME: into %[[ARG1]] // CHECK: return %[[UNPACK2]] : tensor + +// ----- + +module { + func.func @push_extract_through_generic(%arg0: tensor<128x7x128xf32>, %arg1: tensor, %arg2: tensor, %arg3: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[0, 0, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f32, %in_0: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor + return %0 : tensor + } +} + +// CHECK-LABEL: func.func @push_extract_through_generic +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]] +// CHECK: %[[POISON:.+]] = ub.poison : f32 +// CHECK: %[[PADDED:.+]] = tensor.pad %arg1 +// CHECK: tensor.yield %[[POISON]] : f32 +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[PADDED]] +// CHECK-SAME: outs(%[[EMPTY]] +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %3[%[[ARG3]], 0, 0] [%[[ARG3]], 5, 128] [1, 1, 1] : tensor<128x5x128xbf16> to tensor +// CHECK: return %[[EXTRACT]] + +// ----- + +func.func @nopush_extract_through_generic_nodimexpr1(%arg0: tensor<128x7x128xf32>, %arg1: tensor, %arg2: tensor, %arg3: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[0, %arg3, %arg3] [128, 7, %arg3] [1, 1, 1] : tensor<128x7x128xf32> to tensor<128x7x?xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%extracted_slice, %arg1 : tensor<128x7x?xf32>, tensor) outs(%arg2 : tensor) { + ^bb0(%in: f32, %in_0: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr1 +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: return %[[GENERIC]] + +// ----- + +func.func @nopush_extract_through_generic_nodimexpr2(%arg0: tensor<128x?x128xf32>, %arg1: tensor<128x5x3x128xf32>, %arg2: tensor<128x?x128xbf16>, %arg3: index) -> tensor<128x?x128xbf16> { + %extracted_slice = tensor.extract_slice %arg1[0, %arg3, 0, 0] [128, %arg3, 3, 128] [1, 1, 1, 1] : tensor<128x5x3x128xf32> to tensor<128x?x3x128xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d1, d2 + d3, d0)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %extracted_slice : tensor<128x?x128xf32>, tensor<128x?x3x128xf32>) outs(%arg2 : tensor<128x?x128xbf16>) { + ^bb0(%in: f32, %in_0: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + linalg.yield %1 : bf16 + } -> tensor<128x?x128xbf16> + return %0 : tensor<128x?x128xbf16> +} + +// CHECK-LABEL: func.func @nopush_extract_through_generic_nodimexpr2 +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: return %[[GENERIC]] + +// ----- + +func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<128x128xf32>, %arg1: tensor, %arg2: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[%arg2, %arg2] [%arg2, %arg2] [1, 1] : tensor<128x128xf32> to tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + %2 = arith.addf %1, %out : bf16 + linalg.yield %2 : bf16 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @push_redcutionextract_through_generic_withoutsused_2 +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] +// CHECK: %[[POISON_BF16:.+]] = ub.poison : bf16 +// CHECK: %[[POISON_F32:.+]] = ub.poison : f32 +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor +// CHECK: %[[PADDED:.+]] = tensor.pad %[[EXTRACT]] +// CHECK: tensor.yield %[[POISON_F32]] : f32 +// CHECK: %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]] +// CHECK: %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]] +// CHECK: tensor.yield %[[POISON_BF16]] : bf16 +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[PADDED]] +// CHECK-SAME: outs(%[[PADDED1]] +// CHECK: %[[EXTRACT1:.+]] = tensor.extract_slice %[[GENERIC]][%[[ARG2]]] [%[[ARG2]]] [1] : tensor to tensor +// CHECK: return %[[EXTRACT1]] + + +// ----- + +func.func @nopush_rankreducingextract(%arg0: tensor<128x128x128xf32>, %arg1: tensor, %arg2: index) -> tensor { + %extracted_slice = tensor.extract_slice %arg0[0, %arg2, %arg2] [1, %arg2, %arg2] [1, 1, 1] : tensor<128x128x128xf32> to tensor + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: bf16): + %1 = arith.truncf %in : f32 to bf16 + %2 = arith.addf %1, %out : bf16 + linalg.yield %2 : bf16 + } -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func.func @nopush_rankreducingextract +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK: return %[[GENERIC]] diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp index d0700f9a4f1a4..2cf25d8fc8c19 100644 --- a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp @@ -34,6 +34,8 @@ struct TestDataLayoutPropagationPass RewritePatternSet patterns(context); linalg::populateDataLayoutPropagationPatterns( patterns, [](OpOperand *opOperand) { return true; }); + linalg::populateExtractSliceSinkingPatterns( + patterns, [](OpOperand *opOperand) { return true; }); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } From 1493d56583ee5f5149a4157561486966f74faeaa Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Mon, 25 Aug 2025 15:25:01 -0500 Subject: [PATCH 2/3] address reviwer comments Signed-off-by: Nirvedh Meshram --- .../Transforms/DataLayoutPropagation.cpp | 127 +++++++++--------- 1 file changed, 65 insertions(+), 62 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index d50ab8cf03714..40085a2368009 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -1247,61 +1247,55 @@ struct SliceDimInfo { /// Return the first input extract slice operand, if present, for the current /// generic op. -static FailureOr> -getSliceOperandAndIndex(GenericOp genericOp) { +static FailureOr getSliceOperand(GenericOp genericOp) { OpOperand *sliceOperand = nullptr; - unsigned operandIndex; - for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { + for (auto operand : genericOp.getDpsInputOperands()) { auto extractOp = operand->get().getDefiningOp(); if (!extractOp) continue; sliceOperand = operand; - operandIndex = idx; break; } if (!sliceOperand) { return failure(); } - return std::make_tuple(sliceOperand, operandIndex); + return sliceOperand; } -// Return a map of dims that have non full slices on them so that other operands +// Return a map of dims that have partial slices on them so that other operands // can use this information. Also return a bool mentioning if a reduction dim // has a non full slice as that can be used to fold the original extract slice. -static FailureOr, bool>> -getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand, - tensor::ExtractSliceOp producerSliceOp) { - llvm::DenseMap nonZeroSliceDimMap; - bool hasNonZeroReductionDimSlice = false; - SmallVector iterators = - genericOp.getIteratorTypesArray(); +static FailureOr> +getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) { + tensor::ExtractSliceOp producerSliceOp = + sliceOperand->get().getDefiningOp(); + assert(producerSliceOp && "expect a valid ExtractSliceOp"); + llvm::DenseMap partialSliceDimMap; SmallVector offsets = producerSliceOp.getMixedOffsets(); SmallVector sizes = producerSliceOp.getMixedSizes(); - SmallVector shape = llvm::map_to_vector( - producerSliceOp.getSourceType().getShape(), - [&](int64_t sz) -> OpFoldResult { - return getAsIndexOpFoldResult(genericOp.getContext(), sz); - }); + SmallVector shape = getAsIndexOpFoldResult( + genericOp.getContext(), producerSliceOp.getSourceType().getShape()); for (auto [idx, expr] : llvm::enumerate( genericOp.getMatchingIndexingMap(sliceOperand).getResults())) { + // If we have a full slice in a dimension then we dont need to add it to + // the partial slice map. if (isConstantIntValue(offsets[idx], 0) && isEqualConstantIntOrValue(sizes[idx], shape[idx])) { continue; } + // We only support partial slices of AffineDimExprs so bail-out if thats not + // the case. if (!isa(expr)) { return failure(); } SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]}; int64_t dimPos = cast(expr).getPosition(); - nonZeroSliceDimMap[dimPos] = sliceDimInfo; - if (iterators[dimPos] == utils::IteratorType::reduction) { - hasNonZeroReductionDimSlice = true; - } + partialSliceDimMap[dimPos] = sliceDimInfo; } - // Next check if the dims with non zero slice info are used as non - // AffineDimExpr and if they are then bail-out. + // Next check if the dims with partial slice info are used in non + // AffineDimExpr in other operands and if they are then bail-out. for (OpOperand &operand : genericOp->getOpOperands()) { if (operand == *sliceOperand) { continue; @@ -1313,7 +1307,7 @@ getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand, } WalkResult status = expr.walk([&](AffineExpr expr) { if (auto dimExpr = dyn_cast(expr)) { - if (nonZeroSliceDimMap.count(dimExpr.getPosition()) != 0) { + if (partialSliceDimMap.contains(dimExpr.getPosition())) { return WalkResult::interrupt(); } } @@ -1327,7 +1321,7 @@ getNonFullSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand, return failure(); } } - return std::make_tuple(nonZeroSliceDimMap, hasNonZeroReductionDimSlice); + return partialSliceDimMap; } static FailureOr> @@ -1335,47 +1329,57 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, ControlPropagationFn controlFn) { if (genericOp.getNumResults() != 1) - return failure(); + return rewriter.notifyMatchFailure( + genericOp, "propagation through multi-result generic is unsupported."); if (hasGatherSemantics(genericOp)) + return rewriter.notifyMatchFailure( + genericOp, + "propagation through generic with gather semantics is unsupported."); + // Collect the sliced operand, if present. + auto maybeSliceOperand = getSliceOperand(genericOp); + if (failed(maybeSliceOperand)) return failure(); - // Collect the unPacked operand, if present. - auto maybeSliceOperandAndIndex = getSliceOperandAndIndex(genericOp); - if (failed(maybeSliceOperandAndIndex)) - return failure(); - OpOperand *sliceOperand = std::get<0>(*maybeSliceOperandAndIndex); - unsigned OperandIndex = std::get<1>(*maybeSliceOperandAndIndex); + OpOperand *sliceOperand = *maybeSliceOperand; + unsigned OperandIndex = sliceOperand->getOperandNumber(); if (!controlFn(sliceOperand)) return failure(); tensor::ExtractSliceOp producerSliceOp = sliceOperand->get().getDefiningOp(); - assert(producerSliceOp && "expect a valid UnPackOp"); + assert(producerSliceOp && "expect a valid ExtractSliceOp"); if (producerSliceOp.getSource().getType().getRank() != producerSliceOp.getResult().getType().getRank()) { - return failure(); + return rewriter.notifyMatchFailure( + genericOp, + "propagation of rank-reducing extract slice is unsupported."); } SmallVector strides = producerSliceOp.getMixedStrides(); if (!areAllConstantIntValue(strides, 1)) - return failure(); - - SmallVector offsets = producerSliceOp.getMixedOffsets(); - SmallVector sizes = producerSliceOp.getMixedSizes(); + return rewriter.notifyMatchFailure( + genericOp, "propagation of strided extract slice is unsupported."); // check if we can support the propagation of this extractSlice // through the generic op and if so return the dimensions that - auto maybeNonZeroSliceDimMap = - getNonFullSliceDimInfo(genericOp, sliceOperand, producerSliceOp); + auto maybePartialSliceDimMap = + getPartialSliceDimInfo(genericOp, sliceOperand); - if (failed(maybeNonZeroSliceDimMap)) { + if (failed(maybePartialSliceDimMap)) { return failure(); } - auto nonZeroSliceDimMap = std::get<0>(*maybeNonZeroSliceDimMap); - bool hasNonZeroReductionDimSlice = std::get<1>(*maybeNonZeroSliceDimMap); + auto partialSliceDimMap = *maybePartialSliceDimMap; + + SmallVector iterators = + genericOp.getIteratorTypesArray(); + bool hasPartialReductionDimSlice = + llvm::any_of(partialSliceDimMap, [&](const auto &slice) { + int64_t sliceDim = slice.first; + return iterators[sliceDim] == utils::IteratorType::reduction; + }); // Store the padding information as (dimPos, lowPad, highPad, PaddedShape). Location loc = genericOp->getLoc(); @@ -1390,7 +1394,7 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, MLIRContext *ctx = genericOp.getContext(); SmallVector paddedInputs; for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { - if (idx == OperandIndex && !hasNonZeroReductionDimSlice) { + if (idx == OperandIndex && !hasPartialReductionDimSlice) { paddedInputs.push_back(producerSliceOp.getSource()); continue; } @@ -1404,13 +1408,14 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, continue; } AffineDimExpr dimExpr = cast(expr); - if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) { - SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()]; - operandLowPads[idx] = sliceDimInfo.offset; - operandHighPads[idx] = - sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), - sliceDimInfo.sliceSize); + if (!partialSliceDimMap.contains(dimExpr.getPosition())) { + continue; } + SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()]; + operandLowPads[idx] = sliceDimInfo.offset; + operandHighPads[idx] = + sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); } auto paddingValue = ub::PoisonOp::create( rewriter, loc, getElementTypeOrSelf(operand->get().getType())); @@ -1439,15 +1444,15 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, continue; } AffineDimExpr dimExpr = cast(expr); - if (nonZeroSliceDimMap.contains(dimExpr.getPosition())) { - SliceDimInfo sliceDimInfo = nonZeroSliceDimMap[dimExpr.getPosition()]; - outputLowPads[idx] = sliceDimInfo.offset; - outputHighPads[idx] = - sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), - sliceDimInfo.sliceSize); - OutputShape[idx] = sliceDimInfo.outputSize; - newSizes[idx] = sliceDimInfo.sliceSize; + if (!partialSliceDimMap.contains(dimExpr.getPosition())) { + continue; } + SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()]; + outputLowPads[idx] = sliceDimInfo.offset; + outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + OutputShape[idx] = sliceDimInfo.outputSize; + newSizes[idx] = sliceDimInfo.sliceSize; } Value newPadOutput; auto outputElType = @@ -1455,9 +1460,7 @@ pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, if (isGenericOutsNotUsed(genericOp)) { newPadOutput = tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType); - } else { - auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType); newPadOutput = tensor::PadOp::create( rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads, From f08b03cc96077b7c6a7e3a3d20dab4d1bf158f91 Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram Date: Mon, 25 Aug 2025 16:06:58 -0500 Subject: [PATCH 3/3] add shape types for pads Signed-off-by: Nirvedh Meshram --- mlir/test/Dialect/Linalg/data-layout-propagation.mlir | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 723eecb52351b..0e42027644797 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -1470,6 +1470,7 @@ module { // CHECK: %[[POISON:.+]] = ub.poison : f32 // CHECK: %[[PADDED:.+]] = tensor.pad %arg1 // CHECK: tensor.yield %[[POISON]] : f32 +// CHECK: } : tensor to tensor // CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<128x5x128xbf16> // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]], %[[PADDED]] @@ -1531,9 +1532,11 @@ func.func @push_redcutionextract_through_generic_withoutsused_2(%arg0: tensor<12 // CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[ARG0]][%[[ARG2]], %[[ARG2]]] [%[[ARG2]], %[[ARG2]]] [1, 1] : tensor<128x128xf32> to tensor // CHECK: %[[PADDED:.+]] = tensor.pad %[[EXTRACT]] // CHECK: tensor.yield %[[POISON_F32]] : f32 +// CHECK: } : tensor to tensor // CHECK: %[[APPLY2:.+]] = affine.apply #map()[%[[ARG2]]] // CHECK: %[[PADDED1:.+]] = tensor.pad %[[ARG1]] low[%[[ARG2]]] high[%[[APPLY2]]] // CHECK: tensor.yield %[[POISON_BF16]] : bf16 +// CHECK: } : tensor to tensor // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[PADDED]] // CHECK-SAME: outs(%[[PADDED1]]