@@ -1673,80 +1673,6 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
1673
1673
1674
1674
} // namespace
1675
1675
1676
- namespace {
1677
- class ConvertAtenOuterOp : public OpConversionPattern <AtenOuterOp> {
1678
- public:
1679
- using OpConversionPattern::OpConversionPattern;
1680
- LogicalResult
1681
- matchAndRewrite (AtenOuterOp op, OpAdaptor adaptor,
1682
- ConversionPatternRewriter &rewriter) const override {
1683
-
1684
- Location loc = op->getLoc ();
1685
- Value lhs = adaptor.getSelf ();
1686
- Value rhs = adaptor.getVec2 ();
1687
-
1688
- if (failed (verifyLinalgCompatibleTypes (op, rewriter))) {
1689
- return failure ();
1690
- }
1691
- auto lhsType = dyn_cast<RankedTensorType>(lhs.getType ());
1692
- auto rhsType = dyn_cast<RankedTensorType>(rhs.getType ());
1693
-
1694
- if (!lhsType || !rhsType)
1695
- return rewriter.notifyMatchFailure (op,
1696
- " outer: expected ranked tensor types" );
1697
- if (lhsType.getRank () != 1 || rhsType.getRank () != 1 )
1698
- return rewriter.notifyMatchFailure (
1699
- op, " outer: expected 1D tensors for outer op lowering" );
1700
-
1701
- Value lhsDim = getDimOp (rewriter, loc, lhs, 0 );
1702
- Value rhsDim = getDimOp (rewriter, loc, rhs, 0 );
1703
- Type elementType = lhsType.getElementType ();
1704
- Type newResultType = getTypeConverter ()->convertType (op.getType ());
1705
-
1706
- // Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1707
- SmallVector<OpFoldResult> resultShape =
1708
- getAsOpFoldResult (ValueRange{lhsDim, rhsDim});
1709
- Value initTensor =
1710
- rewriter.create <tensor::EmptyOp>(loc, resultShape, elementType);
1711
-
1712
- // Set up affine indexing maps:
1713
- // We create a 2D loop iteration space. For the lhs, we use the first index
1714
- // (i), for the rhs, the second index (j), and for the result, both (i, j).
1715
- AffineMap mapLhs =
1716
- AffineMap::get (2 , /* symbolCount=*/ 0 , {rewriter.getAffineDimExpr (0 )},
1717
- rewriter.getContext ());
1718
- AffineMap mapRhs =
1719
- AffineMap::get (2 , /* symbolCount=*/ 0 , {rewriter.getAffineDimExpr (1 )},
1720
- rewriter.getContext ());
1721
- AffineMap mapOut =
1722
- AffineMap::getMultiDimIdentityMap (2 , rewriter.getContext ());
1723
-
1724
- SmallVector<utils::IteratorType, 2 > iteratorTypes = {
1725
- utils::IteratorType::parallel, utils::IteratorType::parallel};
1726
-
1727
- Value outerProd =
1728
- rewriter
1729
- .create <linalg::GenericOp>(
1730
- loc, initTensor.getType (),
1731
- /* inputs=*/ ValueRange{lhsDim, rhsDim},
1732
- /* outputs=*/ initTensor,
1733
- /* indexingMaps=*/
1734
- SmallVector<AffineMap, 3 >{mapLhs, mapRhs, mapOut},
1735
- /* iteratortType=*/ iteratorTypes,
1736
- [&](OpBuilder &b, Location loc, ValueRange args) {
1737
- Value lhsElem = args[0 ];
1738
- Value rhsElem = args[1 ];
1739
- Value mult = b.create <arith::MulFOp>(loc, lhsElem, rhsElem);
1740
- b.create <linalg::YieldOp>(loc, mult);
1741
- })
1742
- .getResult (0 );
1743
-
1744
- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1745
- return success ();
1746
- }
1747
- };
1748
- } // namespace
1749
-
1750
1676
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality (
1751
1677
TypeConverter &typeConverter, RewritePatternSet &patterns,
1752
1678
ConversionTarget &target) {
@@ -1763,6 +1689,4 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
1763
1689
patterns.add <ConvertAtenConvolutionOp>(typeConverter, context);
1764
1690
target.addIllegalOp <AtenFftRfftOp>();
1765
1691
patterns.add <ConvertAtenFftRfftOp>(typeConverter, context);
1766
- target.addIllegalOp <AtenOuterOp>();
1767
- patterns.add <ConvertAtenOuterOp>(typeConverter, context);
1768
1692
}
0 commit comments