@@ -259,11 +259,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
259259 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr ();
260260 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr ();
261261
262- auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
263- if (llvm::failed (failureOrMaybeZps))
264- return failure ();
262+ // Get and verify zero points.
263+ int64_t inputZpVal;
264+ int64_t weightZpVal;
265+
266+ if (op.getInputZeroPoint (inputZpVal).failed () ||
267+ op.getWeightZeroPoint (weightZpVal).failed ())
268+ return rewriter.notifyMatchFailure (
269+ op, " bail out if zero points cannot statically be determined" );
270+
271+ if (op.verifyInputZeroPoint (inputZpVal).failed () ||
272+ op.verifyWeightZeroPoint (weightZpVal).failed ())
273+ return rewriter.notifyMatchFailure (
274+ op, " zero point must be zero for non-int8 integer types" );
265275
266- auto maybeZps = failureOrMaybeZps. value ( );
276+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
267277
268278 if (!weightTy.hasStaticShape () || !biasTy.hasStaticShape ())
269279 return rewriter.notifyMatchFailure (
@@ -289,19 +299,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
289299
290300 // Apply padding as necessary.
291301 TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
292- if (maybeZps ) {
302+ if (hasZp ) {
293303 int64_t intMin =
294304 APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
295305 .getSExtValue ();
296306 int64_t intMax =
297307 APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
298308 .getSExtValue ();
299309
300- if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
310+ if (inputZpVal < intMin || inputZpVal > intMax)
301311 return rewriter.notifyMatchFailure (
302312 op, " tosa.conv op quantization has zp outside of input range" );
303313
304- zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
314+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
305315 }
306316
307317 llvm::SmallVector<int64_t > pad;
@@ -314,8 +324,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
314324 // For 2D convolutions, we need to check if the target convolution op
315325 // wants a HWCF kernel layout.
316326 bool wantHwcf =
317- maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
318- : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
327+ hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
328+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
319329 if (wantHwcf) {
320330 // Transpose the kernel to match dimension ordering of the linalg
321331 // convolution operation.
@@ -376,9 +386,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
376386 Value broadcastBias =
377387 linalgBroadcastAndMaybeExtSI (rewriter, loc, bias, biasEmptyTensor);
378388
379- if (maybeZps ) {
380- auto iZp = rewriter.getI32IntegerAttr (maybeZps-> inputZp );
381- auto kZp = rewriter.getI32IntegerAttr (maybeZps-> weightZp );
389+ if (hasZp ) {
390+ auto iZp = rewriter.getI32IntegerAttr (inputZpVal );
391+ auto kZp = rewriter.getI32IntegerAttr (weightZpVal );
382392
383393 auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
384394 auto kZpVal = rewriter.create <arith::ConstantOp>(loc, kZp );
@@ -441,31 +451,40 @@ class DepthwiseConvConverter
441451 /* inputSizeDims=*/ {1 , 2 },
442452 /* kernelSizeDims=*/ {0 , 1 }, rewriter);
443453
444- auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
445- if (llvm::failed (failureOrMaybeZps))
446- return failure ();
454+ // Get and verify zero points.
455+ int64_t inputZpVal;
456+ int64_t weightZpVal;
457+
458+ if (op.getInputZeroPoint (inputZpVal).failed () ||
459+ op.getWeightZeroPoint (weightZpVal).failed ())
460+ return rewriter.notifyMatchFailure (
461+ op, " bail out if zero points cannot statically be determined" );
447462
448- auto maybeZps = failureOrMaybeZps.value ();
463+ if (op.verifyInputZeroPoint (inputZpVal).failed () ||
464+ op.verifyWeightZeroPoint (weightZpVal).failed ())
465+ return rewriter.notifyMatchFailure (
466+ op, " zero point must be zero for non-int8 integer types" );
449467
468+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
450469 auto weightShape = weightTy.getShape ();
451470 auto resultShape = resultTy.getShape ();
452471
453472 // Apply padding as necessary.
454473 TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
455- if (maybeZps ) {
474+ if (hasZp ) {
456475 int64_t intMin =
457476 APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
458477 .getSExtValue ();
459478 int64_t intMax =
460479 APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
461480 .getSExtValue ();
462481
463- if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
482+ if (inputZpVal < intMin || inputZpVal > intMax)
464483 return rewriter.notifyMatchFailure (
465484 op, " tosa.depthwise_conv op quantization has zp outside of input "
466485 " range" );
467486
468- zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
487+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
469488 }
470489
471490 llvm::SmallVector<int64_t > pad;
@@ -505,7 +524,7 @@ class DepthwiseConvConverter
505524 indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
506525 indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
507526
508- if (!maybeZps ) {
527+ if (!hasZp ) {
509528 Value conv = rewriter
510529 .create <linalg::DepthwiseConv2DNhwcHwcmOp>(
511530 loc, linalgConvTy, ValueRange{input, weight},
@@ -532,8 +551,8 @@ class DepthwiseConvConverter
532551 .getResult (0 );
533552 rewriter.replaceOp (op, result);
534553 } else {
535- IntegerAttr iZp = rewriter.getI32IntegerAttr (maybeZps-> inputZp );
536- IntegerAttr wZp = rewriter.getI32IntegerAttr (maybeZps-> weightZp );
554+ IntegerAttr iZp = rewriter.getI32IntegerAttr (inputZpVal );
555+ IntegerAttr wZp = rewriter.getI32IntegerAttr (weightZpVal );
537556 auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
538557 auto kZpVal = rewriter.create <arith::ConstantOp>(loc, wZp);
539558 Value conv =
0 commit comments