@@ -258,7 +258,12 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
258258 DenseI64ArrayAttr padAttr = op.getPadAttr ();
259259 DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr ();
260260 DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr ();
261- bool isQuantized = op.getQuantizationInfo ().has_value ();
261+
262+ auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
263+ if (llvm::failed (failureOrMaybeZps))
264+ return failure ();
265+
266+ auto maybeZps = failureOrMaybeZps.value ();
262267
263268 if (!weightTy.hasStaticShape () || !biasTy.hasStaticShape ())
264269 return rewriter.notifyMatchFailure (
@@ -284,22 +289,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
284289
285290 // Apply padding as necessary.
286291 TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
287- if (isQuantized) {
288- auto quantizationInfo = *op.getQuantizationInfo ();
289- int64_t iZp = quantizationInfo.getInputZp ();
290-
292+ if (maybeZps) {
291293 int64_t intMin =
292294 APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
293295 .getSExtValue ();
294296 int64_t intMax =
295297 APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
296298 .getSExtValue ();
297299
298- if (iZp < intMin || iZp > intMax)
300+ if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
299301 return rewriter.notifyMatchFailure (
300302 op, " tosa.conv op quantization has zp outside of input range" );
301303
302- zeroAttr = rewriter.getIntegerAttr (inputETy, iZp );
304+ zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
303305 }
304306
305307 llvm::SmallVector<int64_t > pad;
@@ -312,8 +314,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
312314 // For 2D convolutions, we need to check if the target convolution op
313315 // wants a HWCF kernel layout.
314316 bool wantHwcf =
315- isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
316- : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317+ maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
318+ : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
317319 if (wantHwcf) {
318320 // Transpose the kernel to match dimension ordering of the linalg
319321 // convolution operation.
@@ -374,10 +376,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
374376 Value broadcastBias =
375377 linalgBroadcastAndMaybeExtSI (rewriter, loc, bias, biasEmptyTensor);
376378
377- if (isQuantized) {
378- auto quantizationInfo = *op.getQuantizationInfo ();
379- auto iZp = rewriter.getI32IntegerAttr (quantizationInfo.getInputZp ());
380- auto kZp = rewriter.getI32IntegerAttr (quantizationInfo.getWeightZp ());
379+ if (maybeZps) {
380+ auto iZp = rewriter.getI32IntegerAttr (maybeZps->inputZp );
381+ auto kZp = rewriter.getI32IntegerAttr (maybeZps->weightZp );
381382
382383 auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
383384 auto kZpVal = rewriter.create <arith::ConstantOp>(loc, kZp );
@@ -440,39 +441,31 @@ class DepthwiseConvConverter
440441 /* inputSizeDims=*/ {1 , 2 },
441442 /* kernelSizeDims=*/ {0 , 1 }, rewriter);
442443
443- bool isQuantized = op->hasAttr (" quantization_info" );
444- IntegerAttr iZp;
445- IntegerAttr kZp ;
446- if (isQuantized) {
447- auto quantizationInfo =
448- cast<tosa::ConvOpQuantizationAttr>(op->getAttr (" quantization_info" ));
449- iZp = rewriter.getI32IntegerAttr (quantizationInfo.getInputZp ());
450- kZp = rewriter.getI32IntegerAttr (quantizationInfo.getWeightZp ());
451- }
444+ auto failureOrMaybeZps = extractConvZpPair (op, rewriter);
445+ if (llvm::failed (failureOrMaybeZps))
446+ return failure ();
447+
448+ auto maybeZps = failureOrMaybeZps.value ();
452449
453450 auto weightShape = weightTy.getShape ();
454451 auto resultShape = resultTy.getShape ();
455452
456453 // Apply padding as necessary.
457454 TypedAttr zeroAttr = rewriter.getZeroAttr (inputETy);
458- if (isQuantized) {
459- auto quantizationInfo =
460- cast<tosa::ConvOpQuantizationAttr>(op->getAttr (" quantization_info" ));
461- int64_t iZp = quantizationInfo.getInputZp ();
462-
455+ if (maybeZps) {
463456 int64_t intMin =
464457 APInt::getSignedMinValue (inputETy.getIntOrFloatBitWidth ())
465458 .getSExtValue ();
466459 int64_t intMax =
467460 APInt::getSignedMaxValue (inputETy.getIntOrFloatBitWidth ())
468461 .getSExtValue ();
469462
470- if (iZp < intMin || iZp > intMax)
463+ if (maybeZps-> inputZp < intMin || maybeZps-> inputZp > intMax)
471464 return rewriter.notifyMatchFailure (
472465 op, " tosa.depthwise_conv op quantization has zp outside of input "
473466 " range" );
474467
475- zeroAttr = rewriter.getIntegerAttr (inputETy, iZp );
468+ zeroAttr = rewriter.getIntegerAttr (inputETy, maybeZps-> inputZp );
476469 }
477470
478471 llvm::SmallVector<int64_t > pad;
@@ -512,7 +505,7 @@ class DepthwiseConvConverter
512505 indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
513506 indexingMaps.push_back (rewriter.getMultiDimIdentityMap (resultRank));
514507
515- if (!isQuantized ) {
508+ if (!maybeZps ) {
516509 Value conv = rewriter
517510 .create <linalg::DepthwiseConv2DNhwcHwcmOp>(
518511 loc, linalgConvTy, ValueRange{input, weight},
@@ -539,8 +532,10 @@ class DepthwiseConvConverter
539532 .getResult (0 );
540533 rewriter.replaceOp (op, result);
541534 } else {
535+ IntegerAttr iZp = rewriter.getI32IntegerAttr (maybeZps->inputZp );
536+ IntegerAttr wZp = rewriter.getI32IntegerAttr (maybeZps->weightZp );
542537 auto iZpVal = rewriter.create <arith::ConstantOp>(loc, iZp);
543- auto kZpVal = rewriter.create <arith::ConstantOp>(loc, kZp );
538+ auto kZpVal = rewriter.create <arith::ConstantOp>(loc, wZp );
544539 Value conv =
545540 rewriter
546541 .create <linalg::DepthwiseConv2DNhwcHwcmQOp>(
0 commit comments