@@ -48,6 +48,26 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
4848 return failure ();
4949 }
5050
51+ Type inputETy = inputType.getElementType ();
52+ Type weightETy = weightType.getElementType ();
53+ Type resultETy = resultType.getElementType ();
54+ if (!inputETy.isIntOrFloat () || !weightETy.isIntOrFloat ())
55+ return rewriter.notifyMatchFailure (op, " unsupported type" );
56+
57+ // Get and verify zero points.
58+ int64_t iZp;
59+ int64_t wZp;
60+
61+ if (op.getInputZeroPoint (iZp).failed () ||
62+ op.getWeightZeroPoint (wZp).failed ())
63+ return rewriter.notifyMatchFailure (
64+ op, " bail out if zero points cannot statically be determined" );
65+
66+ if (op.verifyInputZeroPoint (iZp).failed () ||
67+ op.verifyWeightZeroPoint (wZp).failed ())
68+ return rewriter.notifyMatchFailure (
69+ op, " zero point must be zero for non-int8 integer types" );
70+
5171 // Reshape input to [N, H, W, C] -> [N, H, W, C, 1].
5272 ArrayRef<int64_t > inputShape = inputType.getShape ();
5373 llvm::SmallVector<int64_t , 2 > revisedInputShape{
@@ -62,10 +82,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
6282 revisedInputShapeValue)
6383 .getResult ();
6484
65- Type inputETy = inputType.getElementType ();
66- Type weightETy = weightType.getElementType ();
67- Type resultETy = resultType.getElementType ();
68-
6985 if (inputETy != resultETy) {
7086 inputType = inputType.clone (resultETy);
7187 input = rewriter.create <tosa::CastOp>(op.getLoc (), inputType, input);
@@ -76,20 +92,6 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
7692 weight = rewriter.create <tosa::CastOp>(op.getLoc (), weightType, weight);
7793 }
7894
79- // Get and verify zero points.
80- int64_t iZp;
81- int64_t wZp;
82-
83- if (op.getInputZeroPoint (iZp).failed () ||
84- op.getWeightZeroPoint (wZp).failed ())
85- return rewriter.notifyMatchFailure (
86- op, " bail out if zero points cannot statically be determined" );
87-
88- if (op.verifyInputZeroPoint (iZp).failed () ||
89- op.verifyWeightZeroPoint (wZp).failed ())
90- return rewriter.notifyMatchFailure (
91- op, " zero point must be zero for non-int8 integer types" );
92-
9395 if (iZp != 0 || wZp != 0 ) {
9496
9597 auto applyZp = [&](Value val, int64_t zp) -> Value {
0 commit comments