Skip to content

Commit 04fa90e

Browse files
authored
Merge pull request #424 from Xilinx/kosh.pow.to.mul.canonicalization.for.qdq
Enable Pow to Mul canonicalization for quantized exponents
2 parents b205103 + 3572831 commit 04fa90e

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

src/Dialect/ONNX/ONNXOps/OpHelper.cpp

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,8 +596,8 @@ bool isDenseONNXConstant(Value result) {
596596
template <typename RESULT_TYPE>
597597
RESULT_TYPE getScalarValue(ElementsAttr denseAttr, Type type) {
598598
Type elementaryType = getElementTypeOrSelf(type);
599-
if (elementaryType.isInteger(16) || elementaryType.isInteger(32) ||
600-
elementaryType.isInteger(64)) {
599+
if (elementaryType.isInteger(8) || elementaryType.isInteger(16) ||
600+
elementaryType.isInteger(32) || elementaryType.isInteger(64)) {
601601
auto valueIt = denseAttr.getValues<IntegerAttr>().begin();
602602
return static_cast<RESULT_TYPE>(mlir::cast<IntegerAttr>(*valueIt).getInt());
603603
} else if (mlir::isa<FloatType>(elementaryType)) {
@@ -794,6 +794,38 @@ IgnoreDiagnostic::~IgnoreDiagnostic() {
794794

795795
bool hasIntegerPowerExponent(ONNXPowOp *op, int64_t &exponentValue) {
796796
Value exponent = op->getY();
797+
// In case of QDQ quantized models: If exponent is from a DequantizeLinear op,
798+
// we want to check the dequantized value of the exponent
799+
if (auto dequantizeOp = mlir::dyn_cast_or_null<ONNXDequantizeLinearOp>(
800+
exponent.getDefiningOp())) {
801+
ElementsAttr xAttr = getElementAttributeFromONNXValue(dequantizeOp.getX());
802+
ElementsAttr scaleAttr =
803+
getElementAttributeFromONNXValue(dequantizeOp.getXScale());
804+
ElementsAttr zeroPointAttr =
805+
getElementAttributeFromONNXValue(dequantizeOp.getXZeroPoint());
806+
807+
if (!(isScalarConstantTensor(dequantizeOp.getXScale()) &&
808+
isScalarConstantTensor(dequantizeOp.getXZeroPoint())))
809+
return false;
810+
811+
auto x = getScalarValue<double>(xAttr, xAttr.getElementType());
812+
auto scale = getScalarValue<double>(scaleAttr, scaleAttr.getElementType());
813+
auto zeroPoint =
814+
getScalarValue<double>(zeroPointAttr, zeroPointAttr.getElementType());
815+
816+
// Calculate dequantized value for exponent (This is an approximation and
817+
// isn't expected to match the actual calculation done by the
818+
// DequantizeLinear op. However, it should be good enough for checking that
819+
// the exponent is an integer)
820+
double dequantizedExponent = (x - zeroPoint) * scale;
821+
822+
if (dequantizedExponent == ceil(dequantizedExponent)) {
823+
exponentValue = static_cast<int64_t>(dequantizedExponent);
824+
return true;
825+
}
826+
return false;
827+
}
828+
797829
ElementsAttr elementAttr = getElementAttributeFromONNXValue(exponent);
798830
if (!elementAttr)
799831
return false;

test/mlir/onnx/onnx_canonicalization.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,6 +1387,30 @@ func.func @expand_pow_into_constant(%arg0: tensor<3x4x5xf32>) -> tensor<3x4x5xf3
13871387
// CHECK: onnx.Return [[VAR_0_]] : tensor<3x4x5xf32>
13881388
// CHECK: }
13891389
}
1390+
// -----
1391+
1392+
func.func @test_pow_into_mul_with_qdq(%arg0: tensor<1x3x80x80x2xi8>) -> tensor<1x3x80x80x2xi8> {
1393+
%0 = onnx.Constant dense<2.500000e-01> : tensor<f32>
1394+
%1 = onnx.Constant dense<3.125000e-02> : tensor<f32>
1395+
%2 = onnx.Constant dense<64> : tensor<i8>
1396+
%3 = onnx.Constant dense<0> : tensor<i8>
1397+
%6 = "onnx.DequantizeLinear"(%arg0, %0, %3) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x3x80x80x2xi8>, tensor<f32>, tensor<i8>) -> tensor<1x3x80x80x2xf32>
1398+
%7 = "onnx.DequantizeLinear"(%2, %1, %3) {axis = 1 : si64, block_size = 0 : si64} : (tensor<i8>, tensor<f32>, tensor<i8>) -> tensor<f32>
1399+
%8 = "onnx.Pow"(%6, %7) : (tensor<1x3x80x80x2xf32>, tensor<f32>) -> tensor<1x3x80x80x2xf32>
1400+
%9 = "onnx.QuantizeLinear"(%8, %1, %3) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x3x80x80x2xf32>, tensor<f32>, tensor<i8>) -> tensor<1x3x80x80x2xi8>
1401+
return %9 : tensor<1x3x80x80x2xi8>
1402+
1403+
// CHECK-LABEL: func.func @test_pow_into_mul_with_qdq
1404+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x80x80x2xi8>) -> tensor<1x3x80x80x2xi8> {
1405+
// CHECK: [[VAR_0_:%.+]] = onnx.Constant dense<2.500000e-01> : tensor<f32>
1406+
// CHECK: [[VAR_1_:%.+]] = onnx.Constant dense<3.125000e-02> : tensor<f32>
1407+
// CHECK: [[VAR_2_:%.+]] = onnx.Constant dense<0> : tensor<i8>
1408+
// CHECK: [[VAR_3_:%.+]] = "onnx.DequantizeLinear"([[PARAM_0_]], [[VAR_0_]], [[VAR_2_]]) {axis = 1 : si64, block_size = 0 : si64} : (tensor<1x3x80x80x2xi8>, tensor<f32>, tensor<i8>) -> tensor<1x3x80x80x2xf32>
1409+
// CHECK: [[VAR_4_:%.+]] = "onnx.Mul"([[VAR_3_]], [[VAR_3_]]) : (tensor<1x3x80x80x2xf32>, tensor<1x3x80x80x2xf32>) -> tensor<1x3x80x80x2xf32>
1410+
// CHECK: [[VAR_5_:%.+]] = "onnx.QuantizeLinear"([[VAR_4_]], [[VAR_1_]], [[VAR_2_]]) {axis = 1 : si64, block_size = 0 : si64, output_dtype = 0 : si64, saturate = 1 : si64} : (tensor<1x3x80x80x2xf32>, tensor<f32>, tensor<i8>) -> tensor<1x3x80x80x2xi8>
1411+
// CHECK: return [[VAR_5_]] : tensor<1x3x80x80x2xi8>
1412+
// CHECK: }
1413+
}
13901414

13911415
// -----
13921416

0 commit comments

Comments
 (0)