@@ -1759,6 +1759,139 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
1759
1759
1760
1760
} // namespace
1761
1761
1762
+ namespace {
1763
+ class ConvertAtenOuterOp : public OpConversionPattern <AtenOuterOp> {
1764
+ public:
1765
+ using OpConversionPattern::OpConversionPattern;
1766
+ LogicalResult
1767
+ matchAndRewrite (AtenOuterOp op, OpAdaptor adaptor,
1768
+ ConversionPatternRewriter &rewriter) const override {
1769
+
1770
+ Location loc = op->getLoc ();
1771
+ Value lhs = adaptor.getSelf ();
1772
+ Value rhs = op->getOperand (1 );
1773
+
1774
+ if (failed (verifyLinalgCompatibleTypes (op, rewriter))) {
1775
+ return failure ();
1776
+ }
1777
+ auto lhsType = cast<RankedTensorType>(lhs.getType ());
1778
+ auto rhsType = cast<RankedTensorType>(rhs.getType ());
1779
+
1780
+ auto lhsTorchType = cast<ValueTensorType>(op.getSelf ().getType ());
1781
+ auto rhsTorchType = cast<ValueTensorType>(op.getOperand (1 ).getType ());
1782
+
1783
+ // Get the rank of both matrix.
1784
+ unsigned lhsRank = lhsType.getRank ();
1785
+ unsigned rhsRank = rhsType.getRank ();
1786
+
1787
+ Value lhsZeroPoint, rhsZeroPoint;
1788
+ getZeroPoint (op.getSelf (), lhsZeroPoint);
1789
+ getZeroPoint (op.getOperand (1 ), rhsZeroPoint);
1790
+
1791
+ if (static_cast <bool >(lhsZeroPoint) != static_cast <bool >(rhsZeroPoint)) {
1792
+ return rewriter.notifyMatchFailure (
1793
+ op, " unsupported: aten.outer with mixed quantization" );
1794
+ }
1795
+
1796
+ bool isUnsigned = torch_to_linalg::isUnsignedTorchType (lhsTorchType);
1797
+ bool isUnsignedR = torch_to_linalg::isUnsignedTorchType (rhsTorchType);
1798
+
1799
+ if (!lhsZeroPoint && lhsTorchType.getDtype () != rhsTorchType.getDtype ()) {
1800
+ // Allows quantized types to mismatch
1801
+ return rewriter.notifyMatchFailure (
1802
+ op, " unsupported: aten.outer with different input element types" );
1803
+ }
1804
+
1805
+ Type newResultType = getTypeConverter ()->convertType (op.getType ());
1806
+ auto resultType = cast<RankedTensorType>(newResultType);
1807
+ Type elementType = resultType.getElementType ();
1808
+
1809
+ // Quantized case
1810
+ if (lhsZeroPoint) {
1811
+ // get each zero point ready to pass to a quantized_matmul
1812
+ lhsZeroPoint = typeConverter->materializeTargetConversion (
1813
+ rewriter, loc,
1814
+ getTypeConverter ()->convertType (lhsZeroPoint.getType ()),
1815
+ lhsZeroPoint);
1816
+ rhsZeroPoint = typeConverter->materializeTargetConversion (
1817
+ rewriter, loc,
1818
+ getTypeConverter ()->convertType (rhsZeroPoint.getType ()),
1819
+ rhsZeroPoint);
1820
+ lhsZeroPoint = rewriter.create <arith::TruncIOp>(
1821
+ loc, rewriter.getI32Type (), lhsZeroPoint);
1822
+ rhsZeroPoint = rewriter.create <arith::TruncIOp>(
1823
+ loc, rewriter.getI32Type (), rhsZeroPoint);
1824
+
1825
+ // change uint8 quantization -> int8 quantization
1826
+ int64_t numBits =
1827
+ cast<mlir::IntegerType>(lhsType.getElementType ()).getWidth ();
1828
+ signShift (rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
1829
+ numBits = cast<mlir::IntegerType>(rhsType.getElementType ()).getWidth ();
1830
+ signShift (rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
1831
+
1832
+ if (lhsRank == 1 && rhsRank == 1 ) {
1833
+ int64_t lhsDim = lhsType.getShape ()[0 ];
1834
+ int64_t rhsDim = rhsType.getShape ()[0 ];
1835
+
1836
+ // Unsqueeze: lhs: [n] -> [n, 1] and rhs: [m] -> [1, m]
1837
+ auto lhsUnsqueezeType = RankedTensorType::get ({lhsDim, 1 }, lhsType.getElementType ());
1838
+ auto rhsUnsqueezeType = RankedTensorType::get ({1 , rhsDim}, rhsType.getElementType ());
1839
+ SmallVector<ReassociationIndices> reassociation = {{0 , 1 }};
1840
+ lhs = rewriter.create <tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1841
+ rhs = rewriter.create <tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1842
+
1843
+ // Create a zero tensor with shape [lhsDim, rhsDim] for the accumulator.
1844
+ Value lhsDimVal = rewriter.create <tensor::DimOp>(loc, lhs, 0 );
1845
+ Value rhsDimVal = rewriter.create <tensor::DimOp>(loc, rhs, 1 );
1846
+ Value zeroTensor = createZeroInitTensor (rewriter, loc,
1847
+ ValueRange{lhsDimVal, rhsDimVal},
1848
+ elementType);
1849
+
1850
+ // Use the quantized version of matmul.
1851
+ Value outerProd = rewriter.create <linalg::QuantizedMatmulOp>(
1852
+ loc, zeroTensor.getType (),
1853
+ ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
1854
+ zeroTensor).getResult (0 );
1855
+
1856
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1857
+ return success ();
1858
+ }
1859
+ return rewriter.notifyMatchFailure (op, " unsupported: quantized aten.outer op case" );
1860
+ }
1861
+
1862
+
1863
+ // Non Quantized Outter Product
1864
+ if (lhsRank == 1 && rhsRank == 1 ) {
1865
+ int64_t lhsDim = lhsType.getShape ()[0 ];
1866
+ int64_t rhsDim = rhsType.getShape ()[0 ];
1867
+
1868
+ // Unsqueeze: lhs from [n] -> [n, 1] and rhs from [m] -> [1, m]
1869
+ auto lhsUnsqueezeType = RankedTensorType::get ({lhsDim, 1 }, lhsType.getElementType ());
1870
+ auto rhsUnsqueezeType = RankedTensorType::get ({1 , rhsDim}, rhsType.getElementType ());
1871
+ SmallVector<ReassociationIndices> reassociation = {{0 , 1 }};
1872
+ lhs = rewriter.create <tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1873
+ rhs = rewriter.create <tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1874
+
1875
+ // Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1876
+ Value lhsDimVal = rewriter.create <tensor::DimOp>(loc, lhs, 0 );
1877
+ Value rhsDimVal = rewriter.create <tensor::DimOp>(loc, rhs, 1 );
1878
+ Value zeroTensor = createZeroInitTensor (rewriter, loc,
1879
+ ValueRange{lhsDimVal, rhsDimVal},
1880
+ elementType);
1881
+
1882
+ // Use linalg::MatmulOp to compute the outer product.
1883
+ Value outerProd = rewriter.create <linalg::MatmulOp>(
1884
+ loc, zeroTensor.getType (), ValueRange{lhs, rhs}, zeroTensor).getResult (0 );
1885
+
1886
+ rewriter.replaceOpWithNewOp <tensor::CastOp>(op, newResultType, outerProd);
1887
+ return success ();
1888
+ }
1889
+
1890
+ return failure ();
1891
+ }
1892
+ };
1893
+ } // namespace
1894
+
1762
1895
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality (
1763
1896
TypeConverter &typeConverter, RewritePatternSet &patterns,
1764
1897
ConversionTarget &target) {
@@ -1775,4 +1908,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
1775
1908
patterns.add <ConvertAtenConvolutionOp>(typeConverter, context);
1776
1909
target.addIllegalOp <AtenFftRfftOp>();
1777
1910
patterns.add <ConvertAtenFftRfftOp>(typeConverter, context);
1911
+ target.addIllegalOp <AtenOuterOp>();
1912
+ patterns.add <ConvertAtenOuterOp>(typeConverter, context);
1778
1913
}
0 commit comments