Skip to content

Commit 7c15e0f

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Add canonicalization for init_tensor -> subtensor op.
Differential Revision: https://reviews.llvm.org/D95305
1 parent 48bdd67 commit 7c15e0f

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,29 @@ static Value getExpandedInitTensor(OpBuilder &builder,
896896
}
897897

898898
namespace {
899-
struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
899+
/// Since `init_tensor` operation creates a tensor needed only for its shape, a
900+
/// subtensor of this is also needed only for its shape. The result can be
901+
/// replaced by a new init_tensor operation of the same size as the subtensor
902+
/// op.
903+
struct FoldInitTensorWithSubTensorOp : public OpRewritePattern<SubTensorOp> {
904+
using OpRewritePattern<SubTensorOp>::OpRewritePattern;
905+
906+
LogicalResult matchAndRewrite(SubTensorOp subtensorOp,
907+
PatternRewriter &rewriter) const override {
908+
if (!subtensorOp.source().getDefiningOp<linalg::InitTensorOp>())
909+
return failure();
910+
rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
911+
subtensorOp, subtensorOp.sizes(),
912+
llvm::to_vector<4>(llvm::map_range(
913+
subtensorOp.static_sizes(),
914+
[](Attribute attr) { return attr.cast<IntegerAttr>().getInt(); })),
915+
subtensorOp.getSourceType().getElementType());
916+
return success();
917+
}
918+
};
919+
920+
struct FoldInitTensorWithTensorReshapeOp
921+
: public OpRewritePattern<TensorReshapeOp> {
900922
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
901923

902924
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
@@ -921,8 +943,9 @@ struct FoldWithTensorReshapeOp : public OpRewritePattern<TensorReshapeOp> {
921943

922944
void InitTensorOp::getCanonicalizationPatterns(
923945
OwningRewritePatternList &results, MLIRContext *context) {
924-
results.insert<FoldWithTensorReshapeOp, ReplaceDimOfInitTensorOp,
925-
ReplaceStaticShapeDims>(context);
946+
results
947+
.insert<FoldInitTensorWithSubTensorOp, FoldInitTensorWithTensorReshapeOp,
948+
ReplaceDimOfInitTensorOp, ReplaceStaticShapeDims>(context);
926949
}
927950

928951
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,3 +668,19 @@ func @keep_not_noop(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>)
668668
// CHECK-LABEL: func @keep_not_noop
669669
// CHECK: %[[RESULT:.+]]:2 = linalg.generic
670670
// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1
671+
672+
// -----
673+
674+
func @fold_init_tensor_with_subtensor
675+
(%arg0 : index, %arg1 : index) -> tensor<5x?x20xf32>
676+
{
677+
%0 = linalg.init_tensor[%arg0, 10, 40] : tensor<?x10x40xf32>
678+
%1 = subtensor %0[0, 0, 0] [5, %arg1, 20] [1, 1, 1]
679+
: tensor<?x10x40xf32> to tensor<5x?x20xf32>
680+
return %1 : tensor<5x?x20xf32>
681+
}
682+
// CHECK: func @fold_init_tensor_with_subtensor
683+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
684+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
685+
// CHECK: %[[T0:.+]] = linalg.init_tensor [5, %[[ARG1]], 20]
686+
// CHECK: return %[[T0]]

0 commit comments

Comments
 (0)