Skip to content

Commit 3ca8a49

Browse files
[QNN-EP] Replace Upsample with Resize during Quantization (microsoft#24896)
### Description Replace the Upsample with Resize during quantization to avoid causing the invalid graph ### Motivation and Context After the quantization, if the opset of original onnx model is less than 10, the opset of QDQ model will be upgraded to 11. However, Upsample is deprecated in opset 11, which will make the onnx model invalid. So, we replace the Upsample with Resize if the opset needs to be upgraded to 11. --------- Co-authored-by: chuteng <[email protected]>
1 parent b34ae7c commit 3ca8a49

File tree

3 files changed

+115
-1
lines changed

3 files changed

+115
-1
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .fusion import Fusion # noqa: F401
22
from .fusion_gelu import FusionGelu # noqa: F401
33
from .fusion_layernorm import FusionLayerNormalization # noqa: F401
4+
from .replace_upsample_with_resize import ReplaceUpsampleWithResize # noqa: F401
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
from __future__ import annotations
7+
8+
import numpy as np
9+
import onnx
10+
11+
from ..onnx_model import ONNXModel
12+
from .fusion import Fusion
13+
14+
15+
class ReplaceUpsampleWithResize(Fusion):
16+
"""Replace Upsample with Resize."""
17+
18+
def __init__(self, model: ONNXModel, opset):
19+
"""Initialize."""
20+
super().__init__(model, "Resize", "Upsample")
21+
self.opset = opset
22+
23+
def fuse(
24+
self,
25+
node: onnx.NodeProto,
26+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
27+
output_name_to_node: dict[str, onnx.NodeProto],
28+
):
29+
"""Replace Upsample with Resize."""
30+
mode = None
31+
for attr in node.attribute:
32+
if attr.name == "mode":
33+
mode = attr.s.decode("utf-8")
34+
break
35+
36+
scales_input = None
37+
if self.opset > 7:
38+
scales_input = node.input[1] if len(node.input) > 1 else ""
39+
resize_inputs = [node.input[0], node.name + "_roi", scales_input]
40+
else:
41+
if self.opset == 7:
42+
for attr in node.attribute:
43+
if attr.name == "scales":
44+
scales_input = attr.floats
45+
break
46+
47+
scales_input = np.array(list(scales_input), np.float32)
48+
else:
49+
h_scale = 1
50+
w_scale = 1
51+
for attr in node.attribute:
52+
if attr.name == "height_scale":
53+
h_scale = attr.float
54+
elif attr.name == "width_scale":
55+
w_scale = attr.float
56+
57+
scales_input = np.array([1, 1, h_scale, w_scale], np.float32)
58+
59+
scales_tensor = onnx.helper.make_tensor(
60+
name=node.name + "_scales",
61+
data_type=onnx.TensorProto.FLOAT,
62+
dims=scales_input.shape,
63+
vals=scales_input.flatten().tolist(),
64+
)
65+
66+
scales_node = onnx.helper.make_node(
67+
"Constant", inputs=[], outputs=[node.name + "_scales"], value=scales_tensor
68+
)
69+
70+
self.nodes_to_add.append(scales_node)
71+
72+
resize_inputs = [node.input[0], node.name + "_roi", node.name + "_scales"]
73+
74+
roi_tensor = onnx.helper.make_tensor(
75+
name=node.name + "_roi",
76+
data_type=onnx.TensorProto.FLOAT,
77+
dims=(len(scales_input) * 2,),
78+
vals=[0] * len(scales_input) + [1] * len(scales_input),
79+
)
80+
81+
roi_node = onnx.helper.make_node("Constant", inputs=[], outputs=[node.name + "_roi"], value=roi_tensor)
82+
83+
resize_node = onnx.helper.make_node(
84+
op_type="Resize", inputs=resize_inputs, outputs=node.output, mode=mode, nearest_mode="floor"
85+
)
86+
87+
self.nodes_to_remove.append(node)
88+
self.nodes_to_add.append(roi_node)
89+
self.nodes_to_add.append(resize_node)
90+
91+
def apply(self) -> bool:
92+
"""Apply."""
93+
if super().apply():
94+
self.model.topological_sort()
95+
return True
96+
return False

onnxruntime/python/tools/quantization/shape_inference.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
1717
from onnxruntime.transformers.onnx_utils import extract_raw_data_from_model, has_external_data
1818

19-
from .quant_utils import add_pre_process_metadata
19+
from .fusions import ReplaceUpsampleWithResize
20+
from .onnx_model import ONNXModel
21+
from .quant_utils import add_pre_process_metadata, save_and_reload_model_with_shape_infer
2022

2123
logger = logging.getLogger(__name__)
2224

@@ -85,6 +87,21 @@ def quant_pre_process(
8587
verbose,
8688
)
8789

90+
# Since Upsample is deprecated after opset v10, and the model's opset will
91+
# be upgraded to at least v11 during quantization, we need to replace Upsample
92+
# with Resize first to avoid generating an invalid model.
93+
if model:
94+
ai_onnx_domain = [opset for opset in model.opset_import if not opset.domain or opset.domain == "ai.onnx"]
95+
if len(ai_onnx_domain) == 1:
96+
opset_version = ai_onnx_domain[0].version
97+
if opset_version < 10:
98+
ReplaceUpsampleWithResize(ONNXModel(model), opset_version).apply()
99+
model.opset_import.remove(ai_onnx_domain[0])
100+
opset_version = 11
101+
model.opset_import.extend([onnx.helper.make_opsetid("", opset_version)])
102+
model = onnx.version_converter.convert_version(model, opset_version)
103+
model = save_and_reload_model_with_shape_infer(model)
104+
88105
if not skip_optimization:
89106
# Use ORT optimizers (native code) to optimize model
90107
if not skip_symbolic_shape:

0 commit comments

Comments
 (0)