@@ -1986,90 +1986,6 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
19861986 }
19871987};
19881988
1989- struct FoldDimOfExpandShape : public OpRewritePattern <DimOp> {
1990- using OpRewritePattern<DimOp>::OpRewritePattern;
1991-
1992- LogicalResult matchAndRewrite (DimOp dimOp,
1993- PatternRewriter &rewriter) const override {
1994- auto expandShapeOp = dimOp.getSource ().getDefiningOp <ExpandShapeOp>();
1995- if (!expandShapeOp)
1996- return failure ();
1997-
1998- // Only constant dimension values are supported.
1999- std::optional<int64_t > dim = dimOp.getConstantIndex ();
2000- if (!dim.has_value ())
2001- return failure ();
2002-
2003- // Skip static dims. These are folded to constant ops.
2004- RankedTensorType resultType = expandShapeOp.getResultType ();
2005- if (!resultType.isDynamicDim (*dim))
2006- return failure ();
2007-
2008- // Find reassociation group that contains this result dimension.
2009- int64_t srcDim = expandShapeOp.getCorrespondingSourceDim (*dim);
2010-
2011- // `dim` is the only dynamic dimension in `group`. (Otherwise, the
2012- // ExpandShapeOp would be ambiguous.)
2013- int64_t product = 1 ;
2014- ReassociationIndices grp = expandShapeOp.getReassociationIndices ()[srcDim];
2015- for (int64_t d : grp) {
2016- if (d != dim) {
2017- assert (!resultType.isDynamicDim (d) && " expected static dim" );
2018- product *= resultType.getDimSize (d);
2019- }
2020- }
2021-
2022- // result dim size = src dim size / (product(other dims in reassoc group))
2023- Value srcDimSz =
2024- rewriter.create <DimOp>(dimOp.getLoc (), expandShapeOp.getSrc (), srcDim);
2025- AffineExpr expr;
2026- bindSymbols (dimOp.getContext (), expr);
2027- rewriter.replaceOpWithNewOp <affine::AffineApplyOp>(
2028- dimOp, expr.floorDiv (product), srcDimSz);
2029- return success ();
2030- }
2031- };
2032-
2033- struct FoldDimOfCollapseShape : public OpRewritePattern <DimOp> {
2034- using OpRewritePattern<DimOp>::OpRewritePattern;
2035-
2036- LogicalResult matchAndRewrite (DimOp dimOp,
2037- PatternRewriter &rewriter) const override {
2038- auto collapseShapeOp = dimOp.getSource ().getDefiningOp <CollapseShapeOp>();
2039- if (!collapseShapeOp)
2040- return failure ();
2041-
2042- // Only constant dimension values are supported.
2043- std::optional<int64_t > dim = dimOp.getConstantIndex ();
2044- if (!dim.has_value () ||
2045- dim.value () >= collapseShapeOp.getResultType ().getRank ())
2046- return failure ();
2047-
2048- // Skip static dims. These are folded to constant ops.
2049- RankedTensorType resultType = collapseShapeOp.getResultType ();
2050- if (!resultType.isDynamicDim (*dim))
2051- return failure ();
2052-
2053- // Get reassociation group of the result dimension.
2054- ReassociationIndices group =
2055- collapseShapeOp.getReassociationIndices ()[*dim];
2056-
2057- // result dim size = product(dims in reassoc group)
2058- SmallVector<Value> srcDimSizes;
2059- SmallVector<AffineExpr> syms;
2060- AffineExpr product;
2061- for (const auto &it : llvm::enumerate (group)) {
2062- srcDimSizes.push_back (rewriter.create <DimOp>(
2063- dimOp.getLoc (), collapseShapeOp.getSrc (), it.value ()));
2064- syms.push_back (rewriter.getAffineSymbolExpr (it.index ()));
2065- product = product ? product * syms.back () : syms.back ();
2066- }
2067- rewriter.replaceOpWithNewOp <affine::AffineApplyOp>(dimOp, product,
2068- srcDimSizes);
2069- return success ();
2070- }
2071- };
2072-
20731989// / Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
20741990// / matching constant output_shape operands of the expand. This makes the
20751991// / `tensor.expand_shape` more static and creates a consumer cast that can be
0 commit comments