Skip to content

Commit c0d65be

Browse files
rootamemov
authored andcommitted
Initial implementation of AtenOuterOp
- Defined the op in Linear.cpp TODO: - Testing, and perhaps add some test(-s) inside torch-mlir?
1 parent 11d0853 commit c0d65be

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,139 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
17591759

17601760
} // namespace
17611761

1762+
namespace {
1763+
class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
1764+
public:
1765+
using OpConversionPattern::OpConversionPattern;
1766+
LogicalResult
1767+
matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor,
1768+
ConversionPatternRewriter &rewriter) const override {
1769+
1770+
Location loc = op->getLoc();
1771+
Value lhs = adaptor.getSelf();
1772+
Value rhs = op->getOperand(1);
1773+
1774+
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
1775+
return failure();
1776+
}
1777+
auto lhsType = cast<RankedTensorType>(lhs.getType());
1778+
auto rhsType = cast<RankedTensorType>(rhs.getType());
1779+
1780+
auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType());
1781+
auto rhsTorchType = cast<ValueTensorType>(op.getOperand(1).getType());
1782+
1783+
// Get the rank of both matrix.
1784+
unsigned lhsRank = lhsType.getRank();
1785+
unsigned rhsRank = rhsType.getRank();
1786+
1787+
Value lhsZeroPoint, rhsZeroPoint;
1788+
getZeroPoint(op.getSelf(), lhsZeroPoint);
1789+
getZeroPoint(op.getOperand(1), rhsZeroPoint);
1790+
1791+
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(rhsZeroPoint)) {
1792+
return rewriter.notifyMatchFailure(
1793+
op, "unsupported: aten.outer with mixed quantization");
1794+
}
1795+
1796+
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
1797+
bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType);
1798+
1799+
if (!lhsZeroPoint && lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
1800+
// Allows quantized types to mismatch
1801+
return rewriter.notifyMatchFailure(
1802+
op, "unsupported: aten.outer with different input element types");
1803+
}
1804+
1805+
Type newResultType = getTypeConverter()->convertType(op.getType());
1806+
auto resultType = cast<RankedTensorType>(newResultType);
1807+
Type elementType = resultType.getElementType();
1808+
1809+
// Quantized case
1810+
if (lhsZeroPoint) {
1811+
// get each zero point ready to pass to a quantized_matmul
1812+
lhsZeroPoint = typeConverter->materializeTargetConversion(
1813+
rewriter, loc,
1814+
getTypeConverter()->convertType(lhsZeroPoint.getType()),
1815+
lhsZeroPoint);
1816+
rhsZeroPoint = typeConverter->materializeTargetConversion(
1817+
rewriter, loc,
1818+
getTypeConverter()->convertType(rhsZeroPoint.getType()),
1819+
rhsZeroPoint);
1820+
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
1821+
loc, rewriter.getI32Type(), lhsZeroPoint);
1822+
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
1823+
loc, rewriter.getI32Type(), rhsZeroPoint);
1824+
1825+
// change uint8 quantization -> int8 quantization
1826+
int64_t numBits =
1827+
cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
1828+
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
1829+
numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
1830+
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
1831+
1832+
if (lhsRank == 1 && rhsRank == 1) {
1833+
int64_t lhsDim = lhsType.getShape()[0];
1834+
int64_t rhsDim = rhsType.getShape()[0];
1835+
1836+
// Unsqueeze: lhs: [n] -> [n, 1] and rhs: [m] -> [1, m]
1837+
auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType());
1838+
auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType());
1839+
SmallVector<ReassociationIndices> reassociation = {{0, 1}};
1840+
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1841+
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1842+
1843+
// Create a zero tensor with shape [lhsDim, rhsDim] for the accumulator.
1844+
Value lhsDimVal = rewriter.create<tensor::DimOp>(loc, lhs, 0);
1845+
Value rhsDimVal = rewriter.create<tensor::DimOp>(loc, rhs, 1);
1846+
Value zeroTensor = createZeroInitTensor(rewriter, loc,
1847+
ValueRange{lhsDimVal, rhsDimVal},
1848+
elementType);
1849+
1850+
// Use the quantized version of matmul.
1851+
Value outerProd = rewriter.create<linalg::QuantizedMatmulOp>(
1852+
loc, zeroTensor.getType(),
1853+
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
1854+
zeroTensor).getResult(0);
1855+
1856+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1857+
return success();
1858+
}
1859+
return rewriter.notifyMatchFailure(op, "unsupported: quantized aten.outer op case");
1860+
}
1861+
1862+
1863+
// Non Quantized Outter Product
1864+
if (lhsRank == 1 && rhsRank == 1) {
1865+
int64_t lhsDim = lhsType.getShape()[0];
1866+
int64_t rhsDim = rhsType.getShape()[0];
1867+
1868+
// Unsqueeze: lhs from [n] -> [n, 1] and rhs from [m] -> [1, m]
1869+
auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType());
1870+
auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType());
1871+
SmallVector<ReassociationIndices> reassociation = {{0, 1}};
1872+
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1873+
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1874+
1875+
// Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1876+
Value lhsDimVal = rewriter.create<tensor::DimOp>(loc, lhs, 0);
1877+
Value rhsDimVal = rewriter.create<tensor::DimOp>(loc, rhs, 1);
1878+
Value zeroTensor = createZeroInitTensor(rewriter, loc,
1879+
ValueRange{lhsDimVal, rhsDimVal},
1880+
elementType);
1881+
1882+
// Use linalg::MatmulOp to compute the outer product.
1883+
Value outerProd = rewriter.create<linalg::MatmulOp>(
1884+
loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor).getResult(0);
1885+
1886+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1887+
return success();
1888+
}
1889+
1890+
return failure();
1891+
}
1892+
};
1893+
} // namespace
1894+
17621895
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
17631896
TypeConverter &typeConverter, RewritePatternSet &patterns,
17641897
ConversionTarget &target) {
@@ -1775,4 +1908,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
17751908
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
17761909
target.addIllegalOp<AtenFftRfftOp>();
17771910
patterns.add<ConvertAtenFftRfftOp>(typeConverter, context);
1911+
target.addIllegalOp<AtenOuterOp>();
1912+
patterns.add<ConvertAtenOuterOp>(typeConverter, context);
17781913
}

0 commit comments

Comments
 (0)