@@ -817,40 +817,37 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
817817 return success ();
818818}
819819
820- // Perform torch matmul, mm and bmm
820+ // Perform the basic n-dim matmul operation encompassing the handling of
821+ // broadcasting and dynamic shape propagation.
822+ // All PyTorch ops that leverage matrix multiplication will derive this and
823+ // implement their specialized input processing (e.g transpose), and output
824+ // processing, e.g. GEMM or fully connected bias handling.
821825template <typename AtenOpT>
822- class ConvertAtenMatMulOp : public OpConversionPattern <AtenOpT> {
826+ class ConvertAtenMatmulBaseOp : public OpConversionPattern <AtenOpT> {
823827public:
824828 using OpConversionPattern<AtenOpT>::OpConversionPattern;
825829 using OpAdaptor = typename AtenOpT::Adaptor;
826- LogicalResult
827- matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
828- ConversionPatternRewriter &rewriter) const override {
829- Value lhs = adaptor.self ();
830- auto lhsTy = lhs.getType ().cast <RankedTensorType>();
830+ // Each variant must implement corresponding parameter parsing options.
831+ // Maintain separate input read functions for each variant because it is not
832+ // necessarily true with all variants that the first two operands are the lhs
833+ // and rhs.
834+ virtual LogicalResult readMatMulInputs (AtenOpT op, OpAdaptor adaptor,
835+ ConversionPatternRewriter &rewriter,
836+ Value &lhs, Value &rhs) const {
837+ return rewriter.notifyMatchFailure (
838+ op,
839+ " Unimplemented matrix multiplication variant input parsing function" );
840+ }
841+ LogicalResult performMatmul (AtenOpT op, OpAdaptor adaptor,
842+ ConversionPatternRewriter &rewriter, Value &lhs,
843+ Value &rhs, Value &output) const {
831844
832- // Aten matmul, mm and bmm call operand2 by different names.
833- Value rhs = adaptor.getOperands ()[1 ];
845+ auto lhsTy = lhs.getType ().cast <RankedTensorType>();
834846 auto rhsTy = rhs.getType ().cast <RankedTensorType>();
835847
836- if (!lhsTy || !rhsTy)
837- return op.emitError (" Only ranked tensor types supported in TOSA matmul" );
838-
839848 auto lhsRank = lhsTy.getRank ();
840849 auto rhsRank = rhsTy.getRank ();
841850
842- // Mm takes two 2D tensors
843- if (isa<AtenMmOp>(op)) {
844- assert (lhsRank == 2 && rhsRank == 2 &&
845- " aten.mm called but matrix rank != 2" );
846- }
847-
848- // Bmm takes two 2D tensors
849- if (isa<AtenBmmOp>(op)) {
850- assert (lhsRank == 3 && rhsRank == 3 &&
851- " aten.bmm called but matrix rank != 2" );
852- }
853-
854851 auto lhsShape = lhsTy.getShape ();
855852 auto rhsShape = rhsTy.getShape ();
856853
@@ -1248,11 +1245,8 @@ class ConvertAtenMatMulOp : public OpConversionPattern<AtenOpT> {
12481245 // Perform the reshape to output shape. This is always required unless both
12491246 // inputs are rank=3, in which case the tosa.matmul output itself is
12501247 // correctly shaped.
1251- bool performOpReshape = !(lhsRank == 3 && rhsRank == 3 );
1252-
1253- auto outputTy = OpConversionPattern<AtenOpT>::getTypeConverter ()
1254- ->convertType (op.getType ())
1255- .template cast <RankedTensorType>();
1248+ bool performOpReshape =
1249+ !(lhsRank == 3 && rhsRank == 3 && lhsShape[0 ] == rhsShape[0 ]);
12561250
12571251 if (performOpReshape) {
12581252 // Since the output shape may be unknown, we construct it
@@ -1358,20 +1352,218 @@ class ConvertAtenMatMulOp : public OpConversionPattern<AtenOpT> {
13581352
13591353 auto transposedOpType =
13601354 RankedTensorType::get (transposedOpShape, outputElemTy);
1361- auto transposedOp = rewriter.create <tosa::TransposeOp>(
1362- op->getLoc (),
1363- OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
1364- transposedOpType),
1365- reshapedOp.getResult (), transposedOpShapeConst.getValue ());
1355+ output =
1356+ rewriter
1357+ .create <tosa::TransposeOp>(
1358+ op->getLoc (),
1359+ OpConversionPattern<AtenOpT>::getTypeConverter ()
1360+ ->convertType (transposedOpType),
1361+ reshapedOp.getResult (), transposedOpShapeConst.getValue ())
1362+ .getResult ();
13661363
1367- rewriter.replaceOpWithNewOp <tensor::CastOp>(op, outputTy, transposedOp);
13681364 } else {
1369- rewriter. replaceOpWithNewOp <tensor::CastOp>(op, outputTy, reshapedOp);
1365+ output = reshapedOp. getResult ( );
13701366 }
13711367 } else {
1372- rewriter. replaceOpWithNewOp <tensor::CastOp>(op, outputTy, mmOpResult) ;
1368+ output = mmOpResult;
13731369 }
13741370
1371+ return success ();
1372+ }
1373+ // The default version just reads two inputs, computes output and returns it.
1374+ // Other versions may add a bias, apply GEMM-style alpha/beta scaling etc.
1375+ virtual LogicalResult
1376+ matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
1377+ ConversionPatternRewriter &rewriter) const {
1378+
1379+ Value lhs, rhs;
1380+
1381+ if (failed (readMatMulInputs (op, adaptor, rewriter, lhs, rhs)))
1382+ return op.emitError (" Failed to read matmul inputs" );
1383+
1384+ Value output;
1385+
1386+ if (failed (performMatmul (op, adaptor, rewriter, lhs, rhs, output)))
1387+ return op.emitError (" Failed to perform matmul operation" );
1388+
1389+ rewriter.replaceOpWithNewOp <tensor::CastOp>(
1390+ op,
1391+ OpConversionPattern<AtenOpT>::getTypeConverter ()
1392+ ->convertType (op.getType ())
1393+ .template cast <RankedTensorType>(),
1394+ output);
1395+
1396+ return success ();
1397+ }
1398+ };
1399+
1400+ // Legalizes the torch.matmul op for general n-dim matmul.
1401+ template <typename AtenOpT>
1402+ class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp <AtenOpT> {
1403+ public:
1404+ using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
1405+ using OpAdaptor = typename AtenOpT::Adaptor;
1406+ LogicalResult readMatMulInputs (AtenOpT op, OpAdaptor adaptor,
1407+ ConversionPatternRewriter &rewriter,
1408+ Value &lhs, Value &rhs) const override {
1409+ lhs = adaptor.self ();
1410+ auto lhsTy = lhs.getType ().cast <RankedTensorType>();
1411+
1412+ rhs = adaptor.other ();
1413+ auto rhsTy = rhs.getType ().cast <RankedTensorType>();
1414+
1415+ if (!lhsTy || !rhsTy)
1416+ return op.emitError (" Only ranked tensor types supported in TOSA matmul" );
1417+
1418+ return success ();
1419+ }
1420+ };
1421+
1422+ // Implements handling of aten.mm and aten.bmm ops.
1423+ template <typename AtenOpT>
1424+ class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp <AtenOpT> {
1425+ public:
1426+ using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
1427+ using OpAdaptor = typename AtenOpT::Adaptor;
1428+ LogicalResult readMatMulInputs (AtenOpT op, OpAdaptor adaptor,
1429+ ConversionPatternRewriter &rewriter,
1430+ Value &lhs, Value &rhs) const override {
1431+
1432+ lhs = adaptor.self ();
1433+ auto lhsTy = lhs.getType ().cast <RankedTensorType>();
1434+
1435+ rhs = adaptor.mat2 ();
1436+ auto rhsTy = rhs.getType ().cast <RankedTensorType>();
1437+
1438+ if (!lhsTy || !rhsTy)
1439+ return op.emitError (" Only ranked tensor types supported in TOSA matmul" );
1440+
1441+ auto lhsRank = lhsTy.getRank ();
1442+ auto rhsRank = rhsTy.getRank ();
1443+
1444+ if (isa<AtenMmOp>(op)) {
1445+ // Mm takes two 2D tensors.
1446+ if (lhsRank != 2 || rhsRank != 2 )
1447+ return op.emitError (" aten.mm called but matrix rank != 2" );
1448+ } else if (isa<AtenBmmOp>(op)) {
1449+ // Bmm takes two 3D tensors.
1450+ if (lhsRank != 3 || rhsRank != 3 )
1451+ return op.emitError (" aten.bmm called but matrix rank != 3" );
1452+ }
1453+
1454+ return success ();
1455+ }
1456+ };
1457+
1458+ // Implements handling of aten.linear op.
1459+ template <typename AtenOpT>
1460+ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp <AtenOpT> {
1461+ public:
1462+ using ConvertAtenMatmulBaseOp<AtenOpT>::ConvertAtenMatmulBaseOp;
1463+ using OpAdaptor = typename AtenOpT::Adaptor;
1464+ LogicalResult readMatMulInputs (AtenOpT op, OpAdaptor adaptor,
1465+ ConversionPatternRewriter &rewriter,
1466+ Value &lhs, Value &rhs) const override {
1467+
1468+ lhs = adaptor.input ();
1469+ auto lhsTy = lhs.getType ().cast <RankedTensorType>();
1470+
1471+ rhs = adaptor.weight ();
1472+ auto rhsTy = rhs.getType ().cast <RankedTensorType>();
1473+
1474+ if (!lhsTy || !rhsTy)
1475+ return op.emitError (" Only ranked tensor types supported in TOSA matmul" );
1476+
1477+ auto lhsRank = lhsTy.getRank ();
1478+ auto rhsRank = rhsTy.getRank ();
1479+
1480+ if (lhsRank != 2 && lhsRank != 3 )
1481+ return op.emitError (" aten.Linear called but input rank not 2 or 3" );
1482+ if (rhsRank != 2 && rhsRank != 3 )
1483+ return op.emitError (" aten.Linear called but weight rank not 2 or 3" );
1484+
1485+ // Protection against crash due to unguarded code in TOSA->LinAlg.
1486+ if (!lhsTy.hasStaticShape () || !rhsTy.hasStaticShape ())
1487+ return op.emitError (" aten.Linear needs statically shaped input" );
1488+
1489+ return success ();
1490+ }
1491+ // Override the default rewriter to perform RHS transpose and bias addition as
1492+ // well.
1493+ LogicalResult
1494+ matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
1495+ ConversionPatternRewriter &rewriter) const override {
1496+
1497+ Value lhs, rhs;
1498+
1499+ if (failed (readMatMulInputs (op, adaptor, rewriter, lhs, rhs)))
1500+ return op.emitError (" Failed to read matmul inputs" );
1501+
1502+ // The aten.Linear op has a bias tensor that is added to the matmul output.
1503+ auto bias = adaptor.bias ();
1504+ auto biasTy = bias.getType ();
1505+
1506+ // TOSA does not mandate that elementwise op tensors need to be ranked.
1507+ if (!biasTy.template isa <Torch::NoneType>() &&
1508+ !biasTy.template isa <TensorType>())
1509+ return op.emitError (" Only tensor types supported in GEMM to "
1510+ " TOSA for bias tensor" );
1511+
1512+ // RHS must have its last two dims transposed prior to matrix
1513+ // multiplication.
1514+ auto rhsTy = rhs.getType ().cast <RankedTensorType>();
1515+ auto rhsRank = rhsTy.getRank ();
1516+ auto rhsShape = rhsTy.getShape ();
1517+ auto rhsElemTy = rhsTy.getElementType ();
1518+
1519+ // Create a non-const shape array to transpose dims.
1520+ SmallVector<int64_t > transposedRhsShape;
1521+ for (auto &shape : rhsShape)
1522+ transposedRhsShape.push_back (shape);
1523+ SmallVector<int32_t > transposedRhsDims;
1524+ for (int32_t i = 0 ; i < rhsRank; i++)
1525+ transposedRhsDims.push_back (i);
1526+
1527+ // Swap the last two dims.
1528+ std::swap (transposedRhsShape[rhsRank - 1 ], transposedRhsShape[rhsRank - 2 ]);
1529+ std::swap (transposedRhsDims[rhsRank - 1 ], transposedRhsDims[rhsRank - 2 ]);
1530+
1531+ llvm::Optional<Value> transposedRhsShapeConst =
1532+ tosa::getConstTensor<int32_t >(
1533+ rewriter, op,
1534+ /* vec=*/ transposedRhsDims,
1535+ /* shape=*/ {static_cast <int32_t >(transposedRhsDims.size ())});
1536+
1537+ auto transposedRhsType =
1538+ RankedTensorType::get (transposedRhsShape, rhsElemTy);
1539+ rhs = rewriter.create <tosa::TransposeOp>(
1540+ op->getLoc (),
1541+ OpConversionPattern<AtenOpT>::getTypeConverter ()->convertType (
1542+ transposedRhsType),
1543+ rhs, transposedRhsShapeConst.getValue ());
1544+
1545+ Value matmulOutput;
1546+ if (failed (
1547+ this ->performMatmul (op, adaptor, rewriter, lhs, rhs, matmulOutput)))
1548+ return op.emitError (" Failed to perform matmul operation" );
1549+
1550+ Value matmulPlusBias = matmulOutput;
1551+ if (!biasTy.template isa <Torch::NoneType>()) {
1552+ // Bias addition broadcasts to the matmul output shape.
1553+ matmulPlusBias =
1554+ rewriter
1555+ .create <tosa::AddOp>(op->getLoc (), matmulOutput.getType (),
1556+ matmulOutput, bias)
1557+ .getResult ();
1558+ }
1559+
1560+ rewriter.replaceOpWithNewOp <tensor::CastOp>(
1561+ op,
1562+ OpConversionPattern<AtenOpT>::getTypeConverter ()
1563+ ->convertType (op.getType ())
1564+ .template cast <RankedTensorType>(),
1565+ matmulPlusBias);
1566+
13751567 return success ();
13761568 }
13771569};
@@ -1544,10 +1736,21 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
15441736 target.addIllegalOp <AtenOp>(); \
15451737 patterns.add <ConvertAtenMatMulOp<AtenOp>>(typeConverter, context);
15461738 INSERT_MATMUL_ATENOP_PATTERN (AtenMatmulOp);
1547- INSERT_MATMUL_ATENOP_PATTERN (AtenMmOp);
1548- INSERT_MATMUL_ATENOP_PATTERN (AtenBmmOp);
15491739#undef INSERT_MATMUL_ATEMOP_PATTERN
15501740
1741+ #define INSERT_MM_ATENOP_PATTERN (AtenOp ) \
1742+ target.addIllegalOp <AtenOp>(); \
1743+ patterns.add <ConvertAtenMmOp<AtenOp>>(typeConverter, context);
1744+ INSERT_MM_ATENOP_PATTERN (AtenMmOp);
1745+ INSERT_MM_ATENOP_PATTERN (AtenBmmOp);
1746+ #undef INSERT_MM_ATEMOP_PATTERN
1747+
1748+ #define INSERT_LINEAR_ATENOP_PATTERN (AtenOp ) \
1749+ target.addIllegalOp <AtenOp>(); \
1750+ patterns.add <ConvertAtenLinearOp<AtenOp>>(typeConverter, context);
1751+ INSERT_LINEAR_ATENOP_PATTERN (AtenLinearOp);
1752+ #undef INSERT_LINEAR_ATEMOP_PATTERN
1753+
15511754#define INSERT_ATENOP_PATTERN (AtenOp ) \
15521755 target.addIllegalOp <AtenOp>(); \
15531756 patterns.add <ConvertAtenOp<AtenOp>>(typeConverter, context);
0 commit comments