@@ -896,7 +896,29 @@ static Value getExpandedInitTensor(OpBuilder &builder,
896
896
}
897
897
898
898
namespace {
899
- struct FoldWithTensorReshapeOp : public OpRewritePattern <TensorReshapeOp> {
899
+ // / Since `init_tensor` operation creates a tensor needed only for its shape, a
900
+ // / subtensor of this is also needed only for its shape. The result can be
901
+ // / replaced by a new init_tensor operation of the same size as the subtensor
902
+ // / op.
903
+ struct FoldInitTensorWithSubTensorOp : public OpRewritePattern <SubTensorOp> {
904
+ using OpRewritePattern<SubTensorOp>::OpRewritePattern;
905
+
906
+ LogicalResult matchAndRewrite (SubTensorOp subtensorOp,
907
+ PatternRewriter &rewriter) const override {
908
+ if (!subtensorOp.source ().getDefiningOp <linalg::InitTensorOp>())
909
+ return failure ();
910
+ rewriter.replaceOpWithNewOp <linalg::InitTensorOp>(
911
+ subtensorOp, subtensorOp.sizes (),
912
+ llvm::to_vector<4 >(llvm::map_range (
913
+ subtensorOp.static_sizes (),
914
+ [](Attribute attr) { return attr.cast <IntegerAttr>().getInt (); })),
915
+ subtensorOp.getSourceType ().getElementType ());
916
+ return success ();
917
+ }
918
+ };
919
+
920
+ struct FoldInitTensorWithTensorReshapeOp
921
+ : public OpRewritePattern<TensorReshapeOp> {
900
922
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
901
923
902
924
LogicalResult matchAndRewrite (TensorReshapeOp reshapeOp,
@@ -921,8 +943,9 @@ struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
921
943
922
944
void InitTensorOp::getCanonicalizationPatterns (
923
945
OwningRewritePatternList &results, MLIRContext *context) {
924
- results.insert <FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp,
925
- ReplaceStaticShapeDims>(context);
946
+ results
947
+ .insert <FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
948
+ ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
926
949
}
927
950
928
951
// ===----------------------------------------------------------------------===//
0 commit comments