Skip to content

Commit 10d26d9

Browse files
committed
fixup! [mlir][tensor] Introduce FoldTensorCastUnPackOp
Address PR comments
1 parent 1bc2d8e commit 10d26d9

File tree

2 files changed

+47
-53
lines changed

2 files changed

+47
-53
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 43 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
47954795
return newOperands;
47964796
}
47974797

4798+
// Given the (potentially) updated packed type, `newPackedTy`, generates an
4799+
// updated mixed-tile-sizes attribute. A tile size is updated only
4800+
// when:
4801+
// * a dim from newPackedTy is static, and
4802+
// * the corresponding size from mixedTiles is still dynamic.
4803+
// Otherwise, the original tile size is preserved.
4804+
// Note - packed-type-dim and mixed-tile-size should always match!
4805+
static SmallVector<OpFoldResult>
4806+
getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
4807+
SmallVector<OpFoldResult> mixedTiles) {
4808+
SmallVector<OpFoldResult> newMixedTileSizes;
4809+
for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
4810+
.getShape()
4811+
.take_back(mixedTiles.size()),
4812+
mixedTiles)) {
4813+
int64_t shape = std::get<0>(it);
4814+
if (shape == ShapedType::kDynamic) {
4815+
newMixedTileSizes.push_back(std::get<1>(it));
4816+
continue;
4817+
}
4818+
4819+
// If the current result dim is static, update the dynamic mixed-size
4820+
// (provided the original value is dynamic).
4821+
OpFoldResult tile = std::get<1>(it);
4822+
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
4823+
// Already a constant
4824+
newMixedTileSizes.push_back(tile);
4825+
} else {
4826+
assert(getConstantIntValue(tile).value() == shape &&
4827+
"tile size and dim size don't match!");
4828+
newMixedTileSizes.push_back(
4829+
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4830+
}
4831+
}
4832+
4833+
return newMixedTileSizes;
4834+
}
4835+
47984836
/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
47994837
/// `tensor.cast` has source that is more static than the consuming op.
48004838
///
@@ -4821,28 +4859,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48214859
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
48224860

48234861
// Get the updated mixed-tile-sizes attribute.
4824-
SmallVector<OpFoldResult> newMixedTileSizes;
4825-
for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
4826-
.getShape()
4827-
.take_back(op.getMixedTiles().size()),
4828-
op.getMixedTiles())) {
4829-
int64_t shape = std::get<0>(it);
4830-
if (shape == ShapedType::kDynamic) {
4831-
newMixedTileSizes.push_back(std::get<1>(it));
4832-
continue;
4833-
}
4834-
4835-
if (Attribute attr =
4836-
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4837-
// Already a constant
4838-
newMixedTileSizes.push_back(std::get<1>(it));
4839-
} else {
4840-
assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
4841-
"tile size and dim size don't match!");
4842-
newMixedTileSizes.push_back(
4843-
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4844-
}
4845-
}
4862+
SmallVector<OpFoldResult> newMixedTileSizes =
4863+
getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
48464864

48474865
// Clone op.
48484866
// TODO: Strictly speaking, discardable attributes should be _discarded_ at
@@ -4873,7 +4891,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
48734891
/// Example:
48744892
/// ```mlir
48754893
/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
4876-
/// %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
4894+
/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
48774895
/// ```
48784896
///
48794897
/// folds into:
@@ -4894,32 +4912,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
48944912
Value sourceTensor = newOperands[0];
48954913

48964914
// Get the updated mixed-tile-sizes attribute.
4897-
SmallVector<OpFoldResult> newMixedTileSizes;
4898-
for (auto it : llvm::zip(cast<ShapedType>(sourceTensor.getType())
4899-
.getShape()
4900-
.take_back(op.getMixedTiles().size()),
4901-
op.getMixedTiles())) {
4902-
int64_t shape = std::get<0>(it);
4903-
// If the current source shape is dynamic, just preserve this mixed
4904-
// size.
4905-
if (shape == ShapedType::kDynamic) {
4906-
newMixedTileSizes.push_back(std::get<1>(it));
4907-
continue;
4908-
}
4909-
4910-
// If the current source is static, update the dynamic mixed-size
4911-
// (provided the original value is dynamic).
4912-
if (Attribute attr =
4913-
llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
4914-
// Already a constant
4915-
newMixedTileSizes.push_back(std::get<1>(it));
4916-
} else {
4917-
assert(getConstantIntValue(std::get<1>(it)).value() == shape &&
4918-
"tile size and dim size don't match!");
4919-
newMixedTileSizes.push_back(
4920-
(rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
4921-
}
4922-
}
4915+
SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
4916+
rewriter, sourceTensor.getType(), op.getMixedTiles());
49234917

49244918
// Clone op.
49254919
// TODO: Strictly speaking, discardable attributes should be _discarded_ at

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2795,7 +2795,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x
27952795
// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> {
27962796
// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32)
27972797
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]]
2798-
// CHECK-SAME: some_attr
2798+
// CHECK-SAME: test_attr
27992799
// CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32>
28002800
// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32>
28012801
func.func @fold_cast_pack_dynamic_tile_size(
@@ -2808,7 +2808,7 @@ func.func @fold_cast_pack_dynamic_tile_size(
28082808
%pack = tensor.pack %src padding_value(%pad : i32)
28092809
inner_dims_pos = [0, 1]
28102810
inner_tiles = [%c8, 1]
2811-
into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
2811+
into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32>
28122812
%res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32>
28132813
return %res : tensor<1x1x8x1xi32>
28142814
}
@@ -2818,7 +2818,7 @@ func.func @fold_cast_pack_dynamic_tile_size(
28182818
// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size(
28192819
// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>,
28202820
// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> {
2821-
// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
2821+
// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32>
28222822
// CHECK: return %[[RES]] : tensor<7x?xi32>
28232823
func.func @fold_cast_unpack_dynamic_tile_size(
28242824
%src: tensor<1x1x8x1xi32>,
@@ -2829,7 +2829,7 @@ func.func @fold_cast_unpack_dynamic_tile_size(
28292829
%unpack = tensor.unpack %cast
28302830
inner_dims_pos = [0, 1]
28312831
inner_tiles = [%c8, 1]
2832-
into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
2832+
into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
28332833
return %unpack : tensor<7x?xi32>
28342834
}
28352835

0 commit comments

Comments
 (0)