@@ -596,8 +596,8 @@ bool isDenseONNXConstant(Value result) {
596596template <typename RESULT_TYPE>
597597RESULT_TYPE getScalarValue (ElementsAttr denseAttr, Type type) {
598598 Type elementaryType = getElementTypeOrSelf (type);
599- if (elementaryType.isInteger (16 ) || elementaryType.isInteger (32 ) ||
600- elementaryType.isInteger (64 )) {
599+ if (elementaryType.isInteger (8 ) || elementaryType.isInteger (16 ) ||
600+ elementaryType.isInteger (32 ) || elementaryType. isInteger ( 64 )) {
601601 auto valueIt = denseAttr.getValues <IntegerAttr>().begin ();
602602 return static_cast <RESULT_TYPE>(mlir::cast<IntegerAttr>(*valueIt).getInt ());
603603 } else if (mlir::isa<FloatType>(elementaryType)) {
@@ -794,6 +794,38 @@ IgnoreDiagnostic::~IgnoreDiagnostic() {
794794
795795bool hasIntegerPowerExponent (ONNXPowOp *op, int64_t &exponentValue) {
796796 Value exponent = op->getY ();
797+ // In case of QDQ quantized models: If exponent is from a DequantizeLinear op,
798+ // we want to check the dequantized value of the exponent
799+ if (auto dequantizeOp = mlir::dyn_cast_or_null<ONNXDequantizeLinearOp>(
800+ exponent.getDefiningOp ())) {
801+ ElementsAttr xAttr = getElementAttributeFromONNXValue (dequantizeOp.getX ());
802+ ElementsAttr scaleAttr =
803+ getElementAttributeFromONNXValue (dequantizeOp.getXScale ());
804+ ElementsAttr zeroPointAttr =
805+ getElementAttributeFromONNXValue (dequantizeOp.getXZeroPoint ());
806+
807+ if (!(isScalarConstantTensor (dequantizeOp.getXScale ()) &&
808+ isScalarConstantTensor (dequantizeOp.getXZeroPoint ())))
809+ return false ;
810+
811+ auto x = getScalarValue<double >(xAttr, xAttr.getElementType ());
812+ auto scale = getScalarValue<double >(scaleAttr, scaleAttr.getElementType ());
813+ auto zeroPoint =
814+ getScalarValue<double >(zeroPointAttr, zeroPointAttr.getElementType ());
815+
816+ // Calculate dequantized value for exponent (This is an approximation and
817+ // isn't expected to match the actual calculation done by the
818+ // DequantizeLinear op. However, it should be good enough for checking that
819+ // the exponent is an integer)
820+ double dequantizedExponent = (x - zeroPoint) * scale;
821+
822+ if (dequantizedExponent == ceil (dequantizedExponent)) {
823+ exponentValue = static_cast <int64_t >(dequantizedExponent);
824+ return true ;
825+ }
826+ return false ;
827+ }
828+
797829 ElementsAttr elementAttr = getElementAttributeFromONNXValue (exponent);
798830 if (!elementAttr)
799831 return false ;
0 commit comments