Skip to content

Commit cadea67

Browse files
authored
[tosa] Implement torch.linear support. (#535)
Refactor matmul into separate class and derive variants: - matmul - mm, bmm - linear Signed-off-by: Suraj Sudhir <[email protected]>
1 parent ad4b9e0 commit cadea67

File tree

1 file changed

+242
-39
lines changed

1 file changed

+242
-39
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 242 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -817,40 +817,37 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
817817
return success();
818818
}
819819

820-
// Perform torch matmul, mm and bmm
820+
// Perform the basic n-dim matmul operation encompassing the handling of
821+
// broadcasting and dynamic shape propagation.
822+
// All PyTorch ops that leverage matrix multiplication will derive this and
823+
// implement their specialized input processing (e.g transpose), and output
824+
// processing, e.g. GEMM or fully connected bias handling.
821825
template <typename AtenOpT>
822-
class ConvertAtenMatMulOp : public OpConversionPattern<AtenOpT> {
826+
class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
823827
public:
824828
using OpConversionPattern<AtenOpT>::OpConversionPattern;
825829
using OpAdaptor = typename AtenOpT::Adaptor;
826-
LogicalResult
827-
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
828-
ConversionPatternRewriter &rewriter) const override {
829-
Value lhs = adaptor.self();
830-
auto lhsTy = lhs.getType().cast<RankedTensorType>();
830+
// Each variant must implement corresponding parameter parsing options.
831+
// Maintain separate input read functions for each variant because it is not
832+
// necessarily true with all variants that the first two operands are the lhs
833+
// and rhs.
834+
virtual LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
835+
ConversionPatternRewriter &rewriter,
836+
Value &lhs, Value &rhs) const {
837+
return rewriter.notifyMatchFailure(
838+
op,
839+
"Unimplemented matrix multiplication variant input parsing function");
840+
}
841+
LogicalResult performMatmul(AtenOpT op, OpAdaptor adaptor,
842+
ConversionPatternRewriter &rewriter, Value &lhs,
843+
Value &rhs, Value &output) const {
831844

832-
// Aten matmul, mm and bmm call operand2 by different names.
833-
Value rhs = adaptor.getOperands()[1];
845+
auto lhsTy = lhs.getType().cast<RankedTensorType>();
834846
auto rhsTy = rhs.getType().cast<RankedTensorType>();
835847

836-
if (!lhsTy || !rhsTy)
837-
return op.emitError("Only ranked tensor types supported in TOSA matmul");
838-
839848
auto lhsRank = lhsTy.getRank();
840849
auto rhsRank = rhsTy.getRank();
841850

842-
// Mm takes two 2D tensors
843-
if (isa<AtenMmOp>(op)) {
844-
assert(lhsRank == 2 && rhsRank == 2 &&
845-
"aten.mm called but matrix rank != 2");
846-
}
847-
848-
// Bmm takes two 2D tensors
849-
if (isa<AtenBmmOp>(op)) {
850-
assert(lhsRank == 3 && rhsRank == 3 &&
851-
"aten.bmm called but matrix rank != 2");
852-
}
853-
854851
auto lhsShape = lhsTy.getShape();
855852
auto rhsShape = rhsTy.getShape();
856853

@@ -1248,11 +1245,8 @@ class ConvertAtenMatMulOp : public OpConversionPattern<AtenOpT> {
12481245
// Perform the reshape to output shape. This is always required unless both
12491246
// inputs are rank=3, in which case the tosa.matmul output itself is
12501247
// correctly shaped.
1251-
bool performOpReshape = !(lhsRank == 3 && rhsRank == 3);
1252-
1253-
auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter()
1254-
->convertType(op.getType())
1255-
.template cast<RankedTensorType>();
1248+
bool performOpReshape =
1249+
!(lhsRank == 3 && rhsRank == 3 && lhsShape[0] == rhsShape[0]);
12561250

12571251
if (performOpReshape) {
12581252
// Since the output shape may be unknown, we construct it
@@ -1358,20 +1352,218 @@ class ConvertAtenMatMulOp : public OpConversionPattern<AtenOpT> {
13581352

13591353
auto transposedOpType =
13601354
RankedTensorType::get(transposedOpShape, outputElemTy);
1361-
auto transposedOp = rewriter.create<tosa::TransposeOp>(
1362-
op->getLoc(),
1363-
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
1364-
transposedOpType),
1365-
reshapedOp.getResult(), transposedOpShapeConst.getValue());
1355+
output =
1356+
rewriter
1357+
.create<tosa::TransposeOp>(
1358+
op->getLoc(),
1359+
OpConversionPattern<AtenOpT>::getTypeConverter()
1360+
->convertType(transposedOpType),
1361+
reshapedOp.getResult(), transposedOpShapeConst.getValue())
1362+
.getResult();
13661363

1367-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, transposedOp);
13681364
} else {
1369-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, reshapedOp);
1365+
output = reshapedOp.getResult();
13701366
}
13711367
} else {
1372-
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outputTy, mmOpResult);
1368+
output = mmOpResult;
13731369
}
13741370

1371+
return success();
1372+
}
1373+
// The default version just reads two inputs, computes output and returns it.
1374+
// Other versions may add a bias, apply GEMM-style alpha/beta scaling etc.
1375+
virtual LogicalResult
1376+
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
1377+
ConversionPatternRewriter &rewriter) const {
1378+
1379+
Value lhs, rhs;
1380+
1381+
if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs)))
1382+
return op.emitError("Failed to read matmul inputs");
1383+
1384+
Value output;
1385+
1386+
if (failed(performMatmul(op, adaptor, rewriter, lhs, rhs, output)))
1387+
return op.emitError("Failed to perform matmul operation");
1388+
1389+
rewriter.replaceOpWithNewOp<tensor::CastOp>(
1390+
op,
1391+
OpConversionPattern<AtenOpT>::getTypeConverter()
1392+
->convertType(op.getType())
1393+
.template cast<RankedTensorType>(),
1394+
output);
1395+
1396+
return success();
1397+
}
1398+
};
1399+
1400+
// Legalizes the torch.matmul op for general n-dim matmul.
1401+
template <typename AtenOpT>
1402+
class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
1403+
public:
1404+
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
1405+
using OpAdaptor = typename AtenOpT::Adaptor;
1406+
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
1407+
ConversionPatternRewriter &rewriter,
1408+
Value &lhs, Value &rhs) const override {
1409+
lhs = adaptor.self();
1410+
auto lhsTy = lhs.getType().cast<RankedTensorType>();
1411+
1412+
rhs = adaptor.other();
1413+
auto rhsTy = rhs.getType().cast<RankedTensorType>();
1414+
1415+
if (!lhsTy || !rhsTy)
1416+
return op.emitError("Only ranked tensor types supported in TOSA matmul");
1417+
1418+
return success();
1419+
}
1420+
};
1421+
1422+
// Implements handling of aten.mm and aten.bmm ops.
1423+
template <typename AtenOpT>
1424+
class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
1425+
public:
1426+
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
1427+
using OpAdaptor = typename AtenOpT::Adaptor;
1428+
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
1429+
ConversionPatternRewriter &rewriter,
1430+
Value &lhs, Value &rhs) const override {
1431+
1432+
lhs = adaptor.self();
1433+
auto lhsTy = lhs.getType().cast<RankedTensorType>();
1434+
1435+
rhs = adaptor.mat2();
1436+
auto rhsTy = rhs.getType().cast<RankedTensorType>();
1437+
1438+
if (!lhsTy || !rhsTy)
1439+
return op.emitError("Only ranked tensor types supported in TOSA matmul");
1440+
1441+
auto lhsRank = lhsTy.getRank();
1442+
auto rhsRank = rhsTy.getRank();
1443+
1444+
if (isa<AtenMmOp>(op)) {
1445+
// Mm takes two 2D tensors.
1446+
if (lhsRank != 2 || rhsRank != 2)
1447+
return op.emitError("aten.mm called but matrix rank != 2");
1448+
} else if (isa<AtenBmmOp>(op)) {
1449+
// Bmm takes two 3D tensors.
1450+
if (lhsRank != 3 || rhsRank != 3)
1451+
return op.emitError("aten.bmm called but matrix rank != 3");
1452+
}
1453+
1454+
return success();
1455+
}
1456+
};
1457+
1458+
// Implements handling of aten.linear op.
1459+
template <typename AtenOpT>
1460+
class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
1461+
public:
1462+
using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
1463+
using OpAdaptor = typename AtenOpT::Adaptor;
1464+
LogicalResult readMatMulInputs(AtenOpT op, OpAdaptor adaptor,
1465+
ConversionPatternRewriter &rewriter,
1466+
Value &lhs, Value &rhs) const override {
1467+
1468+
lhs = adaptor.input();
1469+
auto lhsTy = lhs.getType().cast<RankedTensorType>();
1470+
1471+
rhs = adaptor.weight();
1472+
auto rhsTy = rhs.getType().cast<RankedTensorType>();
1473+
1474+
if (!lhsTy || !rhsTy)
1475+
return op.emitError("Only ranked tensor types supported in TOSA matmul");
1476+
1477+
auto lhsRank = lhsTy.getRank();
1478+
auto rhsRank = rhsTy.getRank();
1479+
1480+
if (lhsRank != 2 && lhsRank != 3)
1481+
return op.emitError("aten.Linear called but input rank not 2 or 3");
1482+
if (rhsRank != 2 && rhsRank != 3)
1483+
return op.emitError("aten.Linear called but weight rank not 2 or 3");
1484+
1485+
// Protection against crash due to unguarded code in TOSA->LinAlg.
1486+
if (!lhsTy.hasStaticShape() || !rhsTy.hasStaticShape())
1487+
return op.emitError("aten.Linear needs statically shaped input");
1488+
1489+
return success();
1490+
}
1491+
// Override the default rewriter to perform RHS transpose and bias addition as
1492+
// well.
1493+
LogicalResult
1494+
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
1495+
ConversionPatternRewriter &rewriter) const override {
1496+
1497+
Value lhs, rhs;
1498+
1499+
if (failed(readMatMulInputs(op, adaptor, rewriter, lhs, rhs)))
1500+
return op.emitError("Failed to read matmul inputs");
1501+
1502+
// The aten.Linear op has a bias tensor that is added to the matmul output.
1503+
auto bias = adaptor.bias();
1504+
auto biasTy = bias.getType();
1505+
1506+
// TOSA does not mandate that elementwise op tensors need to be ranked.
1507+
if (!biasTy.template isa<Torch::NoneType>() &&
1508+
!biasTy.template isa<TensorType>())
1509+
return op.emitError("Only tensor types supported in GEMM to "
1510+
"TOSA for bias tensor");
1511+
1512+
// RHS must have its last two dims transposed prior to matrix
1513+
// multiplication.
1514+
auto rhsTy = rhs.getType().cast<RankedTensorType>();
1515+
auto rhsRank = rhsTy.getRank();
1516+
auto rhsShape = rhsTy.getShape();
1517+
auto rhsElemTy = rhsTy.getElementType();
1518+
1519+
// Create a non-const shape array to transpose dims.
1520+
SmallVector<int64_t> transposedRhsShape;
1521+
for (auto &shape : rhsShape)
1522+
transposedRhsShape.push_back(shape);
1523+
SmallVector<int32_t> transposedRhsDims;
1524+
for (int32_t i = 0; i < rhsRank; i++)
1525+
transposedRhsDims.push_back(i);
1526+
1527+
// Swap the last two dims.
1528+
std::swap(transposedRhsShape[rhsRank - 1], transposedRhsShape[rhsRank - 2]);
1529+
std::swap(transposedRhsDims[rhsRank - 1], transposedRhsDims[rhsRank - 2]);
1530+
1531+
llvm::Optional<Value> transposedRhsShapeConst =
1532+
tosa::getConstTensor<int32_t>(
1533+
rewriter, op,
1534+
/*vec=*/transposedRhsDims,
1535+
/*shape=*/{static_cast<int32_t>(transposedRhsDims.size())});
1536+
1537+
auto transposedRhsType =
1538+
RankedTensorType::get(transposedRhsShape, rhsElemTy);
1539+
rhs = rewriter.create<tosa::TransposeOp>(
1540+
op->getLoc(),
1541+
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
1542+
transposedRhsType),
1543+
rhs, transposedRhsShapeConst.getValue());
1544+
1545+
Value matmulOutput;
1546+
if (failed(
1547+
this->performMatmul(op, adaptor, rewriter, lhs, rhs, matmulOutput)))
1548+
return op.emitError("Failed to perform matmul operation");
1549+
1550+
Value matmulPlusBias = matmulOutput;
1551+
if (!biasTy.template isa<Torch::NoneType>()) {
1552+
// Bias addition broadcasts to the matmul output shape.
1553+
matmulPlusBias =
1554+
rewriter
1555+
.create<tosa::AddOp>(op->getLoc(), matmulOutput.getType(),
1556+
matmulOutput, bias)
1557+
.getResult();
1558+
}
1559+
1560+
rewriter.replaceOpWithNewOp<tensor::CastOp>(
1561+
op,
1562+
OpConversionPattern<AtenOpT>::getTypeConverter()
1563+
->convertType(op.getType())
1564+
.template cast<RankedTensorType>(),
1565+
matmulPlusBias);
1566+
13751567
return success();
13761568
}
13771569
};
@@ -1544,10 +1736,21 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
15441736
target.addIllegalOp<AtenOp>(); \
15451737
patterns.add<ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
15461738
INSERT_MATMUL_ATENOP_PATTERN(AtenMatmulOp);
1547-
INSERT_MATMUL_ATENOP_PATTERN(AtenMmOp);
1548-
INSERT_MATMUL_ATENOP_PATTERN(AtenBmmOp);
15491739
#undef INSERT_MATMUL_ATEMOP_PATTERN
15501740

1741+
#define INSERT_MM_ATENOP_PATTERN(AtenOp) \
1742+
target.addIllegalOp<AtenOp>(); \
1743+
patterns.add<ConvertAtenMmOp<AtenOp>>(typeConverter, context);
1744+
INSERT_MM_ATENOP_PATTERN(AtenMmOp);
1745+
INSERT_MM_ATENOP_PATTERN(AtenBmmOp);
1746+
#undef INSERT_MM_ATEMOP_PATTERN
1747+
1748+
#define INSERT_LINEAR_ATENOP_PATTERN(AtenOp) \
1749+
target.addIllegalOp<AtenOp>(); \
1750+
patterns.add<ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
1751+
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
1752+
#undef INSERT_LINEAR_ATEMOP_PATTERN
1753+
15511754
#define INSERT_ATENOP_PATTERN(AtenOp) \
15521755
target.addIllegalOp<AtenOp>(); \
15531756
patterns.add<ConvertAtenOp<AtenOp>>(typeConverter, context);

0 commit comments

Comments
 (0)