Skip to content

Commit e8e97aa

Browse files
[Linalg] Fix linalg.pack canonicalization priority issue (#160340)
The current canonicalization prioritizes unpack->pack folder over dropping padding if not needed but that folder fails if there is padding and hence blocks all canonicalizations. We now put the failures in the if statement so that we can proceed if the unpack->pack folder conditions are not met. Signed-off-by: Nirvedh Meshram <[email protected]>
1 parent 81cf8b8 commit e8e97aa

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5583,14 +5583,13 @@ static bool inferStaticShape(PackOp packOp, SmallVectorImpl<int64_t> &srcShape,
55835583
LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
55845584
// Fold an pack(unpack(x)) to x.
55855585
if (auto unPackOp = packOp.getSource().getDefiningOp<UnPackOp>()) {
5586-
if (unPackOp.getSourceType() != packOp.getDestType())
5587-
return failure();
5588-
if (packOp.getPaddingValue() ||
5589-
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
5590-
!haveSameTiles(packOp, unPackOp))
5591-
return failure();
5592-
rewriter.replaceOp(packOp, unPackOp.getSource());
5593-
return success();
5586+
if (unPackOp.getSourceType() == packOp.getDestType() &&
5587+
!packOp.getPaddingValue() &&
5588+
hasSameInnerOuterAttribute(packOp, unPackOp) &&
5589+
haveSameTiles(packOp, unPackOp)) {
5590+
rewriter.replaceOp(packOp, unPackOp.getSource());
5591+
return success();
5592+
}
55945593
}
55955594

55965595
// Fold optional PaddingValue operand away if padding is not needed.

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1756,10 +1756,11 @@ func.func @pack_unpack(%t: tensor<16x16x?x?xf32>, %tile1: index, %tile2: index)
17561756
// CHECK-SAME: %[[T:.+]]: tensor<16x16x8x8xf32>
17571757
// CHECK: return %[[T]] : tensor<16x16x8x8xf32>
17581758
func.func @pack_unpack(%t: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
1759+
%cst = arith.constant 0.000000e+00 : f32
17591760
%tensor_empty = tensor.empty() : tensor<128x128xf32>
17601761
%unpacked = linalg.unpack %t inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty : tensor<16x16x8x8xf32> -> tensor<128x128xf32>
17611762
%tensor_empty1 = tensor.empty() : tensor<16x16x8x8xf32>
1762-
%packed = linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
1763+
%packed = linalg.pack %unpacked padding_value(%cst : f32) inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %tensor_empty1 : tensor<128x128xf32> -> tensor<16x16x8x8xf32>
17631764
return %packed : tensor<16x16x8x8xf32>
17641765
}
17651766

0 commit comments

Comments
 (0)