2121#include " mlir/Dialect/Tosa/Utils/ConversionUtils.h"
2222#include " mlir/Dialect/Utils/IndexingUtils.h"
2323#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
24+ #include " mlir/IR/BuiltinTypes.h"
2425#include " mlir/IR/Matchers.h"
2526#include " mlir/IR/PatternMatch.h"
2627#include " mlir/Transforms/DialectConversion.h"
@@ -258,7 +259,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
258259 DenseI64ArrayAttr padAttr = op.getPadAttr ();
259260 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr ();
260261 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr ();
261- bool isQuantized = op.getQuantizationInfo ().has_value ();
262+
263+ ElementsAttr inputZpAttr;
264+ ElementsAttr weightZpAttr;
265+ if (!matchPattern (op.getInputZp (), m_Constant (&inputZpAttr)) ||
266+ !matchPattern (op.getWeightZp (), m_Constant (&weightZpAttr)))
267+ return rewriter.notifyMatchFailure (
268+ op,
269+ " bail out if the actual value of zero points cannot be determined" );
270+
271+ // Get and verify explicit zero points.
272+ int64_t inputZpVal;
273+ int64_t weightZpVal;
274+
275+ if (tosa::getZeroPoint (inputZpAttr, inputZpVal).failed () ||
276+ tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf (inputZpAttr),
277+ inputZpVal)
278+ .failed ())
279+ return rewriter.notifyMatchFailure (
280+ op, " input zero point must be zero for non-int8 integer types" );
281+
282+ if (tosa::getZeroPoint (weightZpAttr, weightZpVal).failed () ||
283+ tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf (weightZpAttr),
284+ weightZpVal)
285+ .failed ())
286+ return rewriter.notifyMatchFailure (
287+ op, " weight zero point must be zero for non-int8 integer types" );
288+
289+ const bool hasZp =
290+ (inputZpVal != 0 ) || (weightZpVal != 0 ) || isa<IntegerType>(inputETy);
262291
263292 if (!weightTy.hasStaticShape () || !biasTy.hasStaticShape ())
264293 return rewriter.notifyMatchFailure (
@@ -284,22 +313,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
284313
285314 // Apply padding as necessary.
286315 TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
287- if (isQuantized) {
288- auto quantizationInfo = *op.getQuantizationInfo ();
289- int64_t iZp = quantizationInfo.getInputZp ();
290-
316+ if (hasZp) {
291317 int64_t intMin =
292318 APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
293319 .getSExtValue ();
294320 int64_t intMax =
295321 APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
296322 .getSExtValue ();
297323
298- if (iZp < intMin || iZp > intMax)
324+ if (inputZpVal < intMin || inputZpVal > intMax)
299325 return rewriter.notifyMatchFailure (
300326 op, " tosa.conv op quantization has zp outside of input range" );
301327
302- zeroAttr = rewriter.getIntegerAttr (inputETy, iZp );
328+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
303329 }
304330
305331 llvm::SmallVector<int64_t > pad;
@@ -312,8 +338,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
312338 // For 2D convolutions, we need to check if the target convolution op
313339 // wants a HWCF kernel layout.
314340 bool wantHwcf =
315- isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
316- : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
341+ hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
342+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317343 if (wantHwcf) {
318344 // Transpose the kernel to match dimension ordering of the linalg
319345 // convolution operation.
@@ -374,10 +400,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
374400 Value broadcastBias =
375401 linalgBroadcastAndMaybeExtSI (rewriter, loc, bias, biasEmptyTensor);
376402
377- if (isQuantized) {
378- auto quantizationInfo = *op.getQuantizationInfo ();
379- auto iZp = rewriter.getI32IntegerAttr (quantizationInfo.getInputZp ());
380- auto kZp = rewriter.getI32IntegerAttr (quantizationInfo.getWeightZp ());
403+ if (hasZp) {
404+ auto iZp = rewriter.getI32IntegerAttr (inputZpVal);
405+ auto kZp = rewriter.getI32IntegerAttr (weightZpVal);
381406
382407 auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
383408 auto kZpVal = rewriter.create <arith::ConstantOp>(loc, kZp );
@@ -440,25 +465,40 @@ class DepthwiseConvConverter
440465 /* inputSizeDims=*/ {1 , 2 },
441466 /* kernelSizeDims=*/ {0 , 1 }, rewriter);
442467
443- bool isQuantized = op->hasAttr (" quantization_info" );
444- IntegerAttr iZp;
445- IntegerAttr kZp ;
446- if (isQuantized) {
447- auto quantizationInfo =
448- cast<tosa::ConvOpQuantizationAttr>(op->getAttr (" quantization_info" ));
449- iZp = rewriter.getI32IntegerAttr (quantizationInfo.getInputZp ());
450- kZp = rewriter.getI32IntegerAttr (quantizationInfo.getWeightZp ());
451- }
468+ ElementsAttr inputZpAttr;
469+ ElementsAttr weightZpAttr;
470+ if (!matchPattern (op.getInputZp (), m_Constant (&inputZpAttr)) ||
471+ !matchPattern (op.getWeightZp (), m_Constant (&weightZpAttr)))
472+ return rewriter.notifyMatchFailure (
473+ op,
474+ " bail out if the actual value of zero points cannot be determined" );
475+
476+ // Get and verify explicit zero points.
477+ int64_t inputZpVal;
478+ int64_t weightZpVal;
479+
480+ if (tosa::getZeroPoint (inputZpAttr, inputZpVal).failed () ||
481+ tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
482+ getElementTypeOrSelf (inputZpAttr), inputZpVal)
483+ .failed ())
484+ return rewriter.notifyMatchFailure (
485+ op, " input zero point must be zero for non-int8 integer types" );
486+
487+ if (tosa::getZeroPoint (weightZpAttr, weightZpVal).failed () ||
488+ tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
489+ getElementTypeOrSelf (weightZpAttr), weightZpVal)
490+ .failed ())
491+ return rewriter.notifyMatchFailure (
492+ op, " weight zero point must be zero for non-int8 integer types" );
452493
494+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
453495 auto weightShape = weightTy.getShape ();
454496 auto resultShape = resultTy.getShape ();
455497
456498 // Apply padding as necessary.
457499 TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
458- if (isQuantized) {
459- auto quantizationInfo =
460- cast<tosa::ConvOpQuantizationAttr>(op->getAttr (" quantization_info" ));
461- int64_t iZp = quantizationInfo.getInputZp ();
500+ if (inputZpVal) {
501+ const int64_t iZp = inputZpVal;
462502
463503 int64_t intMin =
464504 APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
@@ -512,7 +552,7 @@ class DepthwiseConvConverter
512552 indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
513553 indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
514554
515- if (!isQuantized ) {
555+ if (!hasZp && isa<FloatType>(inputETy) ) {
516556 Value conv = rewriter
517557 .create <linalg::DepthwiseConv2DNhwcHwcmOp>(
518558 loc, linalgConvTy, ValueRange{input, weight},
@@ -539,6 +579,8 @@ class DepthwiseConvConverter
539579 .getResult (0 );
540580 rewriter.replaceOp (op, result);
541581 } else {
582+ IntegerAttr iZp = rewriter.getI32IntegerAttr (inputZpVal);
583+ IntegerAttr kZp = rewriter.getI32IntegerAttr (weightZpVal);
542584 auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
543585 auto kZpVal = rewriter.create <arith::ConstantOp>(loc, kZp );
544586 Value conv =
0 commit comments