Skip to content

Commit 69c258f

Browse files
authored
[5590225] Fixed regression introduced by PR #364 (FP64-to-FP32 conversion) (#462)
--------- Signed-off-by: gcunhase <[email protected]>
1 parent 47ddd14 commit 69c258f

17 files changed

+276
-23
lines changed

modelopt/onnx/autocast/graphsanitizer.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def convert_fp64_to_fp32(self) -> None:
9292

9393
if modified:
9494
logger.info("Converted FP64 initializers, I/O types, and nodes to FP32")
95-
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True)
9695

9796
def ensure_custom_ops_precision(self) -> None:
9897
"""Ensure that custom ops run in the requested precision."""
@@ -144,9 +143,6 @@ def remove_disconnected_outputs(self) -> None:
144143
def convert_opset(self) -> None:
145144
"""Convert the model to the given opset version.
146145
147-
Args:
148-
min_opset: minimum opset version to use
149-
150146
The method checks all opset imports and converts the model if any are below the minimum version.
151147
"""
152148
# Check all opset imports

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,15 @@ def __init__(
139139
self.max_ir_version = max_ir_version
140140
self.trt_plugins = trt_plugins
141141

142+
# Detect additional ops not supported in low precision according to the model's opset version
143+
self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + (
144+
utils.get_op_types_not_supported_in_low_precision(
145+
self.model,
146+
self.min_opset,
147+
self.low_precision_type.str_full,
148+
)
149+
)
150+
142151
def convert(
143152
self,
144153
high_precision_nodes: list[str],
@@ -446,7 +455,7 @@ def _filter_unsupported_op_types(
446455
# precision so we need to set Resize and Upsample to high precision
447456
for node in self.model.graph.node:
448457
if (
449-
node.op_type in OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION
458+
node.op_type in self.op_types_not_supported_in_low_precision
450459
and node.name in low_precision_nodes
451460
):
452461
low_precision_nodes.remove(node.name)

modelopt/onnx/autocast/utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,13 @@
2121
support the core functionality of model precision conversion.
2222
"""
2323

24+
import logging
25+
from collections import defaultdict
26+
2427
import onnx
2528

29+
from modelopt.onnx.utils import get_opset_version
30+
2631

2732
def setup_mappings(model: onnx.ModelProto) -> tuple[dict, dict, dict]:
2833
"""Setup and return mappings for model components.
@@ -115,3 +120,61 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int:
115120
if attr.name == "to":
116121
return attr.i
117122
raise ValueError("Cast node does not have 'to' attribute")
123+
124+
125+
def get_op_types_not_supported_in_low_precision(
126+
model: onnx.ModelProto,
127+
min_opset: int,
128+
low_precision_type: str = "float16",
129+
) -> list[str]:
130+
"""Get a list of ops not supported in low precision for the opset_version = max(model.opset, min_opset).
131+
132+
An op is considered to be supported if at least one of the inputs may be in low precision.
133+
Ops where only some of the inputs may be in low precision are considered supported by this function
134+
and may need special handling. See PrecisionConverter::_should_skip_low_precision_input_conversion.
135+
136+
Args:
137+
model: ONNX model.
138+
min_opset: Minimum opset version.
139+
low_precision_type: Target precision to reduce to ('float16' or 'bfloat16').
140+
141+
Returns:
142+
ops_without_support: List of ops not supported in low precision for the current opset version.
143+
"""
144+
# Obtain the current model's opset version
145+
opset_version = max(get_opset_version(model), min_opset)
146+
147+
# Get all ops precision support information
148+
precision = "tensor(float16)" if low_precision_type == "float16" else "tensor(bfloat16)"
149+
model_ops = {n.op_type for n in model.graph.node}
150+
schemas_dict = defaultdict(dict)
151+
for schema in onnx.defs.get_all_schemas_with_history():
152+
if schema.name not in model_ops:
153+
continue
154+
float16_supported = False
155+
for constr in schema.type_constraints:
156+
if precision in constr.allowed_type_strs:
157+
float16_supported = True
158+
break
159+
schemas_dict[schema.name].update({schema.since_version: float16_supported})
160+
161+
# Check that all ops are supported in low precision for the current opset version.
162+
# Otherwise, exclude from conversion.
163+
ops_without_support = {}
164+
for op, schema in schemas_dict.items():
165+
supported_opsets = [k for k, v in schema.items() if v]
166+
if supported_opsets:
167+
min_supported_opset = min(supported_opsets)
168+
if min_supported_opset > opset_version:
169+
ops_without_support[op] = min_supported_opset
170+
else:
171+
ops_without_support[op] = None
172+
173+
if ops_without_support:
174+
logging.warning(
175+
f"{len(ops_without_support)} ops are not supported in '{low_precision_type}' in opset {opset_version}, "
176+
f"skipping those from conversion. Upgrade the model's opset version as follows to run them in low "
177+
f" precision: {ops_without_support}."
178+
)
179+
180+
return list(ops_without_support.keys())

modelopt/onnx/quantization/quantize.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,12 @@
6767
remove_input_dq_and_output_q,
6868
)
6969
from modelopt.onnx.trt_utils import interpret_trt_plugins_precision_flag, load_onnx_model
70-
from modelopt.onnx.utils import duplicate_shared_constants, name_onnx_nodes, save_onnx
70+
from modelopt.onnx.utils import (
71+
duplicate_shared_constants,
72+
get_opset_version,
73+
name_onnx_nodes,
74+
save_onnx,
75+
)
7176

7277
__all__ = ["quantize"]
7378

@@ -113,12 +118,7 @@ def _preprocess_onnx(
113118
)
114119

115120
# Per-Channel support with QDQ format requires onnx opset version 13 or above
116-
ai_onnx_domain = [
117-
opset
118-
for opset in onnx_model.opset_import
119-
if not opset.domain or opset.domain in ["ai.onnx", "ai.onnx.contrib"]
120-
]
121-
opset_version = ai_onnx_domain[0].version
121+
opset_version = get_opset_version(onnx_model)
122122

123123
required_opset_version = 13
124124
if opset_version < required_opset_version and opset_version != 1:

modelopt/onnx/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,16 @@ def update_domain(onnx_model: onnx.ModelProto, op_type: str, domain: str) -> onn
686686
return onnx_model
687687

688688

689+
def get_opset_version(model: onnx.ModelProto) -> int:
690+
"""Returns the opset version of the given model."""
691+
ai_onnx_domain = [
692+
opset
693+
for opset in model.opset_import
694+
if not opset.domain or opset.domain in ["ai.onnx", "ai.onnx.contrib", "trt.plugins"]
695+
]
696+
return ai_onnx_domain[0].version
697+
698+
689699
def bfloat16_to_float32(bf16_array):
690700
"""Converts a bfloat16 array (as raw data) to a float32 array."""
691701
uint32_array = bf16_array.astype(np.uint32) << 16

tests/_test_utils/onnx/quantization/lib_test_models.py renamed to tests/_test_utils/onnx/lib_test_models.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,3 +822,102 @@ def build_conv_act_pool_model(include_reshape_node=False):
822822
onnx.checker.check_model(model_inferred)
823823

824824
return model_inferred
825+
826+
827+
def build_conv_isinf_model(opset_version=13):
828+
# Define your model inputs and outputs
829+
input_names = ["input_0"]
830+
output_names = ["output_0"]
831+
input_shapes = [(6, 32, 900, 256)]
832+
output_shapes = [(6, 32, 900, 256)]
833+
834+
inputs = [
835+
helper.make_tensor_value_info(input_name, onnx.TensorProto.FLOAT, input_shape)
836+
for input_name, input_shape in zip(input_names, input_shapes)
837+
]
838+
outputs = [
839+
helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, output_shape)
840+
for output_name, output_shape in zip(output_names, output_shapes)
841+
]
842+
843+
# Create the ONNX graph with the nodes
844+
nodes = [
845+
helper.make_node(
846+
op_type="Conv",
847+
inputs=["input_0", "weights_1"],
848+
outputs=["conv1_conv/Conv2D:0"],
849+
name="conv1_conv/Conv2D",
850+
dilations=[1, 1],
851+
group=1,
852+
kernel_shape=[3, 3],
853+
pads=[1, 1, 1, 1],
854+
strides=[1, 1],
855+
),
856+
helper.make_node(
857+
op_type="Cast",
858+
inputs=["conv1_conv/Conv2D:0"],
859+
outputs=["cast1_cast/Cast:0"],
860+
name="cast1_cast/Cast",
861+
to=onnx.TensorProto.DOUBLE,
862+
),
863+
helper.make_node(
864+
op_type="IsInf",
865+
inputs=["cast1_cast/Cast:0"],
866+
outputs=["isinf1_isinf/IsInf:0"],
867+
name="isinf1_isinf/IsInf",
868+
),
869+
helper.make_node(
870+
op_type="Greater",
871+
inputs=["conv1_conv/Conv2D:0", "greater_const1"],
872+
outputs=["greater1_greater/Greater:0"],
873+
name="greater1_greater/Greater",
874+
),
875+
helper.make_node(
876+
op_type="And",
877+
inputs=["isinf1_isinf/IsInf:0", "greater1_greater/Greater:0"],
878+
outputs=["and1_and/And:0"],
879+
name="and1_and/And",
880+
),
881+
helper.make_node(
882+
op_type="Where",
883+
inputs=["and1_and/And:0", "conv1_conv/Conv2D:0", "where_const1"],
884+
outputs=["output_0"],
885+
name="where1_where/Where",
886+
),
887+
]
888+
889+
# Create the ONNX initializers
890+
initializers = [
891+
helper.make_tensor(
892+
name="weights_1",
893+
data_type=onnx.TensorProto.FLOAT,
894+
dims=(32, 32, 3, 3),
895+
vals=np.random.uniform(low=0.5, high=1.0, size=32 * 32 * 3 * 3),
896+
),
897+
helper.make_tensor(
898+
name="greater_const1",
899+
data_type=onnx.TensorProto.FLOAT,
900+
dims=(1,),
901+
vals=[0],
902+
),
903+
helper.make_tensor(
904+
name="where_const1",
905+
data_type=onnx.TensorProto.FLOAT,
906+
dims=(1,),
907+
vals=[10000],
908+
),
909+
]
910+
911+
# Create the ONNX graph with the nodes and initializers
912+
graph = helper.make_graph(nodes, "conv_isinf", inputs, outputs, initializer=initializers)
913+
914+
# Create the ONNX model
915+
model = helper.make_model(graph)
916+
model.opset_import[0].version = opset_version
917+
model.ir_version = 10
918+
919+
# Check the ONNX model
920+
model_inferred = onnx.shape_inference.infer_shapes(model)
921+
onnx.checker.check_model(model_inferred)
922+
923+
return model_inferred

tests/gpu/onnx/test_concat_elim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import onnx
2020
import onnx_graphsurgeon as gs
21-
from _test_utils.onnx.quantization.lib_test_models import build_conv_concat_model
21+
from _test_utils.onnx.lib_test_models import build_conv_concat_model
2222

2323
from modelopt.onnx.quantization.quantize import quantize
2424

tests/gpu/onnx/test_qdq_utils_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import onnx_graphsurgeon as gs
2020
import pytest
2121
import torch
22-
from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx
22+
from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx
2323

2424
from modelopt.onnx.quantization.quantize import quantize
2525

tests/gpu/onnx/test_quantize_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import onnx
1919
import onnx_graphsurgeon as gs
2020
import torch
21-
from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx
21+
from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx
2222

2323
import modelopt.onnx.quantization as moq
2424

tests/gpu/onnx/test_quantize_onnx_torch_int4_awq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import torch
2323
from _test_utils.import_helper import skip_if_no_libcudnn
24-
from _test_utils.onnx.quantization.lib_test_models import SimpleMLP, export_as_onnx, find_init
24+
from _test_utils.onnx.lib_test_models import SimpleMLP, export_as_onnx, find_init
2525
from _test_utils.torch.quantization.quantize_common import get_awq_config
2626

2727
import modelopt.onnx.quantization.int4 as int4

0 commit comments

Comments
 (0)