@@ -1759,80 +1759,6 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
1759
1759
1760
1760
} // namespace
1761
1761
1762
- namespace {
1763
- class ConvertAtenOuterOp : public OpConversionPattern <AtenOuterOp> {
1764
- public:
1765
- using OpConversionPattern::OpConversionPattern;
1766
- LogicalResult
1767
- matchAndRewrite (AtenOuterOp op, OpAdaptor adaptor,
1768
- ConversionPatternRewriter &rewriter) const override {
1769
-
1770
- Location loc = op->getLoc ();
1771
- Value lhs = adaptor.getSelf ();
1772
- Value rhs = adaptor.getVec2 ();
1773
-
1774
- if (failed (verifyLinalgCompatibleTypes (op, rewriter))) {
1775
- return failure ();
1776
- }
1777
- auto lhsType = dyn_cast<RankedTensorType>(lhs.getType ());
1778
- auto rhsType = dyn_cast<RankedTensorType>(rhs.getType ());
1779
-
1780
- if (!lhsType || !rhsType)
1781
- return rewriter.notifyMatchFailure (op,
1782
- " outer: expected ranked tensor types" );
1783
- if (lhsType.getRank () != 1 || rhsType.getRank () != 1 )
1784
- return rewriter.notifyMatchFailure (
1785
- op, " outer: expected 1D tensors for outer op lowering" );
1786
-
1787
- Value lhsDim = getDimOp (rewriter, loc, lhs, 0 );
1788
- Value rhsDim = getDimOp (rewriter, loc, rhs, 0 );
1789
- Type elementType = lhsType.getElementType ();
1790
- Type newResultType = getTypeConverter ()->convertType (op.getType ());
1791
-
1792
- // Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1793
- SmallVector<OpFoldResult> resultShape =
1794
- getAsOpFoldResult (ValueRange{lhsDim, rhsDim});
1795
- Value initTensor =
1796
- rewriter.create <tensor::EmptyOp>(loc, resultShape, elementType);
1797
-
1798
- // Set up affine indexing maps:
1799
- // We create a 2D loop iteration space. For the lhs, we use the first index
1800
- // (i), for the rhs, the second index (j), and for the result, both (i, j).
1801
- AffineMap mapLhs =
1802
- AffineMap::get (2 , /* symbolCount=*/ 0 , {rewriter.getAffineDimExpr (0 )},
1803
- rewriter.getContext ());
1804
- AffineMap mapRhs =
1805
- AffineMap::get (2 , /* symbolCount=*/ 0 , {rewriter.getAffineDimExpr (1 )},
1806
- rewriter.getContext ());
1807
- AffineMap mapOut =
1808
- AffineMap::getMultiDimIdentityMap (2 , rewriter.getContext ());
1809
-
1810
- SmallVector<utils::IteratorType, 2 > iteratorTypes = {
1811
- utils::IteratorType::parallel, utils::IteratorType::parallel};
1812
-
1813
- Value outerProd =
1814
- rewriter
1815
- .create <linalg::GenericOp>(
1816
- loc, initTensor.getType (),
1817
- /* inputs=*/ ValueRange{lhsDim, rhsDim},
1818
- /* outputs=*/ initTensor,
1819
- /* indexingMaps=*/
1820
- SmallVector<AffineMap, 3 >{mapLhs, mapRhs, mapOut},
1821
- /* iteratortType=*/ iteratorTypes,
1822
- [&](OpBuilder &b, Location loc, ValueRange args) {
1823
- Value lhsElem = args[0 ];
1824
- Value rhsElem = args[1 ];
1825
- Value mult = b.create <arith::MulFOp>(loc, lhsElem, rhsElem);
1826
- b.create <linalg::YieldOp>(loc, mult);
1827
- })
1828
- .getResult (0 );
1829
-
1830
- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1831
- return success ();
1832
- }
1833
- };
1834
- } // namespace
1835
-
1836
1762
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality (
1837
1763
TypeConverter &typeConverter, RewritePatternSet &patterns,
1838
1764
ConversionTarget &target) {
@@ -1849,6 +1775,4 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
1849
1775
patterns.add <ConvertAtenConvolutionOp>(typeConverter, context);
1850
1776
target.addIllegalOp <AtenFftRfftOp>();
1851
1777
patterns.add <ConvertAtenFftRfftOp>(typeConverter, context);
1852
- target.addIllegalOp <AtenOuterOp>();
1853
- patterns.add <ConvertAtenOuterOp>(typeConverter, context);
1854
1778
}
0 commit comments