From e2c0d39105a15d66bcad6d2b07d5b826b4e0b06a Mon Sep 17 00:00:00 2001 From: Ali Boubezari Date: Wed, 22 Oct 2025 09:21:46 -0700 Subject: [PATCH 1/3] [Autocast] Optimize _convert_initializers runtime Signed-off-by: Ali Boubezari --- modelopt/onnx/autocast/precisionconverter.py | 374 +++++++++++++------ 1 file changed, 256 insertions(+), 118 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index defec2bd4..7ac9f25f0 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -21,8 +21,9 @@ through type checking and cleanup of redundant operations. """ -from collections import namedtuple +from collections import defaultdict, namedtuple from copy import deepcopy +from dataclasses import dataclass, field import ml_dtypes import numpy as np @@ -39,6 +40,23 @@ PrecisionTypes = namedtuple("PrecisionTypes", ["onnx_type", "numpy_type", "str_short", "str_full"]) + +@dataclass +class InputIndexTracker: + """A class that tracks the index of an input to a node.""" + + node: onnx.NodeProto + node_index: int + + +@dataclass +class InitializerConsumerTracker: + """A class that tracks the nodes that consume an initializer.""" + + low_precision_nodes: list[InputIndexTracker] = field(default_factory=list) + high_precision_nodes: list[InputIndexTracker] = field(default_factory=list) + + PRECISION_MAP = { "fp32": PrecisionTypes(TensorProto.FLOAT, np.float32, "fp32", "float32"), "fp16": PrecisionTypes(TensorProto.FLOAT16, np.float16, "fp16", "float16"), @@ -472,133 +490,253 @@ def _get_tensors_to_cast( def _convert_initializers( self, low_precision_nodes: list[str], high_precision_nodes: list[str] ) -> onnx.ModelProto: - def convert_initializer( - init: onnx.TensorProto, - node: onnx.NodeProto, - from_type: PrecisionTypes, - to_type: PrecisionTypes, - ): - if init.data_type != from_type.onnx_type: + """Convert model initializers to appropriate precision based on their consumer nodes. + + This method analyzes how each initializer is used by different precision nodes and converts + or duplicates initializers as needed to ensure type compatibility: + + 1. Maps each initializer to the high/low precision nodes that consume it + 2. For each initializer, applies one of these strategies: + - If only used by low precision nodes: convert to low precision + - If only used by high precision nodes: convert to high precision + - If used by both precision types: duplicate the initializer, creating separate + copies for each precision type and updating node references accordingly + 3. Skips conversion for non-float initializers or those already at correct precision + + The method handles special cases like bfloat16 conversion and provides warnings when + values are clamped or replaced due to precision limits. + + Args: + low_precision_nodes: List of node names that should use low precision initializers. + high_precision_nodes: List of node names that should use high precision initializers. + """ + assert self.init_conversion_max_bytes == np.inf, ( + "init_conversion_max_bytes not support for v2 version yet." + ) + + # 1. Compute a mapping from initiailizers to high precision nodes & low precision nodes that use them. + low_precision_nodes_set: set[str] = set(low_precision_nodes) + high_precision_nodes_set: set[str] = set(high_precision_nodes) + initializer_to_nodes: dict[str, InitializerConsumerTracker] = defaultdict( + lambda: InitializerConsumerTracker() + ) + for node in self.model.graph.node: + # Compute the mapping from initializers to low precision nodes that use them. + if node.name in low_precision_nodes_set: + for idx, input_name in enumerate(node.input): + if input_name in self.initializer_map: + if self._should_skip_low_precision_input_conversion(node, input_name): + # Handle low precision nodes that require certain high precision inputs. + initializer_to_nodes[input_name].high_precision_nodes.append( + InputIndexTracker(node=node, node_index=idx) + ) + else: + initializer_to_nodes[input_name].low_precision_nodes.append( + InputIndexTracker(node=node, node_index=idx) + ) + # Compute the mapping from initializers to high precision nodes that use them. + elif node.name in high_precision_nodes_set: + for idx, input_name in enumerate(node.input): + if input_name in self.initializer_map: + initializer_to_nodes[input_name].high_precision_nodes.append( + InputIndexTracker(node=node, node_index=idx) + ) + + # 2. Convert initializers to appropriate precision based on their consumer nodes. + for init_name, tracker in initializer_to_nodes.items(): + # Get the initializer. + init = self.initializer_map[init_name] + # If not used, just skip. + if len(tracker.low_precision_nodes) == 0 and len(tracker.high_precision_nodes) == 0: + logger.debug(f"Initializer {init_name} is not used by any nodes, skipping") + continue + # If the initializer is not a float, then just skip. + if init.data_type not in { + self.high_precision_type.onnx_type, + self.low_precision_type.onnx_type, + }: + logger.debug(f"Initializer {init_name} is not a float, skipping") + continue + # If the initializer is only used by high precision nodes and is high precision, then just skip. + if ( + len(tracker.low_precision_nodes) == 0 + and init.data_type == self.high_precision_type.onnx_type + ): logger.debug( - f"Initializer {init.name} has data type {init.data_type}, and size {len(init.raw_data)}," - "skipping conversion" + f"Initializer {init_name} is already high precision and only used " + "by high precision nodes, skipping" ) - return False + continue + # If the initializer is only used by low precision nodes and is low precision, then just skip. + if ( + len(tracker.high_precision_nodes) == 0 + and init.data_type == self.low_precision_type.onnx_type + ): + logger.debug( + f"Initializer {init_name} is already low precision and only used " + "by low precision nodes, skipping" + ) + continue - # If initializer is too large, skip conversion, perform cast instead - if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes: + # If the initializer is used by only one precision type, then convert it to the other precision type. + if len(tracker.high_precision_nodes) == 0 or len(tracker.low_precision_nodes) == 0: + if len(tracker.low_precision_nodes) > 0: + logger.debug( + f"Convert initializer {init_name} to " + f"{self.low_precision_type.str_short}, only used by low precision nodes" + ) + from_type = self.high_precision_type + to_type = self.low_precision_type + elif len(tracker.high_precision_nodes) > 0: + logger.debug( + f"Convert initializer {init_name} to " + f"{self.high_precision_type.str_short}, " + "only used by high precision nodes" + ) + from_type = self.low_precision_type + to_type = self.high_precision_type + else: + raise ValueError( + f"Unexpected: initializer {init_name} is not used by any " + "nodes and is not a float" + ) + + new_init = self._cast_initializer( + init=init, + from_type=from_type, + to_type=to_type, + low_precision_nodes=tracker.low_precision_nodes, + high_precision_nodes=tracker.high_precision_nodes, + ) + if new_init is not None: + self.model.graph.initializer.remove(init) + self.model.graph.initializer.extend([new_init]) + continue + + # This initializer is used by both high precision and low precision nodes, so we need + # to duplicate it for low precision nodes. + assert len(tracker.low_precision_nodes) > 0 and len(tracker.high_precision_nodes) > 0 + if init.data_type == self.low_precision_type.onnx_type: logger.debug( - f"Initializer {init.name} is too large, skipping initializer conversion, cast in " - "runtime instead" + f"Convert initializer {init_name} to " + f"{self.high_precision_type.str_short}, " + "used by both high precision and low precision nodes" ) - exclude_consumers = ( - low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes + from_type = self.low_precision_type + to_type = self.high_precision_type + nodes_to_update = tracker.high_precision_nodes + elif init.data_type == self.high_precision_type.onnx_type: + logger.debug( + f"Convert initializer {init_name} to " + f"{self.low_precision_type.str_short}, " + "used by both high precision and low precision nodes" ) - self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers) - return True - try: - np_array = numpy_helper.to_array(init) - assert from_type.str_short in PRECISION_MAP - assert to_type.str_short in PRECISION_MAP - assert from_type.str_short != to_type.str_short - - if np_array.dtype == from_type.numpy_type: - consumers = [n.name for n in utils.get_consumer_nodes(self.model, init.name)] - should_duplicate = len(consumers) > 1 and set(consumers) & set( - high_precision_nodes - ) + from_type = self.high_precision_type + to_type = self.low_precision_type + nodes_to_update = tracker.low_precision_nodes + else: + raise ValueError(f"Unexpected: initializer {init_name} is not a float") + + new_init = self._cast_initializer( + init=init, + from_type=from_type, + to_type=to_type, + low_precision_nodes=tracker.low_precision_nodes, + high_precision_nodes=tracker.high_precision_nodes, + ) + if new_init is not None: + new_init_name = f"{init_name}_{to_type.str_short}" + new_init.name = new_init_name + for node in nodes_to_update: + node.node.input[node.node_index] = new_init_name + self.model.graph.initializer.extend([new_init]) + + def _cast_initializer( + self, + init: onnx.TensorProto, + from_type: PrecisionTypes, + to_type: PrecisionTypes, + low_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto], + high_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto], + ) -> onnx.TensorProto | None: + """Cast an initializer to a new precision based on its consumer nodes. - if should_duplicate: - # Create a new low precision copy with a different name - new_name = f"{init.name}_{to_type.str_short}" - logger.debug( - f"Initializer {init.name} is shared, creating {to_type.str_short} copy as {new_name} due " - f"to node {node.name}" - ) + This method converts an initializer to a new precision while handling special cases like bfloat16 conversion + and providing warnings when values are clamped or replaced due to precision limits. - # Update the node to use the new initializer - for i, input_name in enumerate(node.input): - if input_name == init.name: - node.input[i] = new_name - break + Args: + init: The initializer to cast. + from_type: The original precision of the initializer. + to_type: The new precision to cast the initializer to. - if init.name in initializer_converted_dup: - return False - initializer_converted_dup.append(init.name) - else: - if init.name in initializer_converted: - return False - new_name = init.name - logger.debug( - f"Converting initializer {new_name} to {to_type.str_short} due to node {node.name}" - ) - initializer_converted.append(init.name) - self.model.graph.initializer.remove(init) - - # Numpy does not support bfloat16, use ml_dtypes to create the raw data instead - if self._is_bf16(to_type) and self._is_fp32(from_type): - new_init = onnx.TensorProto() - new_init.dims.extend(np_array.shape) - new_init.name = new_name - new_init.data_type = onnx.TensorProto.BFLOAT16 - bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16) - new_init.raw_data = bf16_bytes.tobytes() - else: - assert to_type.numpy_type is not None - data_max, data_lowest = ( - np.finfo(to_type.numpy_type).max, - np.finfo(to_type.numpy_type).smallest_subnormal, - ) - if np.any(np.abs(np_array) > data_max): - logger.warning( - f"Initializer {init.name} used by node {node.name} contains values larger than " - f"largest {to_type.str_short} value, values will be clamped to {data_max}." - ) - np_array = np.clip(np_array, -1 * data_max, data_max) - if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)): - logger.warning( - f"Initializer {init.name} used by node {node.name} contains values smaller than " - f"smallest {to_type.str_short} value, values will be replaced with {data_lowest:.1e}." - ) - np_array = np.where( - (np_array != 0.0) & (np.abs(np_array) < data_lowest), - data_lowest, - np_array, - ) - new_array = np_array.astype(to_type.numpy_type) - new_init = numpy_helper.from_array(new_array, new_name) - self.model.graph.initializer.extend([new_init]) - return True - return False - except Exception as e: - logger.error(f"Error converting initializer {init.name}: {e}") - return False + Returns: + onnx.TensorProto: The casted initializer. + """ - initializer_converted = [] - initializer_converted_dup = [] - modified = False - for node in self.model.graph.node: - if node.name in low_precision_nodes: - for init in self.node_to_init_map[node.name]: - if self._should_skip_low_precision_input_conversion(node, init.name): - continue - modified |= convert_initializer( - init, - node, - from_type=self.high_precision_type, - to_type=self.low_precision_type, - ) - if modified: - _, _, self.node_to_init_map = utils.setup_mappings(self.model) - - if node.name in high_precision_nodes: - for init in self.node_to_init_map[node.name]: - convert_initializer( - init, - node, - from_type=self.low_precision_type, - to_type=self.high_precision_type, - ) + def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str: + """Get the name of a node or input index tracker.""" + if isinstance(node, onnx.NodeProto): + return node.name + elif isinstance(node, InputIndexTracker): + return node.node.name + else: + raise ValueError(f"Unexpected: {type(node)}") + + # Ensure the initializer is of the expected type + assert init.data_type == from_type.onnx_type, ( + f"Initializer {init.name} is not of type {from_type.str_short}" + ) + + if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes: + # The initializer is too large, so we need to convert it at runtime. + logger.debug( + f"Initializer {init.name} is too large, skipping initializer conversion, cast in " + "runtime instead" + ) + exclude_consumers = ( + low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes + ) + exclude_consumers_names: list[str] = [] + + exclude_consumers_names = [_get_name(node) for node in exclude_consumers] + self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names) + return None + + np_array = numpy_helper.to_array(init) + # Numpy does not support bfloat16, use ml_dtypes to create the raw data instead + if self._is_bf16(to_type) and self._is_fp32(from_type): + new_init = onnx.TensorProto() + new_init.dims.extend(np_array.shape) + new_init.name = init.name + new_init.data_type = onnx.TensorProto.BFLOAT16 + bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16) + new_init.raw_data = bf16_bytes.tobytes() + else: + assert to_type.numpy_type is not None + data_max, data_lowest = ( + np.finfo(to_type.numpy_type).max, + np.finfo(to_type.numpy_type).smallest_subnormal, + ) + if np.any(np.abs(np_array) > data_max): + logger.warning( + f"Initializer {init.name} contains values larger than largest " + f"{to_type.str_short} value, values will be clamped to {data_max}." + ) + np_array = np.clip(np_array, -1 * data_max, data_max) + if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)): + logger.warning( + f"Initializer {init.name} contains values smaller than smallest " + f"{to_type.str_short} value, values will be replaced with {data_lowest:.1e}." + ) + np_array = np.where( + (np_array != 0.0) & (np.abs(np_array) < data_lowest), + data_lowest, + np_array, + ) + new_array = np_array.astype(to_type.numpy_type) + new_init = numpy_helper.from_array(new_array, init.name) + + return new_init def _replace_tensor_name( self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str From 2673939f6a1b2dbe252a1a7ba8286add3cdb8a87 Mon Sep 17 00:00:00 2001 From: Ali Boubezari Date: Wed, 22 Oct 2025 09:33:12 -0700 Subject: [PATCH 2/3] Remove hack Signed-off-by: Ali Boubezari --- modelopt/onnx/autocast/precisionconverter.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 7ac9f25f0..9803e752e 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -510,10 +510,6 @@ def _convert_initializers( low_precision_nodes: List of node names that should use low precision initializers. high_precision_nodes: List of node names that should use high precision initializers. """ - assert self.init_conversion_max_bytes == np.inf, ( - "init_conversion_max_bytes not support for v2 version yet." - ) - # 1. Compute a mapping from initiailizers to high precision nodes & low precision nodes that use them. low_precision_nodes_set: set[str] = set(low_precision_nodes) high_precision_nodes_set: set[str] = set(high_precision_nodes) From adad68de62dacfbb10db8eb7917edb8a7aebccc9 Mon Sep 17 00:00:00 2001 From: Ali Boubezari Date: Mon, 27 Oct 2025 08:40:31 -0700 Subject: [PATCH 3/3] Accept all float types Signed-off-by: Ali Boubezari --- modelopt/onnx/autocast/precisionconverter.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 9803e752e..b436af1b2 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -538,6 +538,7 @@ def _convert_initializers( InputIndexTracker(node=node, node_index=idx) ) + onnx_float_types = set(ONNX_TYPES) # 2. Convert initializers to appropriate precision based on their consumer nodes. for init_name, tracker in initializer_to_nodes.items(): # Get the initializer. @@ -547,10 +548,7 @@ def _convert_initializers( logger.debug(f"Initializer {init_name} is not used by any nodes, skipping") continue # If the initializer is not a float, then just skip. - if init.data_type not in { - self.high_precision_type.onnx_type, - self.low_precision_type.onnx_type, - }: + if init.data_type not in onnx_float_types: logger.debug(f"Initializer {init_name} is not a float, skipping") continue # If the initializer is only used by high precision nodes and is high precision, then just skip.