@@ -1760,136 +1760,75 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
1760
1760
} // namespace
1761
1761
1762
1762
namespace {
1763
- class ConvertAtenOuterOp : public OpConversionPattern <AtenOuterOp> {
1764
- public:
1765
- using OpConversionPattern::OpConversionPattern;
1766
- LogicalResult
1767
- matchAndRewrite (AtenOuterOp op, OpAdaptor adaptor,
1768
- ConversionPatternRewriter &rewriter) const override {
1769
-
1770
- Location loc = op->getLoc ();
1771
- Value lhs = adaptor.getSelf ();
1772
- Value rhs = op->getOperand (1 );
1773
-
1774
- if (failed (verifyLinalgCompatibleTypes (op, rewriter))) {
1775
- return failure ();
1776
- }
1777
- auto lhsType = cast<RankedTensorType>(lhs.getType ());
1778
- auto rhsType = cast<RankedTensorType>(rhs.getType ());
1779
-
1780
- auto lhsTorchType = cast<ValueTensorType>(op.getSelf ().getType ());
1781
- auto rhsTorchType = cast<ValueTensorType>(op.getOperand (1 ).getType ());
1782
-
1783
- // Get the rank of both matrix.
1784
- unsigned lhsRank = lhsType.getRank ();
1785
- unsigned rhsRank = rhsType.getRank ();
1786
-
1787
- Value lhsZeroPoint, rhsZeroPoint;
1788
- getZeroPoint (op.getSelf (), lhsZeroPoint);
1789
- getZeroPoint (op.getOperand (1 ), rhsZeroPoint);
1790
-
1791
- if (static_cast <bool >(lhsZeroPoint) != static_cast <bool >(rhsZeroPoint)) {
1792
- return rewriter.notifyMatchFailure (
1793
- op, " unsupported: aten.outer with mixed quantization" );
1794
- }
1795
-
1796
- bool isUnsigned = torch_to_linalg::isUnsignedTorchType (lhsTorchType);
1797
- bool isUnsignedR = torch_to_linalg::isUnsignedTorchType (rhsTorchType);
1798
-
1799
- if (!lhsZeroPoint && lhsTorchType.getDtype () != rhsTorchType.getDtype ()) {
1800
- // Allows quantized types to mismatch
1801
- return rewriter.notifyMatchFailure (
1802
- op, " unsupported: aten.outer with different input element types" );
1803
- }
1804
-
1805
- Type newResultType = getTypeConverter ()->convertType (op.getType ());
1806
- auto resultType = cast<RankedTensorType>(newResultType);
1807
- Type elementType = resultType.getElementType ();
1808
-
1809
- // Quantized case
1810
- if (lhsZeroPoint) {
1811
- // get each zero point ready to pass to a quantized_matmul
1812
- lhsZeroPoint = typeConverter->materializeTargetConversion (
1813
- rewriter, loc,
1814
- getTypeConverter ()->convertType (lhsZeroPoint.getType ()),
1815
- lhsZeroPoint);
1816
- rhsZeroPoint = typeConverter->materializeTargetConversion (
1817
- rewriter, loc,
1818
- getTypeConverter ()->convertType (rhsZeroPoint.getType ()),
1819
- rhsZeroPoint);
1820
- lhsZeroPoint = rewriter.create <arith::TruncIOp>(
1821
- loc, rewriter.getI32Type (), lhsZeroPoint);
1822
- rhsZeroPoint = rewriter.create <arith::TruncIOp>(
1823
- loc, rewriter.getI32Type (), rhsZeroPoint);
1824
-
1825
- // change uint8 quantization -> int8 quantization
1826
- int64_t numBits =
1827
- cast<mlir::IntegerType>(lhsType.getElementType ()).getWidth ();
1828
- signShift (rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
1829
- numBits = cast<mlir::IntegerType>(rhsType.getElementType ()).getWidth ();
1830
- signShift (rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
1831
-
1832
- if (lhsRank == 1 && rhsRank == 1 ) {
1833
- int64_t lhsDim = lhsType.getShape ()[0 ];
1834
- int64_t rhsDim = rhsType.getShape ()[0 ];
1763
+ class ConvertAtenOuterOp : public OpConversionPattern <AtenOuterOp> {
1764
+ public:
1765
+ using OpConversionPattern::OpConversionPattern;
1766
+ LogicalResult
1767
+ matchAndRewrite (AtenOuterOp op, OpAdaptor adaptor,
1768
+ ConversionPatternRewriter &rewriter) const override {
1835
1769
1836
- // Unsqueeze: lhs: [n] -> [n, 1] and rhs: [m] -> [1, m]
1837
- auto lhsUnsqueezeType = RankedTensorType::get ({lhsDim, 1 }, lhsType.getElementType ());
1838
- auto rhsUnsqueezeType = RankedTensorType::get ({1 , rhsDim}, rhsType.getElementType ());
1839
- SmallVector<ReassociationIndices> reassociation = {{0 , 1 }};
1840
- lhs = rewriter.create <tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1841
- rhs = rewriter.create <tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1842
-
1843
- // Create a zero tensor with shape [lhsDim, rhsDim] for the accumulator.
1844
- Value lhsDimVal = rewriter.create <tensor::DimOp>(loc, lhs, 0 );
1845
- Value rhsDimVal = rewriter.create <tensor::DimOp>(loc, rhs, 1 );
1846
- Value zeroTensor = createZeroInitTensor (rewriter, loc,
1847
- ValueRange{lhsDimVal, rhsDimVal},
1848
- elementType);
1849
-
1850
- // Use the quantized version of matmul.
1851
- Value outerProd = rewriter.create <linalg::QuantizedMatmulOp>(
1852
- loc, zeroTensor.getType (),
1853
- ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
1854
- zeroTensor).getResult (0 );
1855
-
1856
- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1857
- return success ();
1858
- }
1859
- return rewriter.notifyMatchFailure (op, " unsupported: quantized aten.outer op case" );
1860
- }
1861
-
1862
-
1863
- // Non Quantized Outter Product
1864
- if (lhsRank == 1 && rhsRank == 1 ) {
1865
- int64_t lhsDim = lhsType.getShape ()[0 ];
1866
- int64_t rhsDim = rhsType.getShape ()[0 ];
1867
-
1868
- // Unsqueeze: lhs from [n] -> [n, 1] and rhs from [m] -> [1, m]
1869
- auto lhsUnsqueezeType = RankedTensorType::get ({lhsDim, 1 }, lhsType.getElementType ());
1870
- auto rhsUnsqueezeType = RankedTensorType::get ({1 , rhsDim}, rhsType.getElementType ());
1871
- SmallVector<ReassociationIndices> reassociation = {{0 , 1 }};
1872
- lhs = rewriter.create <tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1873
- rhs = rewriter.create <tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1874
-
1875
- // Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1876
- Value lhsDimVal = rewriter.create <tensor::DimOp>(loc, lhs, 0 );
1877
- Value rhsDimVal = rewriter.create <tensor::DimOp>(loc, rhs, 1 );
1878
- Value zeroTensor = createZeroInitTensor (rewriter, loc,
1879
- ValueRange{lhsDimVal, rhsDimVal},
1880
- elementType);
1881
-
1882
- // Use linalg::MatmulOp to compute the outer product.
1883
- Value outerProd = rewriter.create <linalg::MatmulOp>(
1884
- loc, zeroTensor.getType (), ValueRange{lhs, rhs}, zeroTensor).getResult (0 );
1885
-
1886
- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1887
- return success ();
1888
- }
1889
-
1770
+ Location loc = op->getLoc ();
1771
+ Value lhs = adaptor.getSelf ();
1772
+ Value rhs = op->getOperand (1 );
1773
+
1774
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter))) {
1890
1775
return failure ();
1891
1776
}
1892
- };
1777
+ auto lhsType = cast<RankedTensorType>(lhs.getType ());
1778
+ auto rhsType = cast<RankedTensorType>(rhs.getType ());
1779
+
1780
+ if (!lhsType || !rhsType)
1781
+ return rewriter.notifyMatchFailure (op,
1782
+ " outer: expected ranked tensor types" );
1783
+ if (lhsType.getRank () != 1 || rhsType.getRank () != 1 )
1784
+ return rewriter.notifyMatchFailure (
1785
+ op, " outer: expected 1D tensors for outer op lowering" );
1786
+
1787
+ Value lhsDim = getDimOp (rewriter, loc, lhs, 1 );
1788
+ Value rhsDim = getDimOp (rewriter, loc, rhs, 1 );
1789
+ Type elementType = lhsType.getElementType ();
1790
+ Type newResultType = getTypeConverter ()->convertType (op.getType ());
1791
+
1792
+ // Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1793
+ Value zeroTensor = createZeroInitTensor (
1794
+ rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType);
1795
+
1796
+ // Set up affine indexing maps:
1797
+ // We create a 2D loop iteration space. For the lhs, we use the first index
1798
+ // (i), for the rhs, the second index (j), and for the result, both (i, j).
1799
+ AffineMap mapLhs =
1800
+ AffineMap::get (2 , /* symbolCount=*/ 0 , {rewriter.getAffineDimExpr (0 )},
1801
+ rewriter.getContext ());
1802
+ AffineMap mapRhs =
1803
+ AffineMap::get (2 , /* symbolCount=*/ 0 , {rewriter.getAffineDimExpr (1 )},
1804
+ rewriter.getContext ());
1805
+ AffineMap mapOut =
1806
+ AffineMap::getMultiDimIdentityMap (2 , rewriter.getContext ());
1807
+
1808
+ SmallVector<utils::IteratorType, 2 > iteratorTypes = {
1809
+ utils::IteratorType::parallel, utils::IteratorType::parallel};
1810
+
1811
+ Value outerProd =
1812
+ rewriter
1813
+ .create <linalg::GenericOp>(
1814
+ loc, zeroTensor.getType (),
1815
+ /* inputs=*/ ValueRange{lhsDim, rhsDim},
1816
+ /* outputs=*/ zeroTensor,
1817
+ /* indexingMaps=*/
1818
+ SmallVector<AffineMap, 3 >{mapLhs, mapRhs, mapOut},
1819
+ /* iteratortType=*/ iteratorTypes,
1820
+ [&](OpBuilder &b, Location loc, ValueRange args) {
1821
+ Value lhsElem = args[0 ];
1822
+ Value rhsElem = args[1 ];
1823
+ Value mult = b.create <arith::MulFOp>(loc, lhsElem, rhsElem);
1824
+ b.create <linalg::YieldOp>(loc, mult);
1825
+ })
1826
+ .getResult (0 );
1827
+
1828
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1829
+ return success ();
1830
+ }
1831
+ };
1893
1832
} // namespace
1894
1833
1895
1834
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality (
0 commit comments