Skip to content

Commit bbd925b

Browse files
tatwaichongFranklandJack
authored andcommitted
[mlir][tosa] Make Convolution Zero Points Inputs
The TOSA-v1.0 specification moves the the "zero point" parameters of the convolution opertors CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs. Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately.
1 parent eaa5897 commit bbd925b

21 files changed

+802
-334
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,38 @@ bool isa_tosa_shape_type(mlir::Type t);
142142
#define GET_OP_CLASSES
143143
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
144144

145+
namespace mlir {
146+
namespace tosa {
147+
148+
// Create a rank-1 const tensor for zero point of the source tensor.
149+
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
150+
Type srcElemType, int64_t zp = 0);
151+
152+
// Get zero point value from the attribute argument.
153+
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
154+
155+
// Verify if zero point falls into valid range.
156+
template <typename T>
157+
LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
158+
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
159+
!std::is_same_v<T, DepthwiseConv2DOp> &&
160+
!std::is_same_v<T, TransposeConv2DOp>) {
161+
return failure();
162+
}
163+
164+
if (!zpElemType.isIntOrFloat())
165+
return failure();
166+
167+
if (!zpElemType.isInteger(8) && zp != 0)
168+
return failure();
169+
170+
if (zp < -128 || zp > 127)
171+
return failure();
172+
173+
return success();
174+
}
175+
176+
} // namespace tosa
177+
} // namespace mlir
178+
145179
#endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,13 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
104104
Tosa_Tensor4D:$input,
105105
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
106106
Tosa_Tensor1D:$bias,
107+
TosaTensorRankOf<[Tosa_AnyNumber], [1]>:$input_zp,
108+
TosaTensorRankOf<[Tosa_AnyNumber], [1]>:$weight_zp,
109+
107110
Tosa_IntArrayAttr4:$pad,
108111
Tosa_IntArrayAttr2:$stride,
109112
Tosa_IntArrayAttr2:$dilation,
110113
TypeAttrOf<Tosa_AccType>:$acc_type,
111-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
112114
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
113115
);
114116

@@ -134,11 +136,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
134136
Tosa_Tensor5D:$input,
135137
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
136138
Tosa_Tensor1D:$bias,
139+
TosaTensorRankOf<[Tosa_AnyNumber], [1]>:$input_zp,
140+
TosaTensorRankOf<[Tosa_AnyNumber], [1]>:$weight_zp,
137141
Tosa_IntArrayAttr6:$pad,
138142
Tosa_IntArrayAttr3:$stride,
139143
Tosa_IntArrayAttr3:$dilation,
140144
TypeAttrOf<Tosa_AccType>:$acc_type,
141-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
142145
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
143146
);
144147

@@ -165,11 +168,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
165168
Tosa_Tensor4D:$input,
166169
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
167170
Tosa_Tensor1D:$bias,
171+
TosaTensorRankOf<[Tosa_AnyNumber], [1]>:$input_zp,
172+
TosaTensorRankOf<[Tosa_AnyNumber], [1]>:$weight_zp,
168173
Tosa_IntArrayAttr4:$pad,
169174
Tosa_IntArrayAttr2:$stride,
170175
Tosa_IntArrayAttr2:$dilation,
171176
TypeAttrOf<Tosa_AccType>:$acc_type,
172-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
173177
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
174178
);
175179

@@ -348,13 +352,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
348352

349353
let arguments = (ins
350354
Tosa_Tensor4D:$input,
351-
TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
355+
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
352356
Tosa_Tensor1D:$bias,
357+
TosaTensorRankOf<[Tosa_AnyNumber], [1]>:$input_zp,
358+
TosaTensorRankOf<[Tosa_AnyNumber], [1]>:$weight_zp,
353359
Tosa_IntArrayAttr4:$out_pad,
354360
Tosa_IntArrayAttr2:$stride,
355361
Tosa_IntArrayAttr4:$out_shape,
356362
TypeAttrOf<Tosa_AccType>:$acc_type,
357-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
358363
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
359364
);
360365

mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier,
3535
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
3636
Value input, Value weight);
3737

38+
std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
39+
Value weight);
40+
3841
//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
3942
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
4043
Value a, Value b);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
2222
#include "mlir/Dialect/Utils/IndexingUtils.h"
2323
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
24+
#include "mlir/IR/BuiltinTypes.h"
2425
#include "mlir/IR/Matchers.h"
2526
#include "mlir/IR/PatternMatch.h"
2627
#include "mlir/Transforms/DialectConversion.h"
@@ -258,7 +259,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
258259
DenseI64ArrayAttr padAttr = op.getPadAttr();
259260
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
260261
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
261-
bool isQuantized = op.getQuantizationInfo().has_value();
262+
263+
ElementsAttr inputZpAttr;
264+
ElementsAttr weightZpAttr;
265+
if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
266+
!matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr)))
267+
return rewriter.notifyMatchFailure(
268+
op,
269+
"bail out if the actual value of zero points cannot be determined");
270+
271+
// Get and verify explicit zero points.
272+
int64_t inputZpVal;
273+
int64_t weightZpVal;
274+
275+
if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
276+
tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf(inputZpAttr),
277+
inputZpVal)
278+
.failed())
279+
return rewriter.notifyMatchFailure(
280+
op, "input zero point must be zero for non-int8 integer types");
281+
282+
if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
283+
tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf(weightZpAttr),
284+
weightZpVal)
285+
.failed())
286+
return rewriter.notifyMatchFailure(
287+
op, "weight zero point must be zero for non-int8 integer types");
288+
289+
const bool hasZp =
290+
(inputZpVal != 0) || (weightZpVal != 0) || isa<IntegerType>(inputETy);
262291

263292
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
264293
return rewriter.notifyMatchFailure(
@@ -284,22 +313,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
284313

285314
// Apply padding as necessary.
286315
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
287-
if (isQuantized) {
288-
auto quantizationInfo = *op.getQuantizationInfo();
289-
int64_t iZp = quantizationInfo.getInputZp();
290-
316+
if (hasZp) {
291317
int64_t intMin =
292318
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
293319
.getSExtValue();
294320
int64_t intMax =
295321
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
296322
.getSExtValue();
297323

298-
if (iZp < intMin || iZp > intMax)
324+
if (inputZpVal < intMin || inputZpVal > intMax)
299325
return rewriter.notifyMatchFailure(
300326
op, "tosa.conv op quantization has zp outside of input range");
301327

302-
zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
328+
zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
303329
}
304330

305331
llvm::SmallVector<int64_t> pad;
@@ -312,8 +338,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
312338
// For 2D convolutions, we need to check if the target convolution op
313339
// wants a HWCF kernel layout.
314340
bool wantHwcf =
315-
isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
316-
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
341+
hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
342+
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317343
if (wantHwcf) {
318344
// Transpose the kernel to match dimension ordering of the linalg
319345
// convolution operation.
@@ -374,10 +400,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
374400
Value broadcastBias =
375401
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
376402

377-
if (isQuantized) {
378-
auto quantizationInfo = *op.getQuantizationInfo();
379-
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
380-
auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
403+
if (hasZp) {
404+
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
405+
auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
381406

382407
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
383408
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -440,25 +465,40 @@ class DepthwiseConvConverter
440465
/*inputSizeDims=*/{1, 2},
441466
/*kernelSizeDims=*/{0, 1}, rewriter);
442467

443-
bool isQuantized = op->hasAttr("quantization_info");
444-
IntegerAttr iZp;
445-
IntegerAttr kZp;
446-
if (isQuantized) {
447-
auto quantizationInfo =
448-
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
449-
iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
450-
kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
451-
}
468+
ElementsAttr inputZpAttr;
469+
ElementsAttr weightZpAttr;
470+
if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
471+
!matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr)))
472+
return rewriter.notifyMatchFailure(
473+
op,
474+
"bail out if the actual value of zero points cannot be determined");
475+
476+
// Get and verify explicit zero points.
477+
int64_t inputZpVal;
478+
int64_t weightZpVal;
479+
480+
if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
481+
tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
482+
getElementTypeOrSelf(inputZpAttr), inputZpVal)
483+
.failed())
484+
return rewriter.notifyMatchFailure(
485+
op, "input zero point must be zero for non-int8 integer types");
486+
487+
if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
488+
tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
489+
getElementTypeOrSelf(weightZpAttr), weightZpVal)
490+
.failed())
491+
return rewriter.notifyMatchFailure(
492+
op, "weight zero point must be zero for non-int8 integer types");
452493

494+
bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
453495
auto weightShape = weightTy.getShape();
454496
auto resultShape = resultTy.getShape();
455497

456498
// Apply padding as necessary.
457499
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
458-
if (isQuantized) {
459-
auto quantizationInfo =
460-
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
461-
int64_t iZp = quantizationInfo.getInputZp();
500+
if (inputZpVal) {
501+
const int64_t iZp = inputZpVal;
462502

463503
int64_t intMin =
464504
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
@@ -512,7 +552,7 @@ class DepthwiseConvConverter
512552
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
513553
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
514554

515-
if (!isQuantized) {
555+
if (!hasZp && isa<FloatType>(inputETy)) {
516556
Value conv = rewriter
517557
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
518558
loc, linalgConvTy, ValueRange{input, weight},
@@ -539,6 +579,8 @@ class DepthwiseConvConverter
539579
.getResult(0);
540580
rewriter.replaceOp(op, result);
541581
} else {
582+
IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
583+
IntegerAttr kZp = rewriter.getI32IntegerAttr(weightZpVal);
542584
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
543585
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
544586
Value conv =

0 commit comments

Comments
 (0)