Skip to content

Commit f78215a

Browse files
authored
Fix quantization tools for issue #19529 (#19591)
### Description Fix issue #19529, the code was using a variable loop outside a loop.
1 parent a46bab6 commit f78215a

File tree

3 files changed

+77
-7
lines changed

3 files changed

+77
-7
lines changed

onnxruntime/python/tools/quantization/calibrate.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -733,13 +733,11 @@ def collect_absolute_value(self, name_to_arr):
733733
for tensor, data_arr in name_to_arr.items():
734734
if isinstance(data_arr, list):
735735
for arr in data_arr:
736-
if not isinstance(arr, np.ndarray):
737-
raise ValueError(f"Unexpected type {type(arr)} for tensor={tensor!r}")
738-
dtypes = set(a.dtype for a in arr)
739-
if len(dtypes) != 1:
740-
raise ValueError(
741-
f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
742-
)
736+
assert isinstance(arr, np.ndarray), f"Unexpected type {type(arr)} for tensor={tensor!r}"
737+
dtypes = set(a.dtype for a in data_arr)
738+
assert (
739+
len(dtypes) == 1
740+
), f"The calibration expects only one element type but got {dtypes} for tensor={tensor!r}"
743741
data_arr_np = np.asarray(data_arr)
744742
elif not isinstance(data_arr, np.ndarray):
745743
raise ValueError(f"Unexpected type {type(data_arr)} for tensor={tensor!r}")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for
4+
# license information.
5+
# --------------------------------------------------------------------------
6+
import os
7+
import tempfile
8+
import unittest
9+
import warnings
10+
11+
12+
def ignore_warnings(warns):
13+
"""
14+
Catches warnings.
15+
16+
:param warns: warnings to ignore
17+
"""
18+
19+
def wrapper(fct):
20+
if warns is None:
21+
raise AssertionError(f"warns cannot be None for '{fct}'.")
22+
23+
def call_f(self):
24+
with warnings.catch_warnings():
25+
warnings.simplefilter("ignore", warns)
26+
return fct(self)
27+
28+
return call_f
29+
30+
return wrapper
31+
32+
33+
class TestQuantIssues(unittest.TestCase):
34+
@ignore_warnings(DeprecationWarning)
35+
def test_minimal_model(self):
36+
folder = os.path.join(os.path.dirname(__file__), "..", "..", "testdata")
37+
onnx_path = os.path.join(folder, "qdq_minimal_model.onnx")
38+
if not os.path.exists(onnx_path):
39+
# The file does seem to be the same location in every CI job.
40+
raise unittest.SkipTest("unable to find {onnx_path!r}")
41+
42+
import numpy as np
43+
44+
import onnxruntime.quantization as oq
45+
46+
class Mock:
47+
def __init__(self):
48+
self.i = 0
49+
50+
def get_next(self):
51+
if self.i > 10:
52+
return None
53+
self.i += 1
54+
return {"input": np.random.randint(0, 255, size=(1, 3, 32, 32), dtype=np.uint8)}
55+
56+
with tempfile.TemporaryDirectory() as temp:
57+
preprocessed_path = os.path.join(temp, "preprocessed.onnx")
58+
quantized_path = os.path.join(temp, "quantized.onnx")
59+
oq.quant_pre_process(onnx_path, preprocessed_path, skip_symbolic_shape=True)
60+
oq.quantize_static(
61+
preprocessed_path,
62+
quantized_path,
63+
Mock(),
64+
calibrate_method=oq.CalibrationMethod.Percentile,
65+
op_types_to_quantize=["Conv", "Mul", "Gemm"],
66+
)
67+
assert os.path.exists(preprocessed_path), f"missing output {preprocessed_path!r}"
68+
assert os.path.exists(quantized_path), f"missing output {quantized_path!r}"
69+
70+
71+
if __name__ == "__main__":
72+
unittest.main(verbosity=2)
6.22 KB
Binary file not shown.

0 commit comments

Comments
 (0)