Skip to content

Commit 4b6a660

Browse files
tatwaichongFranklandJack
authored andcommitted
[mlir][tosa] Make Convolution Zero Points Inputs
The TOSA-v1.0 specification moves the the "zero point" parameters of the convolution opertors CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs. Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately.
1 parent b68b4f6 commit 4b6a660

File tree

19 files changed

+546
-198
lines changed

19 files changed

+546
-198
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,4 +264,11 @@ class Tosa_InferShapedTypeOp<string mnemonic, list<Trait> traits = []>
264264
"operands attr-dict `:` functional-type(operands, results)";
265265
}
266266

267+
// The "SameVariadicOperandSize" trait allows us to pass optional arguments
268+
// for multiple zero points in convolution ops.
269+
class Tosa_ConvOp<string mnemonic, list<Trait> traits = []>
270+
: Tosa_InferShapedTypeOp<mnemonic, !listconcat(traits,
271+
[SameVariadicOperandSize])> {
272+
}
273+
267274
#endif // TOSA_OP_BASE

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
1717
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
1818
#include "mlir/Dialect/Traits.h"
19+
#include "mlir/IR/Matchers.h"
1920
#include "mlir/IR/OpDefinition.h"
2021
#include "mlir/IR/OpImplementation.h"
2122
#include "mlir/IR/TypeUtilities.h"
@@ -29,6 +30,7 @@
2930
//===----------------------------------------------------------------------===//
3031

3132
#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc"
33+
#include "mlir/Transforms/DialectConversion.h"
3234

3335
namespace mlir {
3436
class PatternRewriter;
@@ -152,4 +154,120 @@ bool isa_tosa_shape_type(mlir::Type t);
152154
#define GET_OP_CLASSES
153155
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
154156

157+
namespace mlir {
158+
namespace tosa {
159+
160+
// Create a rank-1 const tensor for zero point of the source tensor.
161+
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
162+
Type srcElemType, int64_t zp = 0);
163+
164+
// Get zero point value from the attribute argument.
165+
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
166+
167+
// Verify if zero point falls into valid range.
168+
template <typename T>
169+
LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
170+
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
171+
!std::is_same_v<T, DepthwiseConv2DOp> &&
172+
!std::is_same_v<T, TransposeConv2DOp>) {
173+
return failure();
174+
}
175+
176+
if (!zpElemType.isIntOrFloat())
177+
return failure();
178+
179+
if (!zpElemType.isInteger(8) && zp != 0)
180+
return failure();
181+
182+
if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
183+
return failure();
184+
185+
if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
186+
return failure();
187+
188+
return success();
189+
}
190+
191+
// Helper type trait to determine if an operation is a tosa convolution.
192+
template <typename Op>
193+
struct IsTosaConv : std::false_type {};
194+
195+
template <>
196+
struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
197+
template <>
198+
struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
199+
template <>
200+
struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
201+
template <>
202+
struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};
203+
204+
template <typename Op>
205+
constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;
206+
207+
// Helper struct to hold the zero points of a TOSA convolution operation as
208+
// named 64-bit integer fields.
209+
struct ConvZpPair {
210+
ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
211+
: inputZp(inputZp), weightZp(weightZp) {}
212+
std::int64_t inputZp;
213+
std::int64_t weightZp;
214+
};
215+
216+
// Helper function which attempts to extract the zero points from a TOSA
217+
// convolution by matching them against defining ops which should be tosa.const
218+
// operations.
219+
//
220+
// There are three possible results:
221+
// 1. Failed to extract the zero-points i.e. they should exist and don't or they
222+
// do exist but are invalid.
223+
// 2. Succeeded in extracting zero-points.
224+
// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
225+
// convolution.
226+
using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
227+
template <typename TosaConvOp>
228+
std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
229+
extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
230+
// Strictly speaking the base TOSA spec requires that for non int8 types
231+
// zero points must be zero. However, in the dialect these operands are
232+
// optional and only required for int8. They have no semantic meaning for
233+
// non-quantized types and can therefore be safely ignored. This is case 3.
234+
if (auto opElementTY =
235+
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
236+
!opElementTY.isInteger(8))
237+
return FailOrMaybeZP(std::nullopt);
238+
239+
// Now we know we should have a zero point check it is valid.
240+
if (!op.getInputZp())
241+
return rewriter.notifyMatchFailure(op, "missing input zero point");
242+
243+
// Helper to extract the zero point by matching its definition against a
244+
// constant.
245+
auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
246+
ElementsAttr zpAttr;
247+
if (!matchPattern(zpValue, m_Constant(&zpAttr)))
248+
return std::nullopt;
249+
250+
int64_t zp;
251+
if (tosa::getZeroPoint(zpAttr, zp).failed())
252+
return std::nullopt;
253+
254+
return std::make_optional(zp);
255+
};
256+
257+
auto maybeInputZp = extractZeroPoint(op.getInputZp());
258+
if (!maybeInputZp)
259+
return rewriter.notifyMatchFailure(op, "unable to extract input zp");
260+
261+
if (!op.getWeightZp())
262+
return rewriter.notifyMatchFailure(op, "missing weight zero point");
263+
264+
auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
265+
if (!maybeWeightZp)
266+
return rewriter.notifyMatchFailure(op, "unable to extract weight zp");
267+
268+
return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
269+
}
270+
} // namespace tosa
271+
} // namespace mlir
272+
155273
#endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
9292
//===----------------------------------------------------------------------===//
9393
// Operator: conv2d
9494
//===----------------------------------------------------------------------===//
95-
def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
95+
def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
9696
let summary = "2D Convolution Operator";
9797

9898
let description = [{
@@ -104,11 +104,12 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
104104
Tosa_Tensor4D:$input,
105105
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
106106
Tosa_Tensor1D:$bias,
107+
Optional<Tosa_ZeroPointTensor>:$input_zp,
108+
Optional<Tosa_ZeroPointTensor>:$weight_zp,
107109
Tosa_IntArrayAttr4:$pad,
108110
Tosa_IntArrayAttr2:$stride,
109111
Tosa_IntArrayAttr2:$dilation,
110112
TypeAttrOf<Tosa_AccType>:$acc_type,
111-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
112113
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
113114
);
114115

@@ -123,7 +124,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
123124
//===----------------------------------------------------------------------===//
124125
// Operator: conv3d
125126
//===----------------------------------------------------------------------===//
126-
def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
127+
def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
127128
let summary = "3D Convolution operator";
128129

129130
let description = [{
@@ -134,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
134135
Tosa_Tensor5D:$input,
135136
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
136137
Tosa_Tensor1D:$bias,
138+
Optional<Tosa_ZeroPointTensor>:$input_zp,
139+
Optional<Tosa_ZeroPointTensor>:$weight_zp,
137140
Tosa_IntArrayAttr6:$pad,
138141
Tosa_IntArrayAttr3:$stride,
139142
Tosa_IntArrayAttr3:$dilation,
140143
TypeAttrOf<Tosa_AccType>:$acc_type,
141-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
142144
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
143145
);
144146

@@ -153,7 +155,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
153155
//===----------------------------------------------------------------------===//
154156
// Operator: depthwise_conv2d
155157
//===----------------------------------------------------------------------===//
156-
def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
158+
def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
157159
let summary = "Depthwise 2D Convolution operator";
158160

159161
let description = [{
@@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
165167
Tosa_Tensor4D:$input,
166168
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
167169
Tosa_Tensor1D:$bias,
170+
Optional<Tosa_ZeroPointTensor>:$input_zp,
171+
Optional<Tosa_ZeroPointTensor>:$weight_zp,
168172
Tosa_IntArrayAttr4:$pad,
169173
Tosa_IntArrayAttr2:$stride,
170174
Tosa_IntArrayAttr2:$dilation,
171175
TypeAttrOf<Tosa_AccType>:$acc_type,
172-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
173176
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
174177
);
175178

@@ -338,7 +341,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
338341
//===----------------------------------------------------------------------===//
339342
// Operator: transpose_conv2d
340343
//===----------------------------------------------------------------------===//
341-
def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
344+
def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
342345
let summary = "Transpose 2D Convolution operator.";
343346

344347
let description = [{
@@ -348,13 +351,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
348351

349352
let arguments = (ins
350353
Tosa_Tensor4D:$input,
351-
TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
354+
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
352355
Tosa_Tensor1D:$bias,
356+
Optional<Tosa_ZeroPointTensor>:$input_zp,
357+
Optional<Tosa_ZeroPointTensor>:$weight_zp,
353358
Tosa_IntArrayAttr4:$out_pad,
354359
Tosa_IntArrayAttr2:$stride,
355360
Tosa_IntArrayAttr4:$out_shape,
356361
TypeAttrOf<Tosa_AccType>:$acc_type,
357-
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
358362
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
359363
);
360364

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,4 +288,6 @@ def Rank1TosaShape : TosaShapeOfRank<1>;
288288
def Rank2TosaShape : TosaShapeOfRank<2>;
289289
def Rank4TosaShape : TosaShapeOfRank<4>;
290290

291+
def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>;
292+
291293
#endif // TOSA_TYPES_BASE

mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier,
3535
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
3636
Value input, Value weight);
3737

38+
std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
39+
Value weight);
40+
3841
//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
3942
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
4043
Value a, Value b);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,12 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
258258
DenseI64ArrayAttr padAttr = op.getPadAttr();
259259
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
260260
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
261-
bool isQuantized = op.getQuantizationInfo().has_value();
261+
262+
auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
263+
if (llvm::failed(failureOrMaybeZps))
264+
return failure();
265+
266+
auto maybeZps = failureOrMaybeZps.value();
262267

263268
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
264269
return rewriter.notifyMatchFailure(
@@ -284,22 +289,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
284289

285290
// Apply padding as necessary.
286291
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
287-
if (isQuantized) {
288-
auto quantizationInfo = *op.getQuantizationInfo();
289-
int64_t iZp = quantizationInfo.getInputZp();
290-
292+
if (maybeZps) {
291293
int64_t intMin =
292294
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
293295
.getSExtValue();
294296
int64_t intMax =
295297
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
296298
.getSExtValue();
297299

298-
if (iZp < intMin || iZp > intMax)
300+
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
299301
return rewriter.notifyMatchFailure(
300302
op, "tosa.conv op quantization has zp outside of input range");
301303

302-
zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
304+
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
303305
}
304306

305307
llvm::SmallVector<int64_t> pad;
@@ -312,8 +314,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
312314
// For 2D convolutions, we need to check if the target convolution op
313315
// wants a HWCF kernel layout.
314316
bool wantHwcf =
315-
isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
316-
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317+
maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
318+
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317319
if (wantHwcf) {
318320
// Transpose the kernel to match dimension ordering of the linalg
319321
// convolution operation.
@@ -374,10 +376,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
374376
Value broadcastBias =
375377
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
376378

377-
if (isQuantized) {
378-
auto quantizationInfo = *op.getQuantizationInfo();
379-
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
380-
auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
379+
if (maybeZps) {
380+
auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
381+
auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
381382

382383
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
383384
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -440,39 +441,31 @@ class DepthwiseConvConverter
440441
/*inputSizeDims=*/{1, 2},
441442
/*kernelSizeDims=*/{0, 1}, rewriter);
442443

443-
bool isQuantized = op->hasAttr("quantization_info");
444-
IntegerAttr iZp;
445-
IntegerAttr kZp;
446-
if (isQuantized) {
447-
auto quantizationInfo =
448-
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
449-
iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
450-
kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
451-
}
444+
auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
445+
if (llvm::failed(failureOrMaybeZps))
446+
return failure();
447+
448+
auto maybeZps = failureOrMaybeZps.value();
452449

453450
auto weightShape = weightTy.getShape();
454451
auto resultShape = resultTy.getShape();
455452

456453
// Apply padding as necessary.
457454
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
458-
if (isQuantized) {
459-
auto quantizationInfo =
460-
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
461-
int64_t iZp = quantizationInfo.getInputZp();
462-
455+
if (maybeZps) {
463456
int64_t intMin =
464457
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
465458
.getSExtValue();
466459
int64_t intMax =
467460
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
468461
.getSExtValue();
469462

470-
if (iZp < intMin || iZp > intMax)
463+
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
471464
return rewriter.notifyMatchFailure(
472465
op, "tosa.depthwise_conv op quantization has zp outside of input "
473466
"range");
474467

475-
zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
468+
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
476469
}
477470

478471
llvm::SmallVector<int64_t> pad;
@@ -512,7 +505,7 @@ class DepthwiseConvConverter
512505
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
513506
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
514507

515-
if (!isQuantized) {
508+
if (!maybeZps) {
516509
Value conv = rewriter
517510
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
518511
loc, linalgConvTy, ValueRange{input, weight},
@@ -539,6 +532,8 @@ class DepthwiseConvConverter
539532
.getResult(0);
540533
rewriter.replaceOp(op, result);
541534
} else {
535+
IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
536+
IntegerAttr kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
542537
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
543538
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
544539
Value conv =

0 commit comments

Comments
 (0)