|
51 | 51 | # Temporarily block these ops in low precision, as they are not supported yet |
52 | 52 | OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop", "LSTM"]) |
53 | 53 |
|
| 54 | +# Mapping of op types to indices of inputs that should not be converted to low precision. |
| 55 | +SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {1}} |
| 56 | +SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}} |
| 57 | + |
54 | 58 |
|
55 | 59 | class PrecisionConverter: |
56 | 60 | """Precision conversion module for ONNX models. |
@@ -1079,19 +1083,18 @@ def _should_skip_low_precision_input_conversion( |
1079 | 1083 |
|
1080 | 1084 | This is used for nodes that have inputs that MUST remain in FP32. |
1081 | 1085 | """ |
1082 | | - assert isinstance(node, onnx.NodeProto), f"node must be an onnx.NodeProto, got {type(node)}" |
1083 | | - assert isinstance(input_name, str), f"input_name must be a string, got {type(input_name)}" |
1084 | | - if node.op_type == "Resize": |
1085 | | - if input_name not in node.input: |
1086 | | - raise KeyError( |
1087 | | - f"Input {input_name} not found in node {node.name} input, not expected!" |
1088 | | - ) |
| 1086 | + match self.low_precision_type: |
| 1087 | + case "fp16": |
| 1088 | + skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16 |
| 1089 | + case "bf16": |
| 1090 | + skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_BF16 |
| 1091 | + case _: |
| 1092 | + raise ValueError(f"Unsupported low precision type: {self.low_precision_type}") |
| 1093 | + |
| 1094 | + if node.op_type in skip_inputs_map: |
1089 | 1095 | # Figure out the index of the input in the node input |
1090 | 1096 | inputs_lst = list(node.input) |
1091 | 1097 | input_index = inputs_lst.index(input_name) |
1092 | | - # The second input does not support bfloat16, so leave it in FP32. |
1093 | | - # The third input of Resize must remain in FP32. |
1094 | | - # Ref: https://onnx.ai/onnx/operators/onnx__Resize.html#inputs |
1095 | | - return input_index in {1, 2} |
1096 | | - |
| 1098 | + # Check if we should skip this input for low precision conversion |
| 1099 | + return input_index in skip_inputs_map[node.op_type] |
1097 | 1100 | return False |
0 commit comments