Skip to content

Commit f95475f

Browse files
committed
Use QUInt8 when quantizing models produced by onnxruntime-genai
1 parent 01dd2d1 commit f95475f

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

scripts/quantize.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,27 @@ class QuantMode(Enum):
3636

3737
QUANTIZE_OPTIONS = tuple(x.value for x in QuantMode)
3838

39+
# A list of operators that, when detected in a model, should select QUInt8 as the weight type for 8-bit quantization.
40+
QUINT8_OPS = (
41+
# NOTE:
42+
# As of 2024/11/29, the latest version of onnxruntime-web is 1.20.1, and does not support INT8 weights for Conv layers.
43+
# If you attempt to run a model with INT8 weights for Conv layers, you will get an error like:
44+
# `Can't create a session. ERROR_CODE: 9, ERROR_MESSAGE: Could not find an implementation for ConvInteger(10) node with name '/.../Conv_quant'`
45+
#
46+
# For this reason, we choose model weight types to ensure compatibility with onnxruntime-web.
47+
#
48+
# As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer.
49+
# For more information, see:
50+
# - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
51+
# - https://github.com/microsoft/onnxruntime/issues/2339
52+
"Conv",
53+
54+
# Models produced by onnxruntime-genai contain optimized operators that perform better with QUInt8 weights.
55+
"GroupQueryAttention",
56+
"MultiHeadAttention",
57+
58+
# TODO: "SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization"
59+
)
3960

4061
@dataclass
4162
class IOArguments:
@@ -326,20 +347,11 @@ def quantize(input_folder, output_folder, quantization_args: QuantizationArgumen
326347

327348
elif mode in (QuantMode.Q8, QuantMode.QI8, QuantMode.QU8):
328349
if mode == QuantMode.Q8:
329-
# NOTE:
330-
# As of 2024/06/28, the current latest version of onnxruntime-web is 1.18.0, and does not support INT8 weights for Conv layers.
331-
# If you attempt to run a model with INT8 weights for Conv layers, you will get an error like:
332-
# `Can't create a session. ERROR_CODE: 9, ERROR_MESSAGE: Could not find an implementation for ConvInteger(10) node with name '/.../Conv_quant'`
333-
#
334-
# For this reason, we choose model weight types to ensure compatibility with onnxruntime-web.
335-
#
336-
# As per docs, signed weight type (QInt8) is faster on most CPUs, so, we use that unless the model contains a Conv layer.
337-
# For more information, see:
338-
# - https://github.com/microsoft/onnxruntime/issues/3130#issuecomment-1105200621
339-
# - https://github.com/microsoft/onnxruntime/issues/2339
340350
op_types = get_operators(model)
341351
weight_type = (
342-
QuantType.QUInt8 if "Conv" in op_types else QuantType.QInt8
352+
QuantType.QUInt8
353+
if any(x in QUINT8_OPS for x in op_types)
354+
else QuantType.QInt8
343355
)
344356

345357
elif mode == QuantMode.QI8:

0 commit comments

Comments
 (0)