Skip to content

Commit 0410a67

Browse files
committed
Moved unsupported op detection out of constant
Signed-off-by: gcunhase <[email protected]>
1 parent ce57978 commit 0410a67

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ def __init__(
138138
self.min_opset = min_opset
139139
self.max_ir_version = max_ir_version
140140
self.trt_plugins = trt_plugins
141-
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(
142-
utils.get_ops_without_low_precision_support(
141+
self.op_types_not_supported_in_low_precision = OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION + (
142+
utils.get_op_types_not_supported_in_low_precision(
143143
self.model, self.low_precision_type.str_full, self.min_opset
144144
)
145145
)
@@ -451,7 +451,7 @@ def _filter_unsupported_op_types(
451451
# precision so we need to set Resize and Upsample to high precision
452452
for node in self.model.graph.node:
453453
if (
454-
node.op_type in OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION
454+
node.op_type in self.op_types_not_supported_in_low_precision
455455
and node.name in low_precision_nodes
456456
):
457457
low_precision_nodes.remove(node.name)

modelopt/onnx/autocast/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,20 +120,20 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int:
120120
raise ValueError("Cast node does not have 'to' attribute")
121121

122122

123-
def get_ops_without_low_precision_support(
123+
def get_op_types_not_supported_in_low_precision(
124124
model: onnx.ModelProto,
125125
low_precision_type: str,
126126
min_opset: int,
127127
) -> list[str]:
128-
"""Get a list of ops without low precision support for the current opset version.
128+
"""Get a list of ops not supported in low precision for the current opset version.
129129
130130
Args:
131131
model: ONNX model.
132132
low_precision_type: Target precision to reduce to ('float16' or 'bfloat16').
133133
min_opset: Minimum opset version.
134134
135135
Returns:
136-
ops_without_support: List of ops without low precision support for the current opset version.
136+
ops_without_support: List of ops not supported in low precision for the current opset version.
137137
"""
138138
# Obtain the current model's opset version
139139
ai_onnx_domain = [

0 commit comments

Comments
 (0)