Skip to content

Commit c7e3e05

Browse files
committed
Addressed the comments:
- Rewrote the ConvertAtenOuterOp without unsqueezing - Replaced linalg::MatmulOp with linalg::GenericOp for buidling result of the op - Added error messages for - Added test case in e2e tests - placed in matmul.py
1 parent c0d65be commit c7e3e05

File tree

2 files changed

+90
-127
lines changed
  • lib/Conversion/TorchToLinalg
  • projects/pt1/python/torch_mlir_e2e_test/test_suite

2 files changed

+90
-127
lines changed

lib/Conversion/TorchToLinalg/Linear.cpp

Lines changed: 66 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,136 +1760,75 @@ struct ConvertAtenFftRfftOp final : OpConversionPattern<AtenFftRfftOp> {
17601760
} // namespace
17611761

17621762
namespace {
1763-
class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
1764-
public:
1765-
using OpConversionPattern::OpConversionPattern;
1766-
LogicalResult
1767-
matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor,
1768-
ConversionPatternRewriter &rewriter) const override {
1769-
1770-
Location loc = op->getLoc();
1771-
Value lhs = adaptor.getSelf();
1772-
Value rhs = op->getOperand(1);
1773-
1774-
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
1775-
return failure();
1776-
}
1777-
auto lhsType = cast<RankedTensorType>(lhs.getType());
1778-
auto rhsType = cast<RankedTensorType>(rhs.getType());
1779-
1780-
auto lhsTorchType = cast<ValueTensorType>(op.getSelf().getType());
1781-
auto rhsTorchType = cast<ValueTensorType>(op.getOperand(1).getType());
1782-
1783-
// Get the rank of both matrix.
1784-
unsigned lhsRank = lhsType.getRank();
1785-
unsigned rhsRank = rhsType.getRank();
1786-
1787-
Value lhsZeroPoint, rhsZeroPoint;
1788-
getZeroPoint(op.getSelf(), lhsZeroPoint);
1789-
getZeroPoint(op.getOperand(1), rhsZeroPoint);
1790-
1791-
if (static_cast<bool>(lhsZeroPoint) != static_cast<bool>(rhsZeroPoint)) {
1792-
return rewriter.notifyMatchFailure(
1793-
op, "unsupported: aten.outer with mixed quantization");
1794-
}
1795-
1796-
bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType);
1797-
bool isUnsignedR = torch_to_linalg::isUnsignedTorchType(rhsTorchType);
1798-
1799-
if (!lhsZeroPoint && lhsTorchType.getDtype() != rhsTorchType.getDtype()) {
1800-
// Allows quantized types to mismatch
1801-
return rewriter.notifyMatchFailure(
1802-
op, "unsupported: aten.outer with different input element types");
1803-
}
1804-
1805-
Type newResultType = getTypeConverter()->convertType(op.getType());
1806-
auto resultType = cast<RankedTensorType>(newResultType);
1807-
Type elementType = resultType.getElementType();
1808-
1809-
// Quantized case
1810-
if (lhsZeroPoint) {
1811-
// get each zero point ready to pass to a quantized_matmul
1812-
lhsZeroPoint = typeConverter->materializeTargetConversion(
1813-
rewriter, loc,
1814-
getTypeConverter()->convertType(lhsZeroPoint.getType()),
1815-
lhsZeroPoint);
1816-
rhsZeroPoint = typeConverter->materializeTargetConversion(
1817-
rewriter, loc,
1818-
getTypeConverter()->convertType(rhsZeroPoint.getType()),
1819-
rhsZeroPoint);
1820-
lhsZeroPoint = rewriter.create<arith::TruncIOp>(
1821-
loc, rewriter.getI32Type(), lhsZeroPoint);
1822-
rhsZeroPoint = rewriter.create<arith::TruncIOp>(
1823-
loc, rewriter.getI32Type(), rhsZeroPoint);
1824-
1825-
// change uint8 quantization -> int8 quantization
1826-
int64_t numBits =
1827-
cast<mlir::IntegerType>(lhsType.getElementType()).getWidth();
1828-
signShift(rewriter, loc, lhs, lhsZeroPoint, isUnsigned, numBits);
1829-
numBits = cast<mlir::IntegerType>(rhsType.getElementType()).getWidth();
1830-
signShift(rewriter, loc, rhs, rhsZeroPoint, isUnsignedR, numBits);
1831-
1832-
if (lhsRank == 1 && rhsRank == 1) {
1833-
int64_t lhsDim = lhsType.getShape()[0];
1834-
int64_t rhsDim = rhsType.getShape()[0];
1763+
class ConvertAtenOuterOp : public OpConversionPattern<AtenOuterOp> {
1764+
public:
1765+
using OpConversionPattern::OpConversionPattern;
1766+
LogicalResult
1767+
matchAndRewrite(AtenOuterOp op, OpAdaptor adaptor,
1768+
ConversionPatternRewriter &rewriter) const override {
18351769

1836-
// Unsqueeze: lhs: [n] -> [n, 1] and rhs: [m] -> [1, m]
1837-
auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType());
1838-
auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType());
1839-
SmallVector<ReassociationIndices> reassociation = {{0, 1}};
1840-
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1841-
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1842-
1843-
// Create a zero tensor with shape [lhsDim, rhsDim] for the accumulator.
1844-
Value lhsDimVal = rewriter.create<tensor::DimOp>(loc, lhs, 0);
1845-
Value rhsDimVal = rewriter.create<tensor::DimOp>(loc, rhs, 1);
1846-
Value zeroTensor = createZeroInitTensor(rewriter, loc,
1847-
ValueRange{lhsDimVal, rhsDimVal},
1848-
elementType);
1849-
1850-
// Use the quantized version of matmul.
1851-
Value outerProd = rewriter.create<linalg::QuantizedMatmulOp>(
1852-
loc, zeroTensor.getType(),
1853-
ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint},
1854-
zeroTensor).getResult(0);
1855-
1856-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1857-
return success();
1858-
}
1859-
return rewriter.notifyMatchFailure(op, "unsupported: quantized aten.outer op case");
1860-
}
1861-
1862-
1863-
// Non Quantized Outter Product
1864-
if (lhsRank == 1 && rhsRank == 1) {
1865-
int64_t lhsDim = lhsType.getShape()[0];
1866-
int64_t rhsDim = rhsType.getShape()[0];
1867-
1868-
// Unsqueeze: lhs from [n] -> [n, 1] and rhs from [m] -> [1, m]
1869-
auto lhsUnsqueezeType = RankedTensorType::get({lhsDim, 1}, lhsType.getElementType());
1870-
auto rhsUnsqueezeType = RankedTensorType::get({1, rhsDim}, rhsType.getElementType());
1871-
SmallVector<ReassociationIndices> reassociation = {{0, 1}};
1872-
lhs = rewriter.create<tensor::ExpandShapeOp>(loc, lhsUnsqueezeType, lhs, reassociation);
1873-
rhs = rewriter.create<tensor::ExpandShapeOp>(loc, rhsUnsqueezeType, rhs, reassociation);
1874-
1875-
// Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1876-
Value lhsDimVal = rewriter.create<tensor::DimOp>(loc, lhs, 0);
1877-
Value rhsDimVal = rewriter.create<tensor::DimOp>(loc, rhs, 1);
1878-
Value zeroTensor = createZeroInitTensor(rewriter, loc,
1879-
ValueRange{lhsDimVal, rhsDimVal},
1880-
elementType);
1881-
1882-
// Use linalg::MatmulOp to compute the outer product.
1883-
Value outerProd = rewriter.create<linalg::MatmulOp>(
1884-
loc, zeroTensor.getType(), ValueRange{lhs, rhs}, zeroTensor).getResult(0);
1885-
1886-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1887-
return success();
1888-
}
1889-
1770+
Location loc = op->getLoc();
1771+
Value lhs = adaptor.getSelf();
1772+
Value rhs = op->getOperand(1);
1773+
1774+
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) {
18901775
return failure();
18911776
}
1892-
};
1777+
auto lhsType = cast<RankedTensorType>(lhs.getType());
1778+
auto rhsType = cast<RankedTensorType>(rhs.getType());
1779+
1780+
if (!lhsType || !rhsType)
1781+
return rewriter.notifyMatchFailure(op,
1782+
"outer: expected ranked tensor types");
1783+
if (lhsType.getRank() != 1 || rhsType.getRank() != 1)
1784+
return rewriter.notifyMatchFailure(
1785+
op, "outer: expected 1D tensors for outer op lowering");
1786+
1787+
Value lhsDim = getDimOp(rewriter, loc, lhs, 1);
1788+
Value rhsDim = getDimOp(rewriter, loc, rhs, 1);
1789+
Type elementType = lhsType.getElementType();
1790+
Type newResultType = getTypeConverter()->convertType(op.getType());
1791+
1792+
// Create a zero-initialized tensor with shape [lhsDim, rhsDim]
1793+
Value zeroTensor = createZeroInitTensor(
1794+
rewriter, loc, ValueRange{lhsDim, rhsDim}, elementType);
1795+
1796+
// Set up affine indexing maps:
1797+
// We create a 2D loop iteration space. For the lhs, we use the first index
1798+
// (i), for the rhs, the second index (j), and for the result, both (i, j).
1799+
AffineMap mapLhs =
1800+
AffineMap::get(2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(0)},
1801+
rewriter.getContext());
1802+
AffineMap mapRhs =
1803+
AffineMap::get(2, /*symbolCount=*/0, {rewriter.getAffineDimExpr(1)},
1804+
rewriter.getContext());
1805+
AffineMap mapOut =
1806+
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
1807+
1808+
SmallVector<utils::IteratorType, 2> iteratorTypes = {
1809+
utils::IteratorType::parallel, utils::IteratorType::parallel};
1810+
1811+
Value outerProd =
1812+
rewriter
1813+
.create<linalg::GenericOp>(
1814+
loc, zeroTensor.getType(),
1815+
/*inputs=*/ValueRange{lhsDim, rhsDim},
1816+
/*outputs=*/zeroTensor,
1817+
/*indexingMaps=*/
1818+
SmallVector<AffineMap, 3>{mapLhs, mapRhs, mapOut},
1819+
/*iteratortType=*/iteratorTypes,
1820+
[&](OpBuilder &b, Location loc, ValueRange args) {
1821+
Value lhsElem = args[0];
1822+
Value rhsElem = args[1];
1823+
Value mult = b.create<arith::MulFOp>(loc, lhsElem, rhsElem);
1824+
b.create<linalg::YieldOp>(loc, mult);
1825+
})
1826+
.getResult(0);
1827+
1828+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, outerProd);
1829+
return success();
1830+
}
1831+
};
18931832
} // namespace
18941833

18951834
void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(

projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,3 +918,27 @@ def forward(self, a, b):
918918
@register_test_case(module_factory=lambda: AtenLinalgCrossDynamic())
919919
def AtenLinalgCrossDynamic_basic(module, tu: TestUtils):
920920
module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1))
921+
922+
923+
# ==============================================================================
924+
925+
926+
class AtenOuter(torch.nn.Module):
927+
def __init__(self):
928+
super().__init__()
929+
930+
@export
931+
@annotate_args(
932+
[
933+
None,
934+
([-1], torch.float32, True),
935+
([-1], torch.float32, True),
936+
]
937+
)
938+
def forward(self, lhs, rhs):
939+
return torch.outer(lhs, rhs)
940+
941+
942+
@register_test_case(module_factory=lambda: AtenOuter())
943+
def AtenOuter_basic(module, tu: TestUtils):
944+
module.forward(tu.rand(3), tu.rand(3))

0 commit comments

Comments
 (0)