Skip to content

Commit 6764919

Browse files
committed
[mlir] Fold expand of cast
Sink tensor.cast op through tensor.expand_shape ops when it makes the expand op more static. This allows for other ops further down infer their shapes.
1 parent 8bb100b commit 6764919

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1982,14 +1982,43 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
19821982
return success();
19831983
}
19841984
};
1985+
1986+
struct FoldExpandOfCast : public OpRewritePattern<ExpandShapeOp> {
1987+
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
1988+
1989+
LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
1990+
PatternRewriter &rewriter) const override {
1991+
auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
1992+
if (!canFoldIntoConsumerOp(castOp))
1993+
return failure();
1994+
1995+
SmallVector<OpFoldResult> outputOfr =
1996+
getMixedValues(expandOp.getResultType().getShape(),
1997+
expandOp.getOutputShape(), rewriter);
1998+
std::optional<SmallVector<int64_t>> constantOutputShape =
1999+
getConstantIntValues(outputOfr);
2000+
if (!constantOutputShape.has_value()) {
2001+
return failure();
2002+
}
2003+
auto newType = RankedTensorType::get(
2004+
constantOutputShape.value(), expandOp.getSrcType().getElementType());
2005+
2006+
auto newExpand = rewriter.create<ExpandShapeOp>(
2007+
castOp.getLoc(), newType, castOp.getSource(),
2008+
expandOp.getReassociationIndices());
2009+
rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
2010+
newExpand.getResult());
2011+
return success();
2012+
}
2013+
};
19852014
} // namespace
19862015

19872016
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
19882017
MLIRContext *context) {
19892018
results.add<
19902019
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
19912020
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
1992-
FoldReshapeWithConstant<ExpandShapeOp>,
2021+
FoldExpandOfCast, FoldReshapeWithConstant<ExpandShapeOp>,
19932022
FoldReshapeWithSplat<ExpandShapeOp>,
19942023
FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
19952024
FoldDimOfCollapseShape>(context);

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2718,3 +2718,17 @@ func.func @pack_dont_drop_attributes(%arg0: tensor<?x?x?xf16>, %arg1: tensor<128
27182718
%pack = tensor.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor<?x?x?xf16> -> tensor<128x?x100x16x1xf16>
27192719
return %pack : tensor<128x?x100x16x1xf16>
27202720
}
2721+
2722+
// -----
2723+
2724+
func.func @fold_expand_of_cast(%arg0 : tensor<10x10xf32>)
2725+
-> tensor<?x?x?xf32> {
2726+
%c1 = arith.constant 1 : index
2727+
%c10 = arith.constant 10 : index
2728+
%0 = tensor.cast %arg0 : tensor<10x10xf32> to tensor<?x?xf32>
2729+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%c10, %c1, %c10]
2730+
: tensor<?x?xf32> into tensor<?x?x?xf32>
2731+
return %1 : tensor<?x?x?xf32>
2732+
}
2733+
// CHECK-LABEL: func.func @fold_expand_of_cast
2734+
// CHECK: tensor.expand_shape %{{.*}} {{\[}}[0, 1], [2]] output_shape [10, 1, 10]

0 commit comments

Comments
 (0)