@@ -1982,14 +1982,43 @@ struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
19821982 return success ();
19831983 }
19841984};
1985+
1986+ struct FoldExpandOfCast : public OpRewritePattern <ExpandShapeOp> {
1987+ using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
1988+
1989+ LogicalResult matchAndRewrite (ExpandShapeOp expandOp,
1990+ PatternRewriter &rewriter) const override {
1991+ auto castOp = expandOp.getSrc ().getDefiningOp <CastOp>();
1992+ if (!canFoldIntoConsumerOp (castOp))
1993+ return failure ();
1994+
1995+ SmallVector<OpFoldResult> outputOfr =
1996+ getMixedValues (expandOp.getResultType ().getShape (),
1997+ expandOp.getOutputShape (), rewriter);
1998+ std::optional<SmallVector<int64_t >> constantOutputShape =
1999+ getConstantIntValues (outputOfr);
2000+ if (!constantOutputShape.has_value ()) {
2001+ return failure ();
2002+ }
2003+ auto newType = RankedTensorType::get (
2004+ constantOutputShape.value (), expandOp.getSrcType ().getElementType ());
2005+
2006+ auto newExpand = rewriter.create <ExpandShapeOp>(
2007+ castOp.getLoc (), newType, castOp.getSource (),
2008+ expandOp.getReassociationIndices ());
2009+ rewriter.replaceOpWithNewOp <CastOp>(expandOp, expandOp.getType (),
2010+ newExpand.getResult ());
2011+ return success ();
2012+ }
2013+ };
19852014} // namespace
19862015
19872016void ExpandShapeOp::getCanonicalizationPatterns (RewritePatternSet &results,
19882017 MLIRContext *context) {
19892018 results.add <
19902019 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand >,
19912020 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
1992- FoldReshapeWithConstant<ExpandShapeOp>,
2021+ FoldExpandOfCast, FoldReshapeWithConstant<ExpandShapeOp>,
19932022 FoldReshapeWithSplat<ExpandShapeOp>,
19942023 FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
19952024 FoldDimOfCollapseShape>(context);
0 commit comments