@@ -259,11 +259,21 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
259259 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr ();
260260 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr ();
261261
262- auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
263- if (llvm::failed (failureOrMaybeZps))
264- return failure ();
262+ // Get and verify zero points.
263+ int64_t inputZpVal;
264+ int64_t weightZpVal;
265+
266+ if (op.getInputZeroPoint (inputZpVal).failed () ||
267+ op.getWeightZeroPoint (weightZpVal).failed ())
268+ return rewriter.notifyMatchFailure (
269+ op, " bail out if zero points cannot statically be determined" );
270+
271+ if (op.verifyInputZeroPoint (inputZpVal).failed () ||
272+ op.verifyWeightZeroPoint (weightZpVal).failed ())
273+ return rewriter.notifyMatchFailure (
274+ op, " zero point must be zero for non-int8 integer types" );
265275
266- auto maybeZps = failureOrMaybeZps. value ( );
276+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
267277
268278 if (!weightTy.hasStaticShape () || !biasTy.hasStaticShape ())
269279 return rewriter.notifyMatchFailure (
@@ -289,19 +299,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
289299
290300 // Apply padding as necessary.
291301 TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
292- if (maybeZps ) {
302+ if (hasZp ) {
293303 int64_t intMin =
294304 APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
295305 .getSExtValue ();
296306 int64_t intMax =
297307 APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
298308 .getSExtValue ();
299309
300- if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
310+ if (inputZpVal < intMin || inputZpVal > intMax)
301311 return rewriter.notifyMatchFailure (
302312 op, " tosa.conv op quantization has zp outside of input range" );
303313
304- zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
314+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
305315 }
306316
307317 llvm::SmallVector<int64_t > pad;
@@ -314,8 +324,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
314324 // For 2D convolutions, we need to check if the target convolution op
315325 // wants a HWCF kernel layout.
316326 bool wantHwcf =
317- maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
318- : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
327+ hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
328+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
319329 if (wantHwcf) {
320330 // Transpose the kernel to match dimension ordering of the linalg
321331 // convolution operation.
@@ -372,9 +382,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
372382 Value broadcastBias =
373383 linalgBroadcastAndMaybeExtSI (rewriter, loc, bias, biasEmptyTensor);
374384
375- if (maybeZps ) {
376- auto iZp = rewriter.getI32IntegerAttr (maybeZps-> inputZp );
377- auto kZp = rewriter.getI32IntegerAttr (maybeZps-> weightZp );
385+ if (hasZp ) {
386+ auto iZp = rewriter.getI32IntegerAttr (inputZpVal );
387+ auto kZp = rewriter.getI32IntegerAttr (weightZpVal );
378388
379389 auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
380390 auto kZpVal = rewriter.create <arith::ConstantOp>(loc, kZp );
@@ -437,31 +447,40 @@ class DepthwiseConvConverter
437447 /* inputSizeDims=*/ {1 , 2 },
438448 /* kernelSizeDims=*/ {0 , 1 }, rewriter);
439449
440- auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
441- if (llvm::failed (failureOrMaybeZps))
442- return failure ();
450+ // Get and verify zero points.
451+ int64_t inputZpVal;
452+ int64_t weightZpVal;
453+
454+ if (op.getInputZeroPoint (inputZpVal).failed () ||
455+ op.getWeightZeroPoint (weightZpVal).failed ())
456+ return rewriter.notifyMatchFailure (
457+ op, " bail out if zero points cannot statically be determined" );
443458
444- auto maybeZps = failureOrMaybeZps.value ();
459+ if (op.verifyInputZeroPoint (inputZpVal).failed () ||
460+ op.verifyWeightZeroPoint (weightZpVal).failed ())
461+ return rewriter.notifyMatchFailure (
462+ op, " zero point must be zero for non-int8 integer types" );
445463
464+ bool hasZp = (inputZpVal != 0 ) || (weightZpVal != 0 );
446465 auto weightShape = weightTy.getShape ();
447466 auto resultShape = resultTy.getShape ();
448467
449468 // Apply padding as necessary.
450469 TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
451- if (maybeZps ) {
470+ if (hasZp ) {
452471 int64_t intMin =
453472 APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
454473 .getSExtValue ();
455474 int64_t intMax =
456475 APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
457476 .getSExtValue ();
458477
459- if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
478+ if (inputZpVal < intMin || inputZpVal > intMax)
460479 return rewriter.notifyMatchFailure (
461480 op, " tosa.depthwise_conv op quantization has zp outside of input "
462481 " range" );
463482
464- zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
483+ zeroAttr = rewriter.getIntegerAttr (inputETy, inputZpVal );
465484 }
466485
467486 llvm::SmallVector<int64_t > pad;
@@ -501,7 +520,7 @@ class DepthwiseConvConverter
501520 indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
502521 indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
503522
504- if (!maybeZps ) {
523+ if (!hasZp ) {
505524 Value conv = rewriter
506525 .create <linalg::DepthwiseConv2DNhwcHwcmOp>(
507526 loc, linalgConvTy, ValueRange{input, weight},
@@ -528,8 +547,8 @@ class DepthwiseConvConverter
528547 .getResult (0 );
529548 rewriter.replaceOp (op, result);
530549 } else {
531- IntegerAttr iZp = rewriter.getI32IntegerAttr (maybeZps-> inputZp );
532- IntegerAttr wZp = rewriter.getI32IntegerAttr (maybeZps-> weightZp );
550+ IntegerAttr iZp = rewriter.getI32IntegerAttr (inputZpVal );
551+ IntegerAttr wZp = rewriter.getI32IntegerAttr (weightZpVal );
533552 auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
534553 auto kZpVal = rewriter.create <arith::ConstantOp>(loc, wZp);
535554 Value conv =
0 commit comments