Skip to content

Commit eb06d21

Browse files
authored
[tosa] Implement conv2d support (#541)
Signed-off-by: Suraj Sudhir <[email protected]>
1 parent 3fd9b77 commit eb06d21

File tree

4 files changed

+270
-7
lines changed

4 files changed

+270
-7
lines changed

e2e_testing/torchscript/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,5 @@
8181
"ElementwiseCeilModule_basic",
8282
"ElementwiseReciprocalModule_basic",
8383
"TypePromotionAlphaWiderModule_basic",
84+
"Conv2dWithPaddingDilationStrideStaticModule_basic",
8485
}

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
3333
Value input_val, double input_scale,
3434
int64_t input_zp);
3535

36+
// Creates a TOSA rescale op based on conv2d parameters.
37+
Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
38+
Value conv_val, ShapedType input_type,
39+
ShapedType weight_type, ShapedType output_type);
40+
41+
// Check if scale32 mode is used for given output_element_type
42+
bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
43+
3644
// Create a 32-bit float constant operator from a float
3745
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
3846
float val);

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 175 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,8 +1058,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
10581058

10591059
// Step: generate the common dim/shape information
10601060
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
1061-
bool isDynamicDim =
1062-
ShapedType::isDynamic(lhsBroadcastedShape[dim]);
1061+
bool isDynamicDim = ShapedType::isDynamic(lhsBroadcastedShape[dim]);
10631062
if (isDynamicDim ||
10641063
lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) {
10651064
commonValue *= lhsBroadcastedShape[dim];
@@ -1070,8 +1069,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
10701069
// Step: generate the LHS squeezed dim/shape information.
10711070
bool hasDynamicDims = false;
10721071
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
1073-
bool isDynamicDim =
1074-
ShapedType::isDynamic(lhsBroadcastedShape[dim]);
1072+
bool isDynamicDim = ShapedType::isDynamic(lhsBroadcastedShape[dim]);
10751073
hasDynamicDims |= isDynamicDim;
10761074
if (!isDynamicDim &&
10771075
lhsBroadcastedShape[dim] != rhsBroadcastedShape[dim]) {
@@ -1155,8 +1153,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
11551153
// finally all the rhs_squeeze dims
11561154
hasDynamicDims = false;
11571155
for (uint32_t dim = 0; dim < maxInputRank - 2; dim++) {
1158-
bool isDynamicDim =
1159-
ShapedType::isDynamic(rhsBroadcastedShape[dim]);
1156+
bool isDynamicDim = ShapedType::isDynamic(rhsBroadcastedShape[dim]);
11601157
hasDynamicDims |= isDynamicDim;
11611158
if (!isDynamicDim &&
11621159
rhsBroadcastedShape[dim] != lhsBroadcastedShape[dim]) {
@@ -1374,7 +1371,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern<AtenOpT> {
13741371
// Other versions may add a bias, apply GEMM-style alpha/beta scaling etc.
13751372
virtual LogicalResult
13761373
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
1377-
ConversionPatternRewriter &rewriter) const {
1374+
ConversionPatternRewriter &rewriter) const override {
13781375

13791376
Value lhs, rhs;
13801377

@@ -1607,6 +1604,175 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
16071604
return success();
16081605
}
16091606

1607+
template <>
1608+
LogicalResult ConvertAtenOp<AtenConv2dOp>::matchAndRewrite(
1609+
AtenConv2dOp op, OpAdaptor adaptor,
1610+
ConversionPatternRewriter &rewriter) const {
1611+
1612+
auto input = adaptor.input();
1613+
auto weight = adaptor.weight();
1614+
1615+
auto inputTy = input.getType().template cast<RankedTensorType>();
1616+
auto weightTy = weight.getType().template cast<RankedTensorType>();
1617+
auto outputTy = getTypeConverter()
1618+
->convertType(op.getType())
1619+
.template cast<RankedTensorType>();
1620+
1621+
if (!inputTy || !weightTy || !outputTy)
1622+
return op.emitError(
1623+
"Input, weight and output to Conv2d must be ranked tensors");
1624+
1625+
auto inputElemTy = inputTy.getElementType();
1626+
auto weightElemTy = weightTy.getElementType();
1627+
auto inputShape = inputTy.getShape();
1628+
auto weightShape = weightTy.getShape();
1629+
1630+
// Bias is optional. TOSA mandates a zero tensor here, so construct one if
1631+
// required.
1632+
auto bias = adaptor.bias();
1633+
if (adaptor.bias().getType().template isa<Torch::NoneType>()) {
1634+
// TBD: This is only valid for quantized 8-bit. For 16-bit, the bias (and
1635+
// accumulator) are 48-bit and not 32-bit, and requires the use of APInt to
1636+
// define a 48-bit int.
1637+
if (inputElemTy.isa<quant::QuantizedType>()) {
1638+
SmallVector<int32_t> zeroVec(weightShape[0], 0);
1639+
bias = tosa::getConstTensor<int32_t>(
1640+
rewriter, op, zeroVec, {static_cast<int32_t>(weightShape[0])})
1641+
.getValue();
1642+
} else {
1643+
SmallVector<float> zeroVec(weightShape[0], 0);
1644+
bias = tosa::getConstTensor<float>(rewriter, op, zeroVec,
1645+
{static_cast<int32_t>(weightShape[0])})
1646+
.getValue();
1647+
}
1648+
} else {
1649+
if (!bias.getType().cast<RankedTensorType>())
1650+
return op.emitError("Bias provided but not a ranked tensor");
1651+
}
1652+
auto biasElemTy = inputElemTy.template isa<mlir::FloatType>()
1653+
? inputElemTy
1654+
: rewriter.getI32Type();
1655+
1656+
SmallVector<int64_t, 2> stride;
1657+
if (!matchPattern(adaptor.stride(), m_TorchConstantIntList(stride)))
1658+
return rewriter.notifyMatchFailure(op, "non-const stride list unsupported");
1659+
1660+
SmallVector<int64_t, 2> padding_2d;
1661+
if (!matchPattern(adaptor.padding(), m_TorchConstantIntList(padding_2d)))
1662+
return rewriter.notifyMatchFailure(op,
1663+
"non-const padding list unsupported");
1664+
// TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}.
1665+
// The Torch OFM computation uses 2*pad in each spatial direction, implying
1666+
// the same t=b and l=r values for TOSA.
1667+
SmallVector<int64_t> padding(
1668+
{padding_2d[0], padding_2d[0], padding_2d[1], padding_2d[1]});
1669+
1670+
SmallVector<int64_t, 2> dilation;
1671+
if (!matchPattern(adaptor.dilation(), m_TorchConstantIntList(dilation)))
1672+
return rewriter.notifyMatchFailure(op,
1673+
"non-const dilation list unsupported");
1674+
1675+
// TOSA works in NHWC and takes OHWI weights. Perform the necessary transpose.
1676+
llvm::Optional<Value> nchwToNhwcTransposeConst =
1677+
tosa::getConstTensor<int32_t>(rewriter, op,
1678+
/*vec=*/{0, 2, 3, 1},
1679+
/*shape=*/{static_cast<int32_t>(4)});
1680+
SmallVector<int64_t> transposedInputShape(
1681+
{inputShape[0], inputShape[2], inputShape[3], inputShape[1]});
1682+
auto transposedInputType =
1683+
RankedTensorType::get(transposedInputShape, inputElemTy);
1684+
auto transposedInput =
1685+
rewriter
1686+
.create<tosa::TransposeOp>(
1687+
op->getLoc(),
1688+
getTypeConverter()->convertType(transposedInputType), input,
1689+
nchwToNhwcTransposeConst.getValue())
1690+
.getResult();
1691+
1692+
SmallVector<int64_t> transposedWeightShape(
1693+
{weightShape[0], weightShape[2], weightShape[3], weightShape[1]});
1694+
auto transposedWeightType =
1695+
RankedTensorType::get(transposedWeightShape, weightElemTy);
1696+
auto transposedWeight =
1697+
rewriter
1698+
.create<tosa::TransposeOp>(
1699+
op->getLoc(),
1700+
getTypeConverter()->convertType(transposedWeightType), weight,
1701+
nchwToNhwcTransposeConst.getValue())
1702+
.getResult();
1703+
1704+
int64_t outputHDim, outputWDim;
1705+
if (inputTy.hasStaticShape()) {
1706+
outputHDim = (transposedInputShape[1] + padding[0] + padding[1] -
1707+
dilation[0] * (transposedWeightShape[1] - 1) - 1) /
1708+
stride[0] +
1709+
1;
1710+
outputWDim = (transposedInputShape[2] + padding[2] + padding[3] -
1711+
dilation[1] * (transposedWeightShape[2] - 1) - 1) /
1712+
stride[1] +
1713+
1;
1714+
} else {
1715+
outputHDim = ShapedType::kDynamicSize;
1716+
outputWDim = ShapedType::kDynamicSize;
1717+
}
1718+
1719+
// Output shape is NHWC, to be transposed back to NCHW. Output elemTy for
1720+
// quantized input is i32, which gets rescaled down to quantized output range.
1721+
SmallVector<int64_t> outputShape = {transposedInputShape[0], outputHDim,
1722+
outputWDim, transposedWeightShape[0]};
1723+
auto convOpTy = RankedTensorType::get(outputShape, biasElemTy);
1724+
1725+
Value convOpResult =
1726+
rewriter
1727+
.create<tosa::Conv2DOp>(op->getLoc(),
1728+
getTypeConverter()->convertType(convOpTy),
1729+
transposedInput, transposedWeight, bias,
1730+
rewriter.getI64ArrayAttr(padding),
1731+
rewriter.getI64ArrayAttr(stride),
1732+
rewriter.getI64ArrayAttr(dilation))
1733+
.getResult();
1734+
1735+
llvm::Optional<Value> nhwcToNchwTransposeConst =
1736+
tosa::getConstTensor<int32_t>(rewriter, op,
1737+
/*vec=*/{0, 3, 1, 2},
1738+
/*shape=*/{static_cast<int32_t>(4)});
1739+
SmallVector<int64_t> transposedOutputShape(
1740+
{outputShape[0], outputShape[3], outputShape[1], outputShape[2]});
1741+
auto transposedOutputType =
1742+
RankedTensorType::get(transposedOutputShape, biasElemTy);
1743+
auto transposedOutput =
1744+
rewriter
1745+
.create<tosa::TransposeOp>(
1746+
op->getLoc(),
1747+
getTypeConverter()->convertType(transposedOutputType),
1748+
convOpResult, nhwcToNchwTransposeConst.getValue())
1749+
.getResult();
1750+
1751+
Value rescaledResult = transposedOutput;
1752+
if (inputElemTy.template isa<quant::QuantizedType>()) {
1753+
rescaledResult = tosa::buildRescaleOpConvOutput(
1754+
rewriter, op, transposedOutput, inputTy, weightTy, outputTy);
1755+
}
1756+
1757+
rewriter.replaceOpWithNewOp<tensor::CastOp>(
1758+
op, getTypeConverter()->convertType(op.getType()), rescaledResult);
1759+
1760+
return success();
1761+
}
1762+
1763+
// Torch constants are converted to tosa.const .
1764+
template <>
1765+
LogicalResult ConvertAtenOp<ValueTensorLiteralOp>::matchAndRewrite(
1766+
ValueTensorLiteralOp op, OpAdaptor adaptor,
1767+
ConversionPatternRewriter &rewriter) const {
1768+
auto outputTy = getTypeConverter()
1769+
->convertType(op.getType())
1770+
.template cast<RankedTensorType>();
1771+
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputTy, adaptor.value());
1772+
1773+
return success();
1774+
}
1775+
16101776
} // namespace
16111777

16121778
// -----------------------------------------------------------------------------
@@ -1760,6 +1926,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
17601926
INSERT_ATENOP_PATTERN(AtenArgmaxOp);
17611927
INSERT_ATENOP_PATTERN(AtenPowTensorScalarOp);
17621928
INSERT_ATENOP_PATTERN(AtenRsubScalarOp);
1929+
INSERT_ATENOP_PATTERN(AtenConv2dOp);
1930+
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
17631931
#undef INSERT_ATENOP_PATTERN
17641932

17651933
if (failed(applyPartialConversion(getOperation(), target,

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,92 @@ Value buildRescaleToInt32(PatternRewriter &rewriter, Operation *op,
5252
input_zp, 0, false, true);
5353
}
5454

55+
// Creates a TOSA rescale op based on conv2d parameters.
56+
Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op,
57+
Value conv_val, ShapedType input_type,
58+
ShapedType weight_type, ShapedType output_type) {
59+
auto input_qtype =
60+
input_type.getElementType().dyn_cast<mlir::quant::UniformQuantizedType>();
61+
auto output_qtype = output_type.getElementType()
62+
.dyn_cast<mlir::quant::UniformQuantizedType>();
63+
64+
double input_scale = input_qtype.getScale();
65+
66+
int64_t output_zp = output_qtype.getZeroPoint();
67+
double output_scale = output_qtype.getScale();
68+
69+
bool scale32 = isScale32(output_qtype);
70+
int32_t scale_width = scale32 ? 32 : 16;
71+
72+
if (auto weight_per_tensor_qtype =
73+
weight_type.getElementType()
74+
.dyn_cast<mlir::quant::UniformQuantizedType>()) {
75+
// Per-tensor quantization
76+
double weight_scale = weight_per_tensor_qtype.getScale();
77+
78+
int32_t multiplier;
79+
int32_t shift;
80+
81+
double op_tensor_scale = (input_scale * weight_scale) / output_scale;
82+
83+
computeMultiplierAndShift(op_tensor_scale, multiplier, shift, scale_width);
84+
85+
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
86+
rewriter, op->getLoc(), output_type, conv_val,
87+
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
88+
rewriter.getI32ArrayAttr({multiplier}),
89+
rewriter.getI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32),
90+
rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
91+
92+
return rescale_op.getResult();
93+
94+
} else if (auto weight_per_channel_qtype =
95+
weight_type.getElementType()
96+
.dyn_cast<mlir::quant::UniformQuantizedPerAxisType>()) {
97+
// Per-channel quantization
98+
SmallVector<int32_t> multiplier_arr;
99+
SmallVector<int32_t> shift_arr;
100+
101+
SmallVector<double> weight_scale_arr(
102+
weight_per_channel_qtype.getScales().begin(),
103+
weight_per_channel_qtype.getScales().end());
104+
105+
int64_t output_zp = output_qtype.getZeroPoint();
106+
double output_scale = output_qtype.getScale();
107+
108+
for (double weight_scale : weight_scale_arr) {
109+
int32_t multiplier;
110+
int32_t shift;
111+
112+
double op_channel_scale = (input_scale * weight_scale) / output_scale;
113+
114+
computeMultiplierAndShift(op_channel_scale, multiplier, shift,
115+
scale_width);
116+
117+
multiplier_arr.push_back(multiplier);
118+
shift_arr.push_back(shift);
119+
}
120+
121+
auto rescale_op = CreateOpAndInfer<tosa::RescaleOp>(
122+
rewriter, op->getLoc(), output_type, conv_val,
123+
rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp),
124+
rewriter.getI32ArrayAttr(multiplier_arr),
125+
rewriter.getI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32),
126+
rewriter.getBoolAttr(true), rewriter.getBoolAttr(true));
127+
128+
return rescale_op.getResult();
129+
130+
} else {
131+
op->emitOpError("buildConvRescaleOp: unknown weight quantized type");
132+
return nullptr;
133+
}
134+
}
135+
136+
// Check if scale32 mode is used for given output_element_type
137+
bool isScale32(mlir::quant::UniformQuantizedType output_element_type) {
138+
return (output_element_type.getStorageTypeIntegralWidth() == 8);
139+
}
140+
55141
// Create a 32-bit float constant operator from a float
56142
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
57143
float val) {

0 commit comments

Comments
 (0)