Skip to content

Commit 23803bf

Browse files
committed
feat: pow to mul canonicalization with quantized exponent
1 parent b205103 commit 23803bf

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

src/Dialect/ONNX/ONNXOps/OpHelper.cpp

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,8 +596,8 @@ bool isDenseONNXConstant(Value result) {
596596
template <typename RESULT_TYPE>
597597
RESULT_TYPE getScalarValue(ElementsAttr denseAttr, Type type) {
598598
Type elementaryType = getElementTypeOrSelf(type);
599-
if (elementaryType.isInteger(16) || elementaryType.isInteger(32) ||
600-
elementaryType.isInteger(64)) {
599+
if (elementaryType.isInteger(8) || elementaryType.isInteger(16) ||
600+
elementaryType.isInteger(32) || elementaryType.isInteger(64)) {
601601
auto valueIt = denseAttr.getValues<IntegerAttr>().begin();
602602
return static_cast<RESULT_TYPE>(mlir::cast<IntegerAttr>(*valueIt).getInt());
603603
} else if (mlir::isa<FloatType>(elementaryType)) {
@@ -794,6 +794,30 @@ IgnoreDiagnostic::~IgnoreDiagnostic() {
794794

795795
bool hasIntegerPowerExponent(ONNXPowOp *op, int64_t &exponentValue) {
796796
Value exponent = op->getY();
797+
// In case of QDQ quantized models: If exponent is from a DequantizeLinear op,
798+
// we want to check the dequantized value of the exponent
799+
if (auto dequantizeOp = mlir::dyn_cast_or_null<ONNXDequantizeLinearOp>(
800+
exponent.getDefiningOp())) {
801+
ElementsAttr xAttr = getElementAttributeFromONNXValue(dequantizeOp.getX());
802+
ElementsAttr scaleAttr =
803+
getElementAttributeFromONNXValue(dequantizeOp.getXScale());
804+
ElementsAttr zeroPointAttr =
805+
getElementAttributeFromONNXValue(dequantizeOp.getXZeroPoint());
806+
auto x = getScalarValue<double>(xAttr, xAttr.getElementType());
807+
auto scale = getScalarValue<double>(scaleAttr, scaleAttr.getElementType());
808+
auto zeroPoint =
809+
getScalarValue<double>(zeroPointAttr, zeroPointAttr.getElementType());
810+
811+
// Calculate dequantized value for exponent
812+
double dequantizedExponent = (x - zeroPoint) * scale;
813+
814+
if (dequantizedExponent == ceil(dequantizedExponent)) {
815+
exponentValue = static_cast<int64_t>(dequantizedExponent);
816+
return true;
817+
}
818+
return false;
819+
}
820+
797821
ElementsAttr elementAttr = getElementAttributeFromONNXValue(exponent);
798822
if (!elementAttr)
799823
return false;

0 commit comments

Comments
 (0)