Skip to content

Commit 02f729c

Browse files
committed
bugfixes
Signed-off-by: Ali Boubezari <[email protected]>
1 parent 6262096 commit 02f729c

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1083,7 +1083,7 @@ def _should_skip_low_precision_input_conversion(
10831083
10841084
This is used for nodes that have inputs that MUST remain in FP32.
10851085
"""
1086-
match self.low_precision_type:
1086+
match self.low_precision_type.str_short:
10871087
case "fp16":
10881088
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16
10891089
case "bf16":
@@ -1094,6 +1094,8 @@ def _should_skip_low_precision_input_conversion(
10941094
if node.op_type in skip_inputs_map:
10951095
# Figure out the index of the input in the node input
10961096
inputs_lst = list(node.input)
1097+
if input_name not in inputs_lst:
1098+
raise ValueError(f"Input {input_name} not found in node {node.name}.")
10971099
input_index = inputs_lst.index(input_name)
10981100
# Check if we should skip this input for low precision conversion
10991101
return input_index in skip_inputs_map[node.op_type]

0 commit comments

Comments
 (0)