@@ -1058,8 +1058,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
10581058
10591059 // Step: generate the common dim/shape information
10601060 for (uint32_t dim = 0 ; dim < maxInputRank - 2 ; dim++) {
1061- bool isDynamicDim =
1062- ShapedType::isDynamic (lhsBroadcastedShape[dim]);
1061+ bool isDynamicDim = ShapedType::isDynamic (lhsBroadcastedShape[dim]);
10631062 if (isDynamicDim ||
10641063 lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) {
10651064 commonValue *= lhsBroadcastedShape[dim];
@@ -1070,8 +1069,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
10701069 // Step: generate the LHS squeezed dim/shape information.
10711070 bool hasDynamicDims = false ;
10721071 for (uint32_t dim = 0 ; dim < maxInputRank - 2 ; dim++) {
1073- bool isDynamicDim =
1074- ShapedType::isDynamic (lhsBroadcastedShape[dim]);
1072+ bool isDynamicDim = ShapedType::isDynamic (lhsBroadcastedShape[dim]);
10751073 hasDynamicDims |= isDynamicDim;
10761074 if (!isDynamicDim &&
10771075 lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) {
@@ -1155,8 +1153,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
11551153 // finally all the rhs_squeeze dims
11561154 hasDynamicDims = false ;
11571155 for (uint32_t dim = 0 ; dim < maxInputRank - 2 ; dim++) {
1158- bool isDynamicDim =
1159- ShapedType::isDynamic (rhsBroadcastedShape[dim]);
1156+ bool isDynamicDim = ShapedType::isDynamic (rhsBroadcastedShape[dim]);
11601157 hasDynamicDims |= isDynamicDim;
11611158 if (!isDynamicDim &&
11621159 rhsBroadcastedShape[dim] != lhsBroadcastedShape[dim]) {
@@ -1374,7 +1371,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
13741371 // Other versions may add a bias, apply GEMM-style alpha/beta scaling etc.
13751372 virtual LogicalResult
13761373 matchAndRewrite (AtenOpT op, OpAdaptor adaptor,
1377- ConversionPatternRewriter &rewriter) const {
1374+ ConversionPatternRewriter &rewriter) const override {
13781375
13791376 Value lhs, rhs;
13801377
@@ -1607,6 +1604,175 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
16071604 return success ();
16081605}
16091606
1607+ template <>
1608+ LogicalResult ConvertAtenOp<AtenConv2dOp>::matchAndRewrite(
1609+ AtenConv2dOp op, OpAdaptor adaptor,
1610+ ConversionPatternRewriter &rewriter) const {
1611+
1612+ auto input = adaptor.input ();
1613+ auto weight = adaptor.weight ();
1614+
1615+ auto inputTy = input.getType ().template cast <RankedTensorType>();
1616+ auto weightTy = weight.getType ().template cast <RankedTensorType>();
1617+ auto outputTy = getTypeConverter ()
1618+ ->convertType (op.getType ())
1619+ .template cast <RankedTensorType>();
1620+
1621+ if (!inputTy || !weightTy || !outputTy)
1622+ return op.emitError (
1623+ " Input, weight and output to Conv2d must be ranked tensors" );
1624+
1625+ auto inputElemTy = inputTy.getElementType ();
1626+ auto weightElemTy = weightTy.getElementType ();
1627+ auto inputShape = inputTy.getShape ();
1628+ auto weightShape = weightTy.getShape ();
1629+
1630+ // Bias is optional. TOSA mandates a zero tensor here, so construct one if
1631+ // required.
1632+ auto bias = adaptor.bias ();
1633+ if (adaptor.bias ().getType ().template isa <Torch::NoneType>()) {
1634+ // TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
1635+ // accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
1636+ // define a 48-bit int.
1637+ if (inputElemTy.isa <quant::QuantizedType>()) {
1638+ SmallVector<int32_t > zeroVec (weightShape[0 ], 0 );
1639+ bias = tosa::getConstTensor<int32_t >(
1640+ rewriter, op, zeroVec, {static_cast <int32_t >(weightShape[0 ])})
1641+ .getValue ();
1642+ } else {
1643+ SmallVector<float > zeroVec (weightShape[0 ], 0 );
1644+ bias = tosa::getConstTensor<float >(rewriter, op, zeroVec,
1645+ {static_cast <int32_t >(weightShape[0 ])})
1646+ .getValue ();
1647+ }
1648+ } else {
1649+ if (!bias.getType ().cast <RankedTensorType>())
1650+ return op.emitError (" Bias provided but not a ranked tensor" );
1651+ }
1652+ auto biasElemTy = inputElemTy.template isa <mlir::FloatType>()
1653+ ? inputElemTy
1654+ : rewriter.getI32Type ();
1655+
1656+ SmallVector<int64_t , 2 > stride;
1657+ if (!matchPattern (adaptor.stride (), m_TorchConstantIntList (stride)))
1658+ return rewriter.notifyMatchFailure (op, " non-const stride list unsupported" );
1659+
1660+ SmallVector<int64_t , 2 > padding_2d;
1661+ if (!matchPattern (adaptor.padding (), m_TorchConstantIntList (padding_2d)))
1662+ return rewriter.notifyMatchFailure (op,
1663+ " non-const padding list unsupported" );
1664+ // TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}.
1665+ // The Torch OFM computation uses 2*pad in each spatial direction, implying
1666+ // the same t=b and l=r values for TOSA.
1667+ SmallVector<int64_t > padding (
1668+ {padding_2d[0 ], padding_2d[0 ], padding_2d[1 ], padding_2d[1 ]});
1669+
1670+ SmallVector<int64_t , 2 > dilation;
1671+ if (!matchPattern (adaptor.dilation (), m_TorchConstantIntList (dilation)))
1672+ return rewriter.notifyMatchFailure (op,
1673+ " non-const dilation list unsupported" );
1674+
1675+ // TOSA works in NHWC and takes OHWI weights. Perform the necessary transpose.
1676+ llvm::Optional<Value> nchwToNhwcTransposeConst =
1677+ tosa::getConstTensor<int32_t >(rewriter, op,
1678+ /* vec=*/ {0 , 2 , 3 , 1 },
1679+ /* shape=*/ {static_cast <int32_t >(4 )});
1680+ SmallVector<int64_t > transposedInputShape (
1681+ {inputShape[0 ], inputShape[2 ], inputShape[3 ], inputShape[1 ]});
1682+ auto transposedInputType =
1683+ RankedTensorType::get (transposedInputShape, inputElemTy);
1684+ auto transposedInput =
1685+ rewriter
1686+ .create <tosa::TransposeOp>(
1687+ op->getLoc (),
1688+ getTypeConverter ()->convertType (transposedInputType), input,
1689+ nchwToNhwcTransposeConst.getValue ())
1690+ .getResult ();
1691+
1692+ SmallVector<int64_t > transposedWeightShape (
1693+ {weightShape[0 ], weightShape[2 ], weightShape[3 ], weightShape[1 ]});
1694+ auto transposedWeightType =
1695+ RankedTensorType::get (transposedWeightShape, weightElemTy);
1696+ auto transposedWeight =
1697+ rewriter
1698+ .create <tosa::TransposeOp>(
1699+ op->getLoc (),
1700+ getTypeConverter ()->convertType (transposedWeightType), weight,
1701+ nchwToNhwcTransposeConst.getValue ())
1702+ .getResult ();
1703+
1704+ int64_t outputHDim, outputWDim;
1705+ if (inputTy.hasStaticShape ()) {
1706+ outputHDim = (transposedInputShape[1 ] + padding[0 ] + padding[1 ] -
1707+ dilation[0 ] * (transposedWeightShape[1 ] - 1 ) - 1 ) /
1708+ stride[0 ] +
1709+ 1 ;
1710+ outputWDim = (transposedInputShape[2 ] + padding[2 ] + padding[3 ] -
1711+ dilation[1 ] * (transposedWeightShape[2 ] - 1 ) - 1 ) /
1712+ stride[1 ] +
1713+ 1 ;
1714+ } else {
1715+ outputHDim = ShapedType::kDynamicSize ;
1716+ outputWDim = ShapedType::kDynamicSize ;
1717+ }
1718+
1719+ // Output shape is NHWC, to be transposed back to NCHW. Output elemTy for
1720+ // quantized input is i32, which gets rescaled down to quantized output range.
1721+ SmallVector<int64_t > outputShape = {transposedInputShape[0 ], outputHDim,
1722+ outputWDim, transposedWeightShape[0 ]};
1723+ auto convOpTy = RankedTensorType::get (outputShape, biasElemTy);
1724+
1725+ Value convOpResult =
1726+ rewriter
1727+ .create <tosa::Conv2DOp>(op->getLoc (),
1728+ getTypeConverter ()->convertType (convOpTy),
1729+ transposedInput, transposedWeight, bias,
1730+ rewriter.getI64ArrayAttr (padding),
1731+ rewriter.getI64ArrayAttr (stride),
1732+ rewriter.getI64ArrayAttr (dilation))
1733+ .getResult ();
1734+
1735+ llvm::Optional<Value> nhwcToNchwTransposeConst =
1736+ tosa::getConstTensor<int32_t >(rewriter, op,
1737+ /* vec=*/ {0 , 3 , 1 , 2 },
1738+ /* shape=*/ {static_cast <int32_t >(4 )});
1739+ SmallVector<int64_t > transposedOutputShape (
1740+ {outputShape[0 ], outputShape[3 ], outputShape[1 ], outputShape[2 ]});
1741+ auto transposedOutputType =
1742+ RankedTensorType::get (transposedOutputShape, biasElemTy);
1743+ auto transposedOutput =
1744+ rewriter
1745+ .create <tosa::TransposeOp>(
1746+ op->getLoc (),
1747+ getTypeConverter ()->convertType (transposedOutputType),
1748+ convOpResult, nhwcToNchwTransposeConst.getValue ())
1749+ .getResult ();
1750+
1751+ Value rescaledResult = transposedOutput;
1752+ if (inputElemTy.template isa <quant::QuantizedType>()) {
1753+ rescaledResult = tosa::buildRescaleOpConvOutput (
1754+ rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
1755+ }
1756+
1757+ rewriter.replaceOpWithNewOp <tensor::CastOp>(
1758+ op, getTypeConverter ()->convertType (op.getType ()), rescaledResult);
1759+
1760+ return success ();
1761+ }
1762+
1763+ // Torch constants are converted to tosa.const .
1764+ template <>
1765+ LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
1766+ ValueTensorLiteralOp op, OpAdaptor adaptor,
1767+ ConversionPatternRewriter &rewriter) const {
1768+ auto outputTy = getTypeConverter ()
1769+ ->convertType (op.getType ())
1770+ .template cast <RankedTensorType>();
1771+ rewriter.replaceOpWithNewOp <tosa::ConstOp>(op, outputTy, adaptor.value ());
1772+
1773+ return success ();
1774+ }
1775+
16101776} // namespace
16111777
16121778// -----------------------------------------------------------------------------
@@ -1760,6 +1926,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
17601926 INSERT_ATENOP_PATTERN (AtenArgmaxOp);
17611927 INSERT_ATENOP_PATTERN (AtenPowTensorScalarOp);
17621928 INSERT_ATENOP_PATTERN (AtenRsubScalarOp);
1929+ INSERT_ATENOP_PATTERN (AtenConv2dOp);
1930+ INSERT_ATENOP_PATTERN (ValueTensorLiteralOp);
17631931#undef INSERT_ATENOP_PATTERN
17641932
17651933 if (failed (applyPartialConversion (getOperation (), target,
0 commit comments