Skip to content

Commit 7c854c7

Browse files
committed
Implement function to detect tensor inputs to keep in high precision
Signed-off-by: gcunhase <[email protected]>
1 parent 00d6d1d commit 7c854c7

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

modelopt/onnx/autocast/convert.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,11 @@ def convert_to_f16(
236236
keep_io_types=keep_io_types,
237237
low_precision_type=low_precision_type,
238238
custom_ops=sanitizer.custom_ops,
239+
tensor_block_dict=tensor_block_dict,
239240
)
240241
high_precision_nodes = [node.name for node in model.graph.node if node.op_type in op_block_list]
241242
low_precision_nodes = [
242243
node.name for node in model.graph.node if node.op_type not in op_block_list
243244
]
244-
model_mod = precision_converter.convert(
245-
high_precision_nodes, low_precision_nodes, tensor_block_dict
246-
)
245+
model_mod = precision_converter.convert(high_precision_nodes, low_precision_nodes)
247246
return model_mod

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(
9999
min_opset: int = 13,
100100
max_ir_version: int | None = None,
101101
trt_plugins: list[str] | None = [],
102+
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
102103
) -> None:
103104
"""Initialize PrecisionConverter.
104105
@@ -112,6 +113,10 @@ def __init__(
112113
init_conversion_max_bytes: Maximum size in bytes for initializer conversion. Larger initializers will be
113114
cast at runtime.
114115
custom_ops: List of custom ops.
116+
min_opset: Minimum opset for conversion.
117+
max_ir_version: Max IR version for conversion.
118+
trt_plugins: List of custom TensorRT plugin library paths in .so format (compiled shared library).
119+
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
115120
"""
116121
self.model = deepcopy(model)
117122
self.value_info_map = value_info_map
@@ -148,18 +153,19 @@ def __init__(
148153
)
149154
)
150155

156+
# Custom mapping of op types to indices of inputs that should not be converted to low precision
157+
self.skip_inputs_map = self._create_skip_inputs_mapping(tensor_block_dict)
158+
151159
def convert(
152160
self,
153161
high_precision_nodes: list[str],
154162
low_precision_nodes: list[str],
155-
tensor_block_dict: dict[str, dict[str, list[int]]] = {},
156163
) -> onnx.ModelProto:
157164
"""Convert model to mixed precision.
158165
159166
Args:
160167
high_precision_nodes: List of node names to keep in high precision.
161168
low_precision_nodes: List of node names to convert to low precision.
162-
tensor_block_dict: Dictionary of tensors (operation type and I/O indices) that should remain in FP32.
163169
164170
Returns:
165171
onnx.ModelProto: The converted mixed precision model.
@@ -190,7 +196,7 @@ def convert(
190196
input.type.tensor_type.elem_type = self.low_precision_type.onnx_type
191197

192198
cast_down_tensors, cast_up_tensors, fp32_input_to_low_precision_node = (
193-
self._get_tensors_to_cast(low_precision_nodes, tensor_block_dict)
199+
self._get_tensors_to_cast(low_precision_nodes)
194200
)
195201
logger.debug(f"cast down (to {self.low_precision_type.str_full}): {cast_down_tensors}")
196202
logger.debug(f"cast up (to {self.high_precision_type.str_full}): {cast_up_tensors}")
@@ -483,11 +489,8 @@ def _get_tensors_to_cast(
483489
for node in self.model.graph.node:
484490
if node.name in low_precision_nodes:
485491
# Cast inputs to FP16 nodes down to FP16
486-
high_precision_tensor = high_precision_tensors.get(node.op_type, {})
487-
for idx, input in enumerate(node.input):
488-
if self._should_skip_low_precision_input_conversion(
489-
node, input
490-
) or idx in high_precision_tensor.get("inp", []):
492+
for input in node.input:
493+
if self._should_skip_low_precision_input_conversion(node, input):
491494
cast_to_fp32.append(input)
492495
fp32_input_to_low_precision_node[input].append(node)
493496
else:
@@ -1280,13 +1283,9 @@ def _sanitize_model(self):
12801283
graph_sanitizer.sanitize()
12811284
self.model = graph_sanitizer.model
12821285

1283-
def _should_skip_low_precision_input_conversion(
1284-
self, node: onnx.NodeProto, input_name: str
1285-
) -> bool:
1286-
"""Check if the input should be skipped for low precision conversion.
1287-
1288-
This is used for nodes that have inputs that MUST remain in FP32.
1289-
"""
1286+
def _create_skip_inputs_mapping(self, tensor_block_dict: dict[str, dict[str, list[int]]] = {}):
1287+
"""Create mapping of op types to indices of inputs that should not be converted to low precision."""
1288+
skip_inputs_map = {}
12901289
match self.low_precision_type.str_short:
12911290
case "fp16":
12921291
skip_inputs_map = SKIP_LOW_PRECISION_MAPPING_FP16
@@ -1295,12 +1294,27 @@ def _should_skip_low_precision_input_conversion(
12951294
case _:
12961295
raise ValueError(f"Unsupported low precision type: {self.low_precision_type}")
12971296

1298-
if node.op_type in skip_inputs_map:
1297+
# Update mapping with user-defined information
1298+
for op, tensor_map in tensor_block_dict.items():
1299+
high_precision_tensor = tensor_map.get("inp", [])
1300+
if high_precision_tensor:
1301+
skip_inputs_map.update({op: set(high_precision_tensor)})
1302+
1303+
return skip_inputs_map
1304+
1305+
def _should_skip_low_precision_input_conversion(
1306+
self, node: onnx.NodeProto, input_name: str
1307+
) -> bool:
1308+
"""Check if the input should be skipped for low precision conversion.
1309+
1310+
This is used for nodes that have inputs that MUST remain in FP32.
1311+
"""
1312+
if node.op_type in self.skip_inputs_map:
12991313
# Figure out the index of the input in the node input
13001314
inputs_lst = list(node.input)
13011315
if input_name not in inputs_lst:
13021316
raise ValueError(f"Input {input_name} not found in node {node.name}.")
13031317
input_index = inputs_lst.index(input_name)
13041318
# Check if we should skip this input for low precision conversion
1305-
return input_index in skip_inputs_map[node.op_type]
1319+
return input_index in self.skip_inputs_map[node.op_type]
13061320
return False

0 commit comments

Comments
 (0)