-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] Add shape propagation through tensor.pad #136681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
c9ef18a to
321aea7
Compare
Max191
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left some comments to help with supporting dynamic cases. Lmk if you have more questions!
|
Apologies for the delay — I’ve been recovering from a medical issue. I’ll resume this soon. |
4ab75e1 to
0d8c636
Compare
|
Hi @Max191 , I updated this PR to support dynamic cases too, following your review. Sorry It took a while for me to get back from hiatus. I think it would be better how do you think? |
|
@llvm/pr-subscribers-mlir-linalg Author: Hyunsung Lee (ita9naiwa) ChangesI’ve implemented fusion for tensor.expand_shape → tensor.pad, but two gaps remain:
Before (expand then pad): func.func @<!-- -->fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
%c0 = arith.constant 0.0 : f32
%producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
%pad = tensor.pad %producer low[0, 1, 1] high[0, 1, 1] {
^bb0(%i: index, %j: index, %k: index):
tensor.yield %c0 : f32
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
%reshape = tensor.expand_shape %pad [[0, 1], [2], [3]]
output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
return %reshape : tensor<32x16x258x258xf32>
}After (reshape then pad): func.func @<!-- -->fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
%c0 = arith.constant 0.0 : f32
%producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
%reshape = tensor.expand_shape %producer [[0, 1], [2], [3]]
output_shape [32, 16, 258, 258] : tensor<512x256x256xf32> into tensor<32x16x256x256xf32>
%pad = tensor.pad %reshape low[0, 0, 1, 1] high[0, 0, 1, 1] {
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
tensor.yield %c0 : f32
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
return %pad : tensor<32x16x258x258xf32>
}Next steps CC @Max191 for awareness—would love any pointers on the collapse‑side implementation or dynamic‑shape handling! Full diff: https://github.com/llvm/llvm-project/pull/136681.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 9c0f6e5d6469e..39eed6dd4cba4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1100,6 +1100,146 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+ auto isZeroPadding = [](OpFoldResult padValue) -> bool {
+ if (auto attr = dyn_cast<Attribute>(padValue)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 0;
+ }
+
+ if (auto val = dyn_cast<Value>(padValue)) {
+ if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+ if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+ return attr.getInt() == 0;
+ }
+ }
+
+ // when padding is dynamic and not constant, we don't know if it's zero or
+ // not. so we return false here.
+ return false;
+ };
+
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[idx];
+ OpFoldResult h = high[idx];
+ if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (size_t i = 0; i < reInd.size(); ++i) {
+ newLow.push_back(padOp.getMixedLowPad()[idx]);
+ newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ }
+ }
+
+ Location loc = expandOp.getLoc();
+ auto finalType = cast<RankedTensorType>(expandOp.getType());
+ ArrayRef<int64_t> finalShape = finalType.getShape();
+
+ SmallVector<OpFoldResult> expandedShape;
+ for (int64_t dimSize : finalShape) {
+ if (dimSize == ShapedType::kDynamic) {
+ expandedShape.push_back(OpFoldResult{});
+ } else {
+ expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+ }
+ }
+
+ for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[inDimIdx];
+ OpFoldResult h = high[inDimIdx];
+
+ if (!isZeroPadding(l) || !isZeroPadding(h)) {
+ auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
+ int64_t originalSize = srcType.getDimSize(inDimIdx);
+
+ OpFoldResult originalSizeOFR;
+ if (originalSize == ShapedType::kDynamic) {
+ Value orgSizeVal =
+ rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
+ originalSizeOFR = orgSizeVal;
+ } else {
+ originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
+ }
+
+ for (auto outDimIdx : outGroup) {
+ expandedShape[outDimIdx] = originalSizeOFR;
+ }
+ }
+ }
+
+ for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
+ if (dimSize == ShapedType::kDynamic &&
+ !isa<Value>(expandedShape[outDimIdx]) &&
+ !isa<Attribute>(expandedShape[outDimIdx])) {
+ Value actualSize =
+ rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
+ expandedShape[outDimIdx] = actualSize;
+ }
+ }
+
+ SmallVector<int64_t> staticExpandedShape;
+ for (OpFoldResult dim : expandedShape) {
+ if (auto attr = dyn_cast<Attribute>(dim)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ staticExpandedShape.push_back(intAttr.getInt());
+ } else {
+ staticExpandedShape.push_back(ShapedType::kDynamic);
+ }
+ } else {
+ staticExpandedShape.push_back(ShapedType::kDynamic);
+ }
+ }
+
+ auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc,
+ RankedTensorType::get(staticExpandedShape,
+ padOp.getSource().getType().getElementType()),
+ padOp.getSource(), reassociations);
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
@@ -2235,6 +2375,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..3ea0babfa3b9d 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
- %arg1 : tensor<?x?xi32>,
+ %arg1 : tensor<?x?xi32>,
%sz0: index, %sz1: index) ->
tensor<?x?x4x5xi32>
{
@@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
// -----
func.func @reshape_as_consumer_permutation_with_multiple_results
- (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
+ (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
%sz1: index, %sz2: index, %sz3: index, %sz4: index)
-> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
%c:2 = linalg.generic {
@@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPANDED]] :
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+ %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %cst : f32
+ } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+ %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+ return %expanded : tensor<32x16x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_expand(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+// CHECK: return %[[PADDED]] : tensor<32x16x258x258xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+ %padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %cst : f32
+ } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+ %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+ return %expanded : tensor<32x16x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]]
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+// CHECK: ^bb0(
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: return %[[PADDED]]
|
|
@llvm/pr-subscribers-mlir Author: Hyunsung Lee (ita9naiwa) ChangesI’ve implemented fusion for tensor.expand_shape → tensor.pad, but two gaps remain:
Before (expand then pad): func.func @<!-- -->fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
%c0 = arith.constant 0.0 : f32
%producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
%pad = tensor.pad %producer low[0, 1, 1] high[0, 1, 1] {
^bb0(%i: index, %j: index, %k: index):
tensor.yield %c0 : f32
} : tensor<512x256x256xf32> to tensor<512x258x258xf32>
%reshape = tensor.expand_shape %pad [[0, 1], [2], [3]]
output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
return %reshape : tensor<32x16x258x258xf32>
}After (reshape then pad): func.func @<!-- -->fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
%c0 = arith.constant 0.0 : f32
%producer = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
%reshape = tensor.expand_shape %producer [[0, 1], [2], [3]]
output_shape [32, 16, 258, 258] : tensor<512x256x256xf32> into tensor<32x16x256x256xf32>
%pad = tensor.pad %reshape low[0, 0, 1, 1] high[0, 0, 1, 1] {
^bb0(%i0: index, %i1: index, %i2: index, %i3: index):
tensor.yield %c0 : f32
} : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
return %pad : tensor<32x16x258x258xf32>
}Next steps CC @Max191 for awareness—would love any pointers on the collapse‑side implementation or dynamic‑shape handling! Full diff: https://github.com/llvm/llvm-project/pull/136681.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 9c0f6e5d6469e..39eed6dd4cba4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1100,6 +1100,146 @@ class FoldPadWithProducerReshapeOpByExpansion
ControlFusionFn controlFoldingReshapes;
};
+/// Pattern to fold a tensor.expand_shape op with its producer tensor.pad op
+/// by bubbling the expand_shape before the pad.
+struct FoldReshapeWithProducerPadOpByExpansion
+ : public OpRewritePattern<tensor::ExpandShapeOp> {
+
+ FoldReshapeWithProducerPadOpByExpansion(MLIRContext *context,
+ ControlFusionFn foldReshapes,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<tensor::ExpandShapeOp>(context, benefit),
+ controlFoldingReshapes(std::move(foldReshapes)) {}
+
+ LogicalResult matchAndRewrite(tensor::ExpandShapeOp expandOp,
+ PatternRewriter &rewriter) const override {
+ tensor::PadOp padOp = expandOp.getSrc().getDefiningOp<tensor::PadOp>();
+ if (!padOp)
+ return failure();
+
+ if (!padOp->hasOneUse())
+ return failure();
+
+ if (!controlFoldingReshapes(&expandOp.getSrcMutable())) {
+ return rewriter.notifyMatchFailure(expandOp,
+ "fusion blocked by control function");
+ }
+
+ SmallVector<ReassociationIndices> reassociations =
+ expandOp.getReassociationIndices();
+ SmallVector<OpFoldResult> low = padOp.getMixedLowPad();
+ SmallVector<OpFoldResult> high = padOp.getMixedHighPad();
+
+ auto isZeroPadding = [](OpFoldResult padValue) -> bool {
+ if (auto attr = dyn_cast<Attribute>(padValue)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 0;
+ }
+
+ if (auto val = dyn_cast<Value>(padValue)) {
+ if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) {
+ if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue()))
+ return attr.getInt() == 0;
+ }
+ }
+
+ // when padding is dynamic and not constant, we don't know if it's zero or
+ // not. so we return false here.
+ return false;
+ };
+
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[idx];
+ OpFoldResult h = high[idx];
+ if (reInd.size() != 1 && (!isZeroPadding(l) || !isZeroPadding(h)))
+ return failure();
+ }
+
+ SmallVector<OpFoldResult> newLow, newHigh;
+ for (auto [idx, reInd] : llvm::enumerate(reassociations)) {
+ for (size_t i = 0; i < reInd.size(); ++i) {
+ newLow.push_back(padOp.getMixedLowPad()[idx]);
+ newHigh.push_back(padOp.getMixedHighPad()[idx]);
+ }
+ }
+
+ Location loc = expandOp.getLoc();
+ auto finalType = cast<RankedTensorType>(expandOp.getType());
+ ArrayRef<int64_t> finalShape = finalType.getShape();
+
+ SmallVector<OpFoldResult> expandedShape;
+ for (int64_t dimSize : finalShape) {
+ if (dimSize == ShapedType::kDynamic) {
+ expandedShape.push_back(OpFoldResult{});
+ } else {
+ expandedShape.push_back(rewriter.getI64IntegerAttr(dimSize));
+ }
+ }
+
+ for (auto [inDimIdx, outGroup] : llvm::enumerate(reassociations)) {
+ OpFoldResult l = low[inDimIdx];
+ OpFoldResult h = high[inDimIdx];
+
+ if (!isZeroPadding(l) || !isZeroPadding(h)) {
+ auto srcType = cast<RankedTensorType>(padOp.getSource().getType());
+ int64_t originalSize = srcType.getDimSize(inDimIdx);
+
+ OpFoldResult originalSizeOFR;
+ if (originalSize == ShapedType::kDynamic) {
+ Value orgSizeVal =
+ rewriter.create<tensor::DimOp>(loc, padOp.getSource(), inDimIdx);
+ originalSizeOFR = orgSizeVal;
+ } else {
+ originalSizeOFR = rewriter.getI64IntegerAttr(originalSize);
+ }
+
+ for (auto outDimIdx : outGroup) {
+ expandedShape[outDimIdx] = originalSizeOFR;
+ }
+ }
+ }
+
+ for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
+ if (dimSize == ShapedType::kDynamic &&
+ !isa<Value>(expandedShape[outDimIdx]) &&
+ !isa<Attribute>(expandedShape[outDimIdx])) {
+ Value actualSize =
+ rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx);
+ expandedShape[outDimIdx] = actualSize;
+ }
+ }
+
+ SmallVector<int64_t> staticExpandedShape;
+ for (OpFoldResult dim : expandedShape) {
+ if (auto attr = dyn_cast<Attribute>(dim)) {
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ staticExpandedShape.push_back(intAttr.getInt());
+ } else {
+ staticExpandedShape.push_back(ShapedType::kDynamic);
+ }
+ } else {
+ staticExpandedShape.push_back(ShapedType::kDynamic);
+ }
+ }
+
+ auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc,
+ RankedTensorType::get(staticExpandedShape,
+ padOp.getSource().getType().getElementType()),
+ padOp.getSource(), reassociations);
+
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh,
+ padOp.getConstantPaddingValue(), padOp.getNofold());
+
+ rewriter.replaceOp(expandOp, newPadOp.getResult());
+ return success();
+ }
+
+private:
+ ControlFusionFn controlFoldingReshapes;
+};
+
/// Pattern to fold a tensor.expand_shape op with its producer generic op
/// by expanding the dimensionality of the loop in the producer op.
struct FoldReshapeWithGenericOpByExpansion
@@ -2235,6 +2375,8 @@ void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns(
controlFoldingReshapes);
patterns.add<FoldPadWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
+ patterns.add<FoldReshapeWithProducerPadOpByExpansion>(patterns.getContext(),
+ controlFoldingReshapes);
patterns.add<FoldWithProducerReshapeOpByExpansion>(patterns.getContext(),
controlFoldingReshapes);
}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index 67b4f2b32bad5..3ea0babfa3b9d 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -247,7 +247,7 @@ func.func @indexed_consumer_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xi32>,
#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @indexed_producer_reshape_consumer_fusion(%arg0 : tensor<?x?xi32>,
- %arg1 : tensor<?x?xi32>,
+ %arg1 : tensor<?x?xi32>,
%sz0: index, %sz1: index) ->
tensor<?x?x4x5xi32>
{
@@ -515,7 +515,7 @@ func.func @fuse_dynamic_dims(%arg0: tensor<?x?xf32>) -> tensor<?xf32> {
// -----
func.func @reshape_as_consumer_permutation_with_multiple_results
- (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
+ (%a : tensor<?x?x?xf32>, %b : tensor<?x?xf32>, %sz0: index,
%sz1: index, %sz2: index, %sz3: index, %sz4: index)
-> (tensor<?x2x?x3x4x?xf32>, tensor<?x?x2x3x4x?xf32>) {
%c:2 = linalg.generic {
@@ -893,3 +893,50 @@ func.func @move_operand_deps(%arg0 : tensor<?x128xf16>,
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[EXPANDED]] :
// CHECK: return %[[GENERIC]]
+
+// -----
+
+func.func @fold_tensor_pad_with_expand(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+ %padded = tensor.pad %0 low[0, 1, 1] high[0, 1, 1] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %cst : f32
+ } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+ %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+ return %expanded : tensor<32x16x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_expand(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK-DAG: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]] : tensor<32x16x256x256xf32>)
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: } : tensor<32x16x256x256xf32> to tensor<32x16x258x258xf32>
+// CHECK: return %[[PADDED]] : tensor<32x16x258x258xf32>
+
+// -----
+
+func.func @fold_tensor_pad_with_expand_dynamic_pad_zero(%arg0: tensor<512x256x256xf32>) -> tensor<32x16x258x258xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<512x256x256xf32>) -> tensor<512x256x256xf32>
+ %padded = tensor.pad %0 low[%c0, %c1, %c1] high[%c0, %c1, %c1] {
+ ^bb0(%i: index, %j: index, %k: index):
+ tensor.yield %cst : f32
+ } : tensor<512x256x256xf32> to tensor<512x258x258xf32>
+ %expanded = tensor.expand_shape %padded [[0, 1], [2], [3]] output_shape [32, 16, 258, 258] : tensor<512x258x258xf32> into tensor<32x16x258x258xf32>
+ return %expanded : tensor<32x16x258x258xf32>
+}
+// CHECK: func @fold_tensor_pad_with_expand_dynamic_pad_zero(
+// CHECK-SAME: %[[ARG0:[^:]+]]: tensor<512x256x256xf32>
+// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]]
+// CHECK: %[[FILLED:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EXPANDED]]
+// CHECK: %[[PADDED:.*]] = tensor.pad %[[FILLED]] low[0, 0, 1, 1] high[0, 0, 1, 1]
+// CHECK: ^bb0(
+// CHECK: tensor.yield %[[CST]] : f32
+// CHECK: return %[[PADDED]]
|
|
collapse_shape added. they works for both static and dynamic case two limitations are
|
Max191
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewing the FoldReshapeWithProducerPadOpByExpansion for the first round of comments. I think a lot of the cleanups can apply to both patterns, though. Nice work so far!
| #include "llvm/ADT/STLExtras.h" | ||
| #include "llvm/Support/LogicalResult.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I don't think these includes are needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like these are still here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, done.
| bool isZero(OpFoldResult value) { | ||
| if (auto attr = dyn_cast<Attribute>(value)) { | ||
| if (auto intAttr = dyn_cast<IntegerAttr>(attr)) | ||
| return intAttr.getInt() == 0; | ||
| } | ||
| if (auto val = dyn_cast<Value>(value)) { | ||
| if (auto constOp = val.getDefiningOp<arith::ConstantOp>()) { | ||
| if (auto attr = dyn_cast<IntegerAttr>(constOp.getValue())) | ||
| return attr.getInt() == 0; | ||
| } | ||
| } | ||
| return false; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use isConstantIntValue(value, 0) for this:
llvm-project/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
Lines 144 to 146 in 88a498c
| bool isConstantIntValue(OpFoldResult ofr, int64_t value) { | |
| return getConstantIntValue(ofr) == value; | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| for (auto [idx, reInd] : llvm::enumerate(reassociations)) { | ||
| OpFoldResult l = low[idx]; | ||
| OpFoldResult h = high[idx]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can use llvm::zip_equal
| for (auto [idx, reInd] : llvm::enumerate(reassociations)) { | |
| OpFoldResult l = low[idx]; | |
| OpFoldResult h = high[idx]; | |
| for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| for (size_t i = 0; i < reInd.size(); ++i) { | ||
| newLow.push_back(low[idx]); | ||
| newHigh.push_back(high[idx]); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can use append
| for (size_t i = 0; i < reInd.size(); ++i) { | |
| newLow.push_back(low[idx]); | |
| newHigh.push_back(high[idx]); | |
| } | |
| newLow.append(reInd.size(), low[idx]); | |
| newHigh.append(reInd.size(), high[idx]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| OpFoldResult l = low[idx]; | ||
| OpFoldResult h = high[idx]; | ||
| if (reInd.size() > 1 && (!isZero(l) || !isZero(h))) | ||
| return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Use rewriter.notifyMatchFailure() like above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| for (auto outDimIdx : reInd) { | ||
| expandedShape[outDimIdx] = originalSizeOFR; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know that the reInd should have a size of 1 from the previous matching, but I think the logic is more clear if you add an assert here that reInd.size() == 1, and then just do expandedShape[reInd[0]] = originalSizeOFR;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) { | ||
| if (dimSize == ShapedType::kDynamic && | ||
| !isa<Value>(expandedShape[outDimIdx]) && | ||
| !isa<Attribute>(expandedShape[outDimIdx])) { | ||
| Value actualSize = | ||
| rewriter.create<tensor::DimOp>(loc, expandOp.getSrc(), outDimIdx); | ||
| expandedShape[outDimIdx] = actualSize; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was necessary because some of the expandedShape were null right? I'm pretty sure this shouldn't be necessary if you use getMixedOutputShape as per my above comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did I understand your comment correctly?
this can be reduced
for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) {
if (dimSize == ShapedType::kDynamic &&
!isa<Value>(expandedShape[outDimIdx]) &&
!isa<Attribute>(expandedShape[outDimIdx])) {
expandedShape[outDimIdx] =
tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx);
}
}
| SmallVector<int64_t> staticExpandedShape; | ||
| for (OpFoldResult dim : expandedShape) { | ||
| if (auto attr = dyn_cast<Attribute>(dim)) { | ||
| if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { | ||
| staticExpandedShape.push_back(intAttr.getInt()); | ||
| } else { | ||
| staticExpandedShape.push_back(ShapedType::kDynamic); | ||
| } | ||
| } else { | ||
| staticExpandedShape.push_back(ShapedType::kDynamic); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use decomposeMixedValues here:
| SmallVector<int64_t> staticExpandedShape; | |
| for (OpFoldResult dim : expandedShape) { | |
| if (auto attr = dyn_cast<Attribute>(dim)) { | |
| if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { | |
| staticExpandedShape.push_back(intAttr.getInt()); | |
| } else { | |
| staticExpandedShape.push_back(ShapedType::kDynamic); | |
| } | |
| } else { | |
| staticExpandedShape.push_back(ShapedType::kDynamic); | |
| } | |
| } | |
| SmallVector<int64_t> staticExpandedShape; | |
| std::tie(staticExpandedShape, std::ignore) = decomposeMixedValues(expandedShape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( | ||
| loc, | ||
| RankedTensorType::get(staticExpandedShape, | ||
| padOp.getSource().getType().getElementType()), | ||
| padOp.getSource(), reassociations); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you also want to pass the mixed output shape here to use the correct builder:
| auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( | |
| loc, | |
| RankedTensorType::get(staticExpandedShape, | |
| padOp.getSource().getType().getElementType()), | |
| padOp.getSource(), reassociations); | |
| auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( | |
| loc, | |
| RankedTensorType::get(staticExpandedShape, | |
| padOp.getSource().getType().getElementType()), | |
| padOp.getSource(), reassociations, expandedShape); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| auto newPadOp = rewriter.create<tensor::PadOp>( | ||
| loc, expandOp.getType(), newExpandOp.getResult(), newLow, newHigh, | ||
| padOp.getConstantPaddingValue(), padOp.getNofold()); | ||
|
|
||
| rewriter.replaceOp(expandOp, newPadOp.getResult()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use rewriter.replaceOpWithNewOp<tensor::PadOp>?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
0496fed to
737d4a4
Compare
|
Thanks for contributing! I have two high-level asks. ASK 1: Please use the summary to describe and to justify your change (i.e. provide rationale). The summary should be self-contained and shouldn't refer to any external projects (within reason, but in this case I don't find the reference helpful). You could, for example, explain why "bubbling up"
ASK 2: "Before" and "after" what? Could you clarify? Thanks! |
|
Thanks @banach-space for thoughtful comment! ASK 2: "Before" and "after" what? Could you clarify? Before, After - before, after applying the pass introduced with this PR |
| for (auto [reInd, l, h] : llvm::zip_equal(reassociations, low, high)) { | ||
| if (reInd.size() > 1 && | ||
| (!isConstantIntValue(l, 0) || !isConstantIntValue(h, 0))) | ||
| return rewriter.notifyMatchFailure( | ||
| expandOp, "fusion blocked by non-zero padding"); | ||
| } | ||
|
|
||
| SmallVector<OpFoldResult> newLow, newHigh; | ||
| for (auto [idx, reInd] : llvm::enumerate(reassociations)) { | ||
| newLow.append(reInd.size(), low[idx]); | ||
| newHigh.append(reInd.size(), high[idx]); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: combine these 2 loops.
| for (auto [outDimIdx, dimSize] : llvm::enumerate(finalShape)) { | ||
| if (dimSize == ShapedType::kDynamic && | ||
| !isa<Value>(expandedShape[outDimIdx]) && | ||
| !isa<Attribute>(expandedShape[outDimIdx])) { | ||
| expandedShape[outDimIdx] = | ||
| tensor::getMixedSize(rewriter, loc, expandOp.getSrc(), outDimIdx); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can delete this loop because expandOp.getMixedOutputShape() will already populate the expandedShape with the right dynamic values.
| SmallVector<OpFoldResult> collapsedShape; | ||
| for (int64_t dimSize : finalShape) { | ||
| if (dimSize == ShapedType::kDynamic) { | ||
| collapsedShape.push_back(OpFoldResult{}); | ||
| } else { | ||
| collapsedShape.push_back(rewriter.getI64IntegerAttr(dimSize)); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need the mixed sizes for the collapsed shape. It is only used for getting the new type, so you can just collect static sizes instead (SmallVector<int64_t>).
|
Thanks @Max191 for thoughtful review! |
Max191
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good now, just a few more small comments!
| SmallVector<OpFoldResult> newLow, newHigh; | ||
| for (auto [idx, reInd] : llvm::enumerate(reassociations)) { | ||
| newLow.push_back(low[reInd[0]]); | ||
| newHigh.push_back(high[reInd[0]]); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: combine this loop with the loop above.
|
@Max191 Thanks, all addressed! |
Max191
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for addressing all my comments, nice work!
|
Hi @banach-space, could you please review this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, "folding" would mean that tensor.expand_shape disappears (i.e. is folded away), but that's not what is happening here, is it? This is merely "bubbling up".
Please update the description accordingly and add some example IR before and after. As an example: https://github.com/banach-space/llvm-project/blob/7d35eb58959c0ab398a9739f38bfb9754c5ba5e5/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp#L305-L317
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the summary!
tensor.{expand_shape,collapse_shape} sitting between a producer and tensor.pad blocks pad–producer fusion and other canonicalizations.
OK, then there should be a test demonstrating that this new transformation is unblocking fusion. Otherwise, it feels a bit ad-hoc. Do you have an example that we could turn into a test?
Thanks!
// RUN: mlir-opt %s -tensor-reshape-propagation -linalg-elementwise-fusion -canonicalize | FileCheck %s
func @pad_fuse(%arg0: tensor<2x3xf32>) -> tensor<2x4xf32> {
%generic = linalg.generic
{ indexing_maps = [affine_map<(d0, d1)->(d0, d1)>],
iterator_types = ["parallel", "parallel"] }
ins(%arg0 : tensor<2x3xf32>) outs(%arg0 : tensor<2x3xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
} -> tensor<2x3xf32>
%collapsed = tensor.collapse_shape %generic [[0, 1]]
: tensor<2x3xf32> into tensor<6xf32>
%cst = arith.constant 1.0 : f32
%padded = tensor.pad %collapsed low[1] high[1] {
^bb0(%idx: index):
tensor.yield %cst : f32
} : tensor<6xf32> to tensor<8xf32>
%expanded = tensor.expand_shape %padded [2, 4]
: tensor<8xf32> into tensor<2x4xf32>
return %expanded : tensor<2x4xf32>
}
// CHECK-LABEL: func @pad_fuse
// Function result type must be tensor<2x4xf32> (reshape + pad fused)
// CHECK-SAME: (%{{.*}}: tensor<2x3xf32>) -> tensor<2x4xf32>
// No standalone reshapes or pad should survive
// CHECK-NOT: tensor.collapse_shape
// CHECK-NOT: tensor.expand_shape
// CHECK-NOT: tensor.pad
// A single linalg.generic should produce the final tensor<2x4xf32>
// CHECK: linalg.generic
// CHECK-SAME: outs(%{{.*}} : tensor<2x4xf32>)This will be a example for this purpose, but I am not sure where to put this. could you recommend appropriate place to put? |
1af7a3c to
9cbd032
Compare
The flags that you used in your example don't exist: $ bin/mlir-opt -tensor-reshape-propagation -linalg-elementwise-fusion -canonicalize eample.mlir
mlir-opt: Unknown command line argument '-tensor-reshape-propagation'. Try: 'bin/mlir-opt --help'
mlir-opt: Did you mean '--sharding-propagation'?
mlir-opt: Unknown command line argument '-linalg-elementwise-fusion'. Try: 'bin/mlir-opt --help'
mlir-opt: Did you mean '--linalg-fuse-elementwise-ops'?Where did you take them from? Also, the MLIR file is "broken": $ bin/mlir-opt bad.mlir
bad.mlir:1:1: error: custom op 'func' is unknown (tried 'builtin.func' as well)
func @pad_fuse(%arg0: tensor<2x3xf32>) -> tensor<2x4xf32> {
^There's more issues then just this one.
Please provide a working example first. Sharing broken examples is bad use of reviewers time. |
|
I sincerely apologize — while cleaning up the MLIR code snippet, I relied on an AI assistant and inadvertently broke it. |
|
This transformation is common in ML workloads such as CNNs, where input tensors are padded (e.g., to match convolution kernel sizes) and then packed into smaller tiles for efficient Tensor Core or SIMD execution. Bubbling Example: Bubbling
|
|
I really apologize that I made a comment without testing comprehensively. @banach-space |
Motivation / Rationale
Why:
tensor.{expand_shape,collapse_shape} sitting between a producer and tensor.pad blocks pad–producer fusion and other canonicalizations. This leaves extra ops (dim, alloc, reshapes) and complicates bufferization/codegen.
What this patch enables:
• Move or remove those reshapes so tensor.pad directly sees the final shape.
• Result: simpler IR and downstream passes (fusion, folding, hoisting) apply cleanly.
Changes
Before applying this pass: {collapse,expand}_reshape then pad
After applying this pass. pad then {collapse,expand}_reshape
CC @Max191 for awareness—would love any pointers on the collapse‑side implementation or dynamic‑shape handling!