@@ -99,6 +99,7 @@ def __init__(
9999 min_opset : int = 13 ,
100100 max_ir_version : int | None = None ,
101101 trt_plugins : list [str ] | None = [],
102+ tensor_block_dict : dict [str , dict [str , list [int ]]] = {},
102103 ) -> None :
103104 """Initialize PrecisionConverter.
104105
@@ -112,6 +113,10 @@ def __init__(
112113 init_conversion_max_bytes: Maximum size in bytes for initializer conversion. Larger initializers will be
113114 cast at runtime.
114115 custom_ops: List of custom ops.
116+ min_opset: Minimum opset for conversion.
117+ max_ir_version: Max IR version for conversion.
118+ trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library).
119+ tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
115120 """
116121 self .model = deepcopy (model )
117122 self .value_info_map = value_info_map
@@ -148,18 +153,19 @@ def __init__(
148153 )
149154 )
150155
156+ # Custom mapping of op types to indices of inputs that should not be converted to low precision
157+ self .skip_inputs_map = self ._create_skip_inputs_mapping (tensor_block_dict )
158+
151159 def convert (
152160 self ,
153161 high_precision_nodes : list [str ],
154162 low_precision_nodes : list [str ],
155- tensor_block_dict : dict [str , dict [str , list [int ]]] = {},
156163 ) -> onnx .ModelProto :
157164 """Convert model to mixed precision.
158165
159166 Args:
160167 high_precision_nodes: List of node names to keep in high precision.
161168 low_precision_nodes: List of node names to convert to low precision.
162- tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
163169
164170 Returns:
165171 onnx.ModelProto: The converted mixed precision model.
@@ -190,7 +196,7 @@ def convert(
190196 input .type .tensor_type .elem_type = self .low_precision_type .onnx_type
191197
192198 cast_down_tensors , cast_up_tensors , fp32_input_to_low_precision_node = (
193- self ._get_tensors_to_cast (low_precision_nodes , tensor_block_dict )
199+ self ._get_tensors_to_cast (low_precision_nodes )
194200 )
195201 logger .debug (f"cast down (to { self .low_precision_type .str_full } ): { cast_down_tensors } " )
196202 logger .debug (f"cast up (to { self .high_precision_type .str_full } ): { cast_up_tensors } " )
@@ -483,11 +489,8 @@ def _get_tensors_to_cast(
483489 for node in self .model .graph .node :
484490 if node .name in low_precision_nodes :
485491 # Cast inputs to FP16 nodes down to FP16
486- high_precision_tensor = high_precision_tensors .get (node .op_type , {})
487- for idx , input in enumerate (node .input ):
488- if self ._should_skip_low_precision_input_conversion (
489- node , input
490- ) or idx in high_precision_tensor .get ("inp" , []):
492+ for input in node .input :
493+ if self ._should_skip_low_precision_input_conversion (node , input ):
491494 cast_to_fp32 .append (input )
492495 fp32_input_to_low_precision_node [input ].append (node )
493496 else :
@@ -1280,13 +1283,9 @@ def _sanitize_model(self):
12801283 graph_sanitizer .sanitize ()
12811284 self .model = graph_sanitizer .model
12821285
1283- def _should_skip_low_precision_input_conversion (
1284- self , node : onnx .NodeProto , input_name : str
1285- ) -> bool :
1286- """Check if the input should be skipped for low precision conversion.
1287-
1288- This is used for nodes that have inputs that MUST remain in FP32.
1289- """
1286+ def _create_skip_inputs_mapping (self , tensor_block_dict : dict [str , dict [str , list [int ]]] = {}):
1287+ """Create mapping of op types to indices of inputs that should not be converted to low precision."""
1288+ skip_inputs_map = {}
12901289 match self .low_precision_type .str_short :
12911290 case "fp16" :
12921291 skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16
@@ -1295,12 +1294,27 @@ def _should_skip_low_precision_input_conversion(
12951294 case _:
12961295 raise ValueError (f"Unsupported low precision type: { self .low_precision_type } " )
12971296
1298- if node .op_type in skip_inputs_map :
1297+ # Update mapping with user-defined information
1298+ for op , tensor_map in tensor_block_dict .items ():
1299+ high_precision_tensor = tensor_map .get ("inp" , [])
1300+ if high_precision_tensor :
1301+ skip_inputs_map .update ({op : set (high_precision_tensor )})
1302+
1303+ return skip_inputs_map
1304+
1305+ def _should_skip_low_precision_input_conversion (
1306+ self , node : onnx .NodeProto , input_name : str
1307+ ) -> bool :
1308+ """Check if the input should be skipped for low precision conversion.
1309+
1310+ This is used for nodes that have inputs that MUST remain in FP32.
1311+ """
1312+ if node .op_type in self .skip_inputs_map :
12991313 # Figure out the index of the input in the node input
13001314 inputs_lst = list (node .input )
13011315 if input_name not in inputs_lst :
13021316 raise ValueError (f"Input { input_name } not found in node { node .name } ." )
13031317 input_index = inputs_lst .index (input_name )
13041318 # Check if we should skip this input for low precision conversion
1305- return input_index in skip_inputs_map [node .op_type ]
1319+ return input_index in self . skip_inputs_map [node .op_type ]
13061320 return False
0 commit comments