Skip to content

Commit 6262096

Browse files
committed
Generalize & automate skipping inputs; only skip index 2 for bfloat16
Signed-off-by: Ali Boubezari <[email protected]>
1 parent f3bcfd1 commit 6262096

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
# Temporarily block these ops in low precision, as they are not supported yet
5252
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop", "LSTM"])
5353

54+
# Mapping of op types to indices of inputs that should not be converted to low precision.
55+
SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {1}}
56+
SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}}
57+
5458

5559
class PrecisionConverter:
5660
"""Precision conversion module for ONNX models.
@@ -1079,19 +1083,18 @@ def _should_skip_low_precision_input_conversion(
10791083
10801084
This is used for nodes that have inputs that MUST remain in FP32.
10811085
"""
1082-
assert isinstance(node, onnx.NodeProto), f"node must be an onnx.NodeProto, got {type(node)}"
1083-
assert isinstance(input_name, str), f"input_name must be a string, got {type(input_name)}"
1084-
if node.op_type == "Resize":
1085-
if input_name not in node.input:
1086-
raise KeyError(
1087-
f"Input {input_name} not found in node {node.name} input, not expected!"
1088-
)
1086+
match self.low_precision_type:
1087+
case "fp16":
1088+
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16
1089+
case "bf16":
1090+
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_BF16
1091+
case _:
1092+
raise ValueError(f"Unsupported low precision type: {self.low_precision_type}")
1093+
1094+
if node.op_type in skip_inputs_map:
10891095
# Figure out the index of the input in the node input
10901096
inputs_lst = list(node.input)
10911097
input_index = inputs_lst.index(input_name)
1092-
# The second input does not support bfloat16, so leave it in FP32.
1093-
# The third input of Resize must remain in FP32.
1094-
# Ref: https://onnx.ai/onnx/operators/onnx__Resize.html#inputs
1095-
return input_index in {1, 2}
1096-
1098+
# Check if we should skip this input for low precision conversion
1099+
return input_index in skip_inputs_map[node.op_type]
10971100
return False

0 commit comments

Comments
 (0)