Skip to content

Commit 8b2cbe2

Browse files
tatwaichongFranklandJack
authored andcommitted
[mlir][tosa] Make Convolution Zero Points Inputs
The TOSA-v1.0 specification moves the the "zero point" parameters of the convolution opertors CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs. Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately.
1 parent 8d306cc commit 8b2cbe2

21 files changed

+800
-338
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,38 @@ class TosaElementwiseOperator
101101
#define GET_OP_CLASSES
102102
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
103103

104+
namespace mlir {
105+
namespace tosa {
106+
107+
// Create a rank-0 const tensor for zero point of the source tensor.
108+
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
109+
Type srcElemType, int64_t zp = 0);
110+
111+
// Get zero point value from the attribute argument.
112+
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
113+
114+
// Verify if zero point falls into valid range.
115+
template <typename T>
116+
LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
117+
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
118+
!std::is_same_v<T, DepthwiseConv2DOp> &&
119+
!std::is_same_v<T, TransposeConv2DOp>) {
120+
return failure();
121+
}
122+
123+
if (!zpElemType.isIntOrFloat())
124+
return failure();
125+
126+
if (!zpElemType.isInteger(8) && zp != 0)
127+
return failure();
128+
129+
if (zp < -128 || zp > 127)
130+
return failure();
131+
132+
return success();
133+
}
134+
135+
} // namespace tosa
136+
} // namespace mlir
137+
104138
#endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,13 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
103103
Tosa_Tensor4D:$input,
104104
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
105105
Tosa_Tensor1D:$bias,
106+
Tosa_ScalarTensor:$input_zp,
107+
Tosa_ScalarTensor:$weight_zp,
108+
106109
Tosa_IntArrayAttr4:$pad,
107110
Tosa_IntArrayAttr2:$stride,
108111
Tosa_IntArrayAttr2:$dilation,
109112
TypeAttrOf<Tosa_AccType>:$acc_type,
110-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
111113
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
112114
);
113115

@@ -133,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
133135
Tosa_Tensor5D:$input,
134136
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
135137
Tosa_Tensor1D:$bias,
138+
Tosa_ScalarTensor:$input_zp,
139+
Tosa_ScalarTensor:$weight_zp,
136140
Tosa_IntArrayAttr6:$pad,
137141
Tosa_IntArrayAttr3:$stride,
138142
Tosa_IntArrayAttr3:$dilation,
139143
TypeAttrOf<Tosa_AccType>:$acc_type,
140-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
141144
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
142145
);
143146

@@ -164,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
164167
Tosa_Tensor4D:$input,
165168
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
166169
Tosa_Tensor1D:$bias,
170+
Tosa_ScalarTensor:$input_zp,
171+
Tosa_ScalarTensor:$weight_zp,
167172
Tosa_IntArrayAttr4:$pad,
168173
Tosa_IntArrayAttr2:$stride,
169174
Tosa_IntArrayAttr2:$dilation,
170175
TypeAttrOf<Tosa_AccType>:$acc_type,
171-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
172176
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
173177
);
174178

@@ -346,13 +350,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
346350

347351
let arguments = (ins
348352
Tosa_Tensor4D:$input,
349-
TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
353+
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
350354
Tosa_Tensor1D:$bias,
355+
Tosa_ScalarTensor:$input_zp,
356+
Tosa_ScalarTensor:$weight_zp,
351357
Tosa_IntArrayAttr4:$out_pad,
352358
Tosa_IntArrayAttr2:$stride,
353359
Tosa_IntArrayAttr4:$out_shape,
354360
TypeAttrOf<Tosa_AccType>:$acc_type,
355-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
356361
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
357362
);
358363

mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier,
3535
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
3636
Value input, Value weight);
3737

38+
std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
39+
Value weight);
40+
3841
//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
3942
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
4043
Value a, Value b);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
2222
#include "mlir/Dialect/Utils/IndexingUtils.h"
2323
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24+
#include "mlir/IR/BuiltinTypes.h"
2425
#include "mlir/IR/Matchers.h"
2526
#include "mlir/IR/PatternMatch.h"
2627
#include "mlir/Transforms/DialectConversion.h"
@@ -258,7 +259,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
258259
DenseI64ArrayAttr padAttr = op.getPadAttr();
259260
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
260261
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
261-
bool isQuantized = op.getQuantizationInfo().has_value();
262+
263+
ElementsAttr inputZpAttr;
264+
ElementsAttr weightZpAttr;
265+
if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
266+
!matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr)))
267+
return rewriter.notifyMatchFailure(
268+
op,
269+
"bail out if the actual value of zero points cannot be determined");
270+
271+
// Get and verify explicit zero points.
272+
int64_t inputZpVal;
273+
int64_t weightZpVal;
274+
275+
if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
276+
tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf(inputZpAttr),
277+
inputZpVal)
278+
.failed())
279+
return rewriter.notifyMatchFailure(
280+
op, "input zero point must be zero for non-int8 integer types");
281+
282+
if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
283+
tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf(weightZpAttr),
284+
weightZpVal)
285+
.failed())
286+
return rewriter.notifyMatchFailure(
287+
op, "weight zero point must be zero for non-int8 integer types");
288+
289+
const bool hasZp =
290+
(inputZpVal != 0) || (weightZpVal != 0) || isa<IntegerType>(inputETy);
262291

263292
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
264293
return rewriter.notifyMatchFailure(
@@ -284,22 +313,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
284313

285314
// Apply padding as necessary.
286315
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
287-
if (isQuantized) {
288-
auto quantizationInfo = *op.getQuantizationInfo();
289-
int64_t iZp = quantizationInfo.getInputZp();
290-
316+
if (hasZp) {
291317
int64_t intMin =
292318
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
293319
.getSExtValue();
294320
int64_t intMax =
295321
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
296322
.getSExtValue();
297323

298-
if (iZp < intMin || iZp > intMax)
324+
if (inputZpVal < intMin || inputZpVal > intMax)
299325
return rewriter.notifyMatchFailure(
300326
op, "tosa.conv op quantization has zp outside of input range");
301327

302-
zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
328+
zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
303329
}
304330

305331
llvm::SmallVector<int64_t> pad;
@@ -312,8 +338,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
312338
// For 2D convolutions, we need to check if the target convolution op
313339
// wants a HWCF kernel layout.
314340
bool wantHwcf =
315-
isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
316-
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
341+
hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
342+
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317343
if (wantHwcf) {
318344
// Transpose the kernel to match dimension ordering of the linalg
319345
// convolution operation.
@@ -374,10 +400,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
374400
Value broadcastBias =
375401
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
376402

377-
if (isQuantized) {
378-
auto quantizationInfo = *op.getQuantizationInfo();
379-
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
380-
auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
403+
if (hasZp) {
404+
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
405+
auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
381406

382407
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
383408
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -440,25 +465,40 @@ class DepthwiseConvConverter
440465
/*inputSizeDims=*/{1, 2},
441466
/*kernelSizeDims=*/{0, 1}, rewriter);
442467

443-
bool isQuantized = op->hasAttr("quantization_info");
444-
IntegerAttr iZp;
445-
IntegerAttr kZp;
446-
if (isQuantized) {
447-
auto quantizationInfo =
448-
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
449-
iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
450-
kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
451-
}
468+
ElementsAttr inputZpAttr;
469+
ElementsAttr weightZpAttr;
470+
if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
471+
!matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr)))
472+
return rewriter.notifyMatchFailure(
473+
op,
474+
"bail out if the actual value of zero points cannot be determined");
475+
476+
// Get and verify explicit zero points.
477+
int64_t inputZpVal;
478+
int64_t weightZpVal;
479+
480+
if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
481+
tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
482+
getElementTypeOrSelf(inputZpAttr), inputZpVal)
483+
.failed())
484+
return rewriter.notifyMatchFailure(
485+
op, "input zero point must be zero for non-int8 integer types");
486+
487+
if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
488+
tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
489+
getElementTypeOrSelf(weightZpAttr), weightZpVal)
490+
.failed())
491+
return rewriter.notifyMatchFailure(
492+
op, "weight zero point must be zero for non-int8 integer types");
452493

494+
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
453495
auto weightShape = weightTy.getShape();
454496
auto resultShape = resultTy.getShape();
455497

456498
// Apply padding as necessary.
457499
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
458-
if (isQuantized) {
459-
auto quantizationInfo =
460-
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
461-
int64_t iZp = quantizationInfo.getInputZp();
500+
if (inputZpVal) {
501+
const int64_t iZp = inputZpVal;
462502

463503
int64_t intMin =
464504
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
@@ -512,7 +552,7 @@ class DepthwiseConvConverter
512552
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
513553
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
514554

515-
if (!isQuantized) {
555+
if (!hasZp && isa<FloatType>(inputETy)) {
516556
Value conv = rewriter
517557
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
518558
loc, linalgConvTy, ValueRange{input, weight},
@@ -539,6 +579,8 @@ class DepthwiseConvConverter
539579
.getResult(0);
540580
rewriter.replaceOp(op, result);
541581
} else {
582+
IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
583+
IntegerAttr kZp = rewriter.getI32IntegerAttr(weightZpVal);
542584
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
543585
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
544586
Value conv =

0 commit comments

Comments
 (0)