@@ -1642,69 +1642,18 @@ class ConvertAtenSqueezeDimOp : public OpConversionPattern<AtenSqueezeDimOp> {
16421642 ConversionPatternRewriter &rewriter) const override {
16431643 if (failed (verifyLinalgCompatibleTypes (op, rewriter)))
16441644 return failure ();
1645- Value input = adaptor.getSelf ();
1646- auto inputType = cast<RankedTensorType>(input.getType ());
1647- int64_t inputRank = inputType.getRank ();
1648-
1649- if (inputRank == 0 ) {
1650- return rewriter.notifyMatchFailure (
1651- op, " zero input rank should have been handled by the folder" );
1652- }
1653-
16541645 int64_t dim;
16551646 if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
16561647 return rewriter.notifyMatchFailure (op, " dim must be constant" );
1657- dim = toPositiveDim (dim, inputRank);
1658- if (!isValidDim (dim, inputRank))
1659- return rewriter.notifyMatchFailure (op, " dim is statically invalid" );
1660-
1661- // assert dynamic squeeze dim size == 1
1662- if (inputType.isDynamicDim (dim)) {
1663- Value cstDim = rewriter.create <arith::ConstantIndexOp>(op.getLoc (), dim);
1664- Value dimVal = rewriter.create <tensor::DimOp>(op.getLoc (), input, cstDim);
1665- Value cstOne = rewriter.create <arith::ConstantIndexOp>(op.getLoc (), 1 );
1666- Value cmp = rewriter.create <arith::CmpIOp>(
1667- op.getLoc (), arith::CmpIPredicate::eq, dimVal, cstOne);
1668- rewriter.create <cf::AssertOp>(
1669- op.getLoc (), cmp,
1670- rewriter.getStringAttr (
1671- " Expected dynamic squeeze dim size to be statically 1" ));
1672- }
1673-
1674- const TypeConverter *typeConverter = getTypeConverter ();
1675- auto resultType =
1676- cast<RankedTensorType>(typeConverter->convertType (op.getType ()));
1677- int64_t resultRank = resultType.getRank ();
16781648
1679- // If the dim(th) dimension of operand tensor type is not statically unit,
1680- // `aten.squeeze` will behave as an identity operation.
1681- if (inputType. getDimSize (dim) != 1 && !inputType. isDynamicDim (dim )) {
1682- rewriter.replaceOpWithNewOp <tensor::CastOp> (op, resultType, input);
1683- return success ( );
1649+ auto squeezeTensorInfo =
1650+ squeezeTensor (rewriter, op, adaptor. getSelf (), dim);
1651+ if (failed (squeezeTensorInfo )) {
1652+ return rewriter.notifyMatchFailure (op,
1653+ " cannot generate unsqueeze tensor " );
16841654 }
16851655
1686- SmallVector<ReassociationIndices> reassociationMap (resultRank);
1687- bool alreadyCrossedSqueezedDim = false ;
1688- for (int i = 0 ; i != resultRank; i++) {
1689- if (alreadyCrossedSqueezedDim) {
1690- reassociationMap[i].push_back (i + 1 );
1691- } else {
1692- reassociationMap[i].push_back (i);
1693- if (dim != 0 && i != dim - 1 )
1694- continue ;
1695-
1696- alreadyCrossedSqueezedDim = true ;
1697- if (dim == 0 )
1698- reassociationMap[0 ].push_back (1 );
1699- if (i == dim - 1 )
1700- reassociationMap[i].push_back (dim);
1701- }
1702- }
1703- // Note: In case the operand tensor type is of unit rank and is statically
1704- // shaped with unit dimension, the `reassociationMap` will be empty and the
1705- // input will be collapsed to a 0-D tensor.
1706- rewriter.replaceOpWithNewOp <tensor::CollapseShapeOp>(op, resultType, input,
1707- reassociationMap);
1656+ rewriter.replaceOp (op, squeezeTensorInfo.value ());
17081657 return success ();
17091658 }
17101659};
@@ -1722,36 +1671,15 @@ class ConvertAtenUnsqueezeOp : public OpConversionPattern<AtenUnsqueezeOp> {
17221671 int64_t dim;
17231672 if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim)))
17241673 return rewriter.notifyMatchFailure (op, " dim must be constant" );
1725- auto inputRank =
1726- cast<RankedTensorType>(adaptor.getSelf ().getType ()).getRank ();
1727- dim = toPositiveDim (dim, inputRank + 1 );
1728- if (!isValidDim (dim, inputRank + 1 ))
1729- return rewriter.notifyMatchFailure (op, " dim is statically invalid" );
17301674
1731- SmallVector<ReassociationIndices> reassociationMap (inputRank);
1732- // From the perspective of the reassociation map, the situation of
1733- // unsqueezing before or after the last dimension is symmetrical.
1734- // Normalize it to the "before" case.
1735- // The 0 case is special here, since there is no last dimension to insert
1736- // before -- we simply rely on the loop below iterating 0 times.
1737- if (dim == inputRank && inputRank != 0 )
1738- dim = inputRank - 1 ;
1739- bool alreadyCrossedExpandedDim = false ;
1740- for (int i = 0 ; i != inputRank; i++) {
1741- if (alreadyCrossedExpandedDim) {
1742- reassociationMap[i].push_back (i + 1 );
1743- } else {
1744- reassociationMap[i].push_back (i);
1745- if (i == dim) {
1746- reassociationMap[i].push_back (i + 1 );
1747- alreadyCrossedExpandedDim = true ;
1748- }
1749- }
1675+ auto unsqueezeTensorInfo =
1676+ unsqueezeTensor (rewriter, op, adaptor.getSelf (), dim);
1677+ if (failed (unsqueezeTensorInfo)) {
1678+ return rewriter.notifyMatchFailure (op,
1679+ " cannot generate unsqueeze tensor" );
17501680 }
1751- auto resultType = cast<RankedTensorType>(
1752- getTypeConverter ()->convertType (op->getResult (0 ).getType ()));
1753- rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
1754- op, resultType, adaptor.getSelf (), reassociationMap);
1681+
1682+ rewriter.replaceOp (op, unsqueezeTensorInfo.value ());
17551683 return success ();
17561684 }
17571685};
0 commit comments