Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
368 changes: 250 additions & 118 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
through type checking and cleanup of redundant operations.
"""

from collections import namedtuple
from collections import defaultdict, namedtuple
from copy import deepcopy
from dataclasses import dataclass, field

import ml_dtypes
import numpy as np
Expand All @@ -39,6 +40,23 @@

PrecisionTypes = namedtuple("PrecisionTypes", ["onnx_type", "numpy_type", "str_short", "str_full"])


@dataclass
class InputIndexTracker:
"""A class that tracks the index of an input to a node."""

node: onnx.NodeProto
node_index: int


@dataclass
class InitializerConsumerTracker:
"""A class that tracks the nodes that consume an initializer."""

low_precision_nodes: list[InputIndexTracker] = field(default_factory=list)
high_precision_nodes: list[InputIndexTracker] = field(default_factory=list)


PRECISION_MAP = {
"fp32": PrecisionTypes(TensorProto.FLOAT, np.float32, "fp32", "float32"),
"fp16": PrecisionTypes(TensorProto.FLOAT16, np.float16, "fp16", "float16"),
Expand Down Expand Up @@ -472,133 +490,247 @@ def _get_tensors_to_cast(
def _convert_initializers(
self, low_precision_nodes: list[str], high_precision_nodes: list[str]
) -> onnx.ModelProto:
def convert_initializer(
init: onnx.TensorProto,
node: onnx.NodeProto,
from_type: PrecisionTypes,
to_type: PrecisionTypes,
):
if init.data_type != from_type.onnx_type:
"""Convert model initializers to appropriate precision based on their consumer nodes.

This method analyzes how each initializer is used by different precision nodes and converts
or duplicates initializers as needed to ensure type compatibility:

1. Maps each initializer to the high/low precision nodes that consume it
2. For each initializer, applies one of these strategies:
- If only used by low precision nodes: convert to low precision
- If only used by high precision nodes: convert to high precision
- If used by both precision types: duplicate the initializer, creating separate
copies for each precision type and updating node references accordingly
3. Skips conversion for non-float initializers or those already at correct precision

The method handles special cases like bfloat16 conversion and provides warnings when
values are clamped or replaced due to precision limits.

Args:
low_precision_nodes: List of node names that should use low precision initializers.
high_precision_nodes: List of node names that should use high precision initializers.
"""
# 1. Compute a mapping from initiailizers to high precision nodes & low precision nodes that use them.
low_precision_nodes_set: set[str] = set(low_precision_nodes)
high_precision_nodes_set: set[str] = set(high_precision_nodes)
initializer_to_nodes: dict[str, InitializerConsumerTracker] = defaultdict(
lambda: InitializerConsumerTracker()
)
for node in self.model.graph.node:
# Compute the mapping from initializers to low precision nodes that use them.
if node.name in low_precision_nodes_set:
for idx, input_name in enumerate(node.input):
if input_name in self.initializer_map:
if self._should_skip_low_precision_input_conversion(node, input_name):
# Handle low precision nodes that require certain high precision inputs.
initializer_to_nodes[input_name].high_precision_nodes.append(
InputIndexTracker(node=node, node_index=idx)
)
else:
initializer_to_nodes[input_name].low_precision_nodes.append(
InputIndexTracker(node=node, node_index=idx)
)
# Compute the mapping from initializers to high precision nodes that use them.
elif node.name in high_precision_nodes_set:
for idx, input_name in enumerate(node.input):
if input_name in self.initializer_map:
initializer_to_nodes[input_name].high_precision_nodes.append(
InputIndexTracker(node=node, node_index=idx)
)

onnx_float_types = set(ONNX_TYPES)
# 2. Convert initializers to appropriate precision based on their consumer nodes.
for init_name, tracker in initializer_to_nodes.items():
# Get the initializer.
init = self.initializer_map[init_name]
# If not used, just skip.
if len(tracker.low_precision_nodes) == 0 and len(tracker.high_precision_nodes) == 0:
logger.debug(f"Initializer {init_name} is not used by any nodes, skipping")
continue
# If the initializer is not a float, then just skip.
if init.data_type not in onnx_float_types:
logger.debug(f"Initializer {init_name} is not a float, skipping")
continue
# If the initializer is only used by high precision nodes and is high precision, then just skip.
if (
len(tracker.low_precision_nodes) == 0
and init.data_type == self.high_precision_type.onnx_type
):
logger.debug(
f"Initializer {init.name} has data type {init.data_type}, and size {len(init.raw_data)},"
"skipping conversion"
f"Initializer {init_name} is already high precision and only used "
"by high precision nodes, skipping"
)
return False
continue
# If the initializer is only used by low precision nodes and is low precision, then just skip.
if (
len(tracker.high_precision_nodes) == 0
and init.data_type == self.low_precision_type.onnx_type
):
logger.debug(
f"Initializer {init_name} is already low precision and only used "
"by low precision nodes, skipping"
)
continue

# If the initializer is used by only one precision type, then convert it to the other precision type.
if len(tracker.high_precision_nodes) == 0 or len(tracker.low_precision_nodes) == 0:
if len(tracker.low_precision_nodes) > 0:
logger.debug(
f"Convert initializer {init_name} to "
f"{self.low_precision_type.str_short}, only used by low precision nodes"
)
from_type = self.high_precision_type
to_type = self.low_precision_type
elif len(tracker.high_precision_nodes) > 0:
logger.debug(
f"Convert initializer {init_name} to "
f"{self.high_precision_type.str_short}, "
"only used by high precision nodes"
)
from_type = self.low_precision_type
to_type = self.high_precision_type
else:
raise ValueError(
f"Unexpected: initializer {init_name} is not used by any "
"nodes and is not a float"
)

new_init = self._cast_initializer(
init=init,
from_type=from_type,
to_type=to_type,
low_precision_nodes=tracker.low_precision_nodes,
high_precision_nodes=tracker.high_precision_nodes,
)
if new_init is not None:
self.model.graph.initializer.remove(init)
self.model.graph.initializer.extend([new_init])
continue

# If initializer is too large, skip conversion, perform cast instead
if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes:
# This initializer is used by both high precision and low precision nodes, so we need
# to duplicate it for low precision nodes.
assert len(tracker.low_precision_nodes) > 0 and len(tracker.high_precision_nodes) > 0
if init.data_type == self.low_precision_type.onnx_type:
logger.debug(
f"Initializer {init.name} is too large, skipping initializer conversion, cast in "
"runtime instead"
f"Convert initializer {init_name} to "
f"{self.high_precision_type.str_short}, "
"used by both high precision and low precision nodes"
)
exclude_consumers = (
low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes
from_type = self.low_precision_type
to_type = self.high_precision_type
nodes_to_update = tracker.high_precision_nodes
elif init.data_type == self.high_precision_type.onnx_type:
logger.debug(
f"Convert initializer {init_name} to "
f"{self.low_precision_type.str_short}, "
"used by both high precision and low precision nodes"
)
self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers)
return True
try:
np_array = numpy_helper.to_array(init)
assert from_type.str_short in PRECISION_MAP
assert to_type.str_short in PRECISION_MAP
assert from_type.str_short != to_type.str_short

if np_array.dtype == from_type.numpy_type:
consumers = [n.name for n in utils.get_consumer_nodes(self.model, init.name)]
should_duplicate = len(consumers) > 1 and set(consumers) & set(
high_precision_nodes
)
from_type = self.high_precision_type
to_type = self.low_precision_type
nodes_to_update = tracker.low_precision_nodes
else:
raise ValueError(f"Unexpected: initializer {init_name} is not a float")

new_init = self._cast_initializer(
init=init,
from_type=from_type,
to_type=to_type,
low_precision_nodes=tracker.low_precision_nodes,
high_precision_nodes=tracker.high_precision_nodes,
)
if new_init is not None:
new_init_name = f"{init_name}_{to_type.str_short}"
new_init.name = new_init_name
for node in nodes_to_update:
node.node.input[node.node_index] = new_init_name
self.model.graph.initializer.extend([new_init])

def _cast_initializer(
self,
init: onnx.TensorProto,
from_type: PrecisionTypes,
to_type: PrecisionTypes,
low_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto],
high_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto],
) -> onnx.TensorProto | None:
Comment on lines +648 to +655
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is copied from the previous implementation

"""Cast an initializer to a new precision based on its consumer nodes.

if should_duplicate:
# Create a new low precision copy with a different name
new_name = f"{init.name}_{to_type.str_short}"
logger.debug(
f"Initializer {init.name} is shared, creating {to_type.str_short} copy as {new_name} due "
f"to node {node.name}"
)
This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
and providing warnings when values are clamped or replaced due to precision limits.

# Update the node to use the new initializer
for i, input_name in enumerate(node.input):
if input_name == init.name:
node.input[i] = new_name
break
Args:
init: The initializer to cast.
from_type: The original precision of the initializer.
to_type: The new precision to cast the initializer to.

if init.name in initializer_converted_dup:
return False
initializer_converted_dup.append(init.name)
else:
if init.name in initializer_converted:
return False
new_name = init.name
logger.debug(
f"Converting initializer {new_name} to {to_type.str_short} due to node {node.name}"
)
initializer_converted.append(init.name)
self.model.graph.initializer.remove(init)

# Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
if self._is_bf16(to_type) and self._is_fp32(from_type):
new_init = onnx.TensorProto()
new_init.dims.extend(np_array.shape)
new_init.name = new_name
new_init.data_type = onnx.TensorProto.BFLOAT16
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
new_init.raw_data = bf16_bytes.tobytes()
else:
assert to_type.numpy_type is not None
data_max, data_lowest = (
np.finfo(to_type.numpy_type).max,
np.finfo(to_type.numpy_type).smallest_subnormal,
)
if np.any(np.abs(np_array) > data_max):
logger.warning(
f"Initializer {init.name} used by node {node.name} contains values larger than "
f"largest {to_type.str_short} value, values will be clamped to {data_max}."
)
np_array = np.clip(np_array, -1 * data_max, data_max)
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
logger.warning(
f"Initializer {init.name} used by node {node.name} contains values smaller than "
f"smallest {to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
)
np_array = np.where(
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
data_lowest,
np_array,
)
new_array = np_array.astype(to_type.numpy_type)
new_init = numpy_helper.from_array(new_array, new_name)
self.model.graph.initializer.extend([new_init])
return True
return False
except Exception as e:
logger.error(f"Error converting initializer {init.name}: {e}")
return False
Returns:
onnx.TensorProto: The casted initializer.
"""
Comment on lines +656 to +668
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring return type is outdated

The function also returns None (runtime-cast path). Update docstring to reflect this.

Apply:

-        Returns:
-            onnx.TensorProto: The casted initializer.
+        Returns:
+            onnx.TensorProto | None: The casted initializer, or None when casting is deferred to runtime.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""Cast an initializer to a new precision based on its consumer nodes.
if should_duplicate:
# Create a new low precision copy with a different name
new_name = f"{init.name}_{to_type.str_short}"
logger.debug(
f"Initializer {init.name} is shared, creating {to_type.str_short} copy as {new_name} due "
f"to node {node.name}"
)
This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
and providing warnings when values are clamped or replaced due to precision limits.
# Update the node to use the new initializer
for i, input_name in enumerate(node.input):
if input_name == init.name:
node.input[i] = new_name
break
Args:
init: The initializer to cast.
from_type: The original precision of the initializer.
to_type: The new precision to cast the initializer to.
if init.name in initializer_converted_dup:
return False
initializer_converted_dup.append(init.name)
else:
if init.name in initializer_converted:
return False
new_name = init.name
logger.debug(
f"Converting initializer {new_name} to {to_type.str_short} due to node {node.name}"
)
initializer_converted.append(init.name)
self.model.graph.initializer.remove(init)
# Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
if self._is_bf16(to_type) and self._is_fp32(from_type):
new_init = onnx.TensorProto()
new_init.dims.extend(np_array.shape)
new_init.name = new_name
new_init.data_type = onnx.TensorProto.BFLOAT16
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
new_init.raw_data = bf16_bytes.tobytes()
else:
assert to_type.numpy_type is not None
data_max, data_lowest = (
np.finfo(to_type.numpy_type).max,
np.finfo(to_type.numpy_type).smallest_subnormal,
)
if np.any(np.abs(np_array) > data_max):
logger.warning(
f"Initializer {init.name} used by node {node.name} contains values larger than "
f"largest {to_type.str_short} value, values will be clamped to {data_max}."
)
np_array = np.clip(np_array, -1 * data_max, data_max)
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
logger.warning(
f"Initializer {init.name} used by node {node.name} contains values smaller than "
f"smallest {to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
)
np_array = np.where(
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
data_lowest,
np_array,
)
new_array = np_array.astype(to_type.numpy_type)
new_init = numpy_helper.from_array(new_array, new_name)
self.model.graph.initializer.extend([new_init])
return True
return False
except Exception as e:
logger.error(f"Error converting initializer {init.name}: {e}")
return False
Returns:
onnx.TensorProto: The casted initializer.
"""
"""Cast an initializer to a new precision based on its consumer nodes.
This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
and providing warnings when values are clamped or replaced due to precision limits.
Args:
init: The initializer to cast.
from_type: The original precision of the initializer.
to_type: The new precision to cast the initializer to.
Returns:
onnx.TensorProto | None: The casted initializer, or None when casting is deferred to runtime.
"""
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 658 to 670, the
docstring's Returns section incorrectly states it only returns onnx.TensorProto
despite the function sometimes returning None for runtime-cast paths; update the
Returns docstring to accurately reflect the possible return values (e.g.,
"onnx.TensorProto or None: The casted initializer, or None if casting is
deferred to runtime") and adjust wording to match project docstring style.


initializer_converted = []
initializer_converted_dup = []
modified = False
for node in self.model.graph.node:
if node.name in low_precision_nodes:
for init in self.node_to_init_map[node.name]:
if self._should_skip_low_precision_input_conversion(node, init.name):
continue
modified |= convert_initializer(
init,
node,
from_type=self.high_precision_type,
to_type=self.low_precision_type,
)
if modified:
_, _, self.node_to_init_map = utils.setup_mappings(self.model)

if node.name in high_precision_nodes:
for init in self.node_to_init_map[node.name]:
convert_initializer(
init,
node,
from_type=self.low_precision_type,
to_type=self.high_precision_type,
)
def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
"""Get the name of a node or input index tracker."""
if isinstance(node, onnx.NodeProto):
return node.name
elif isinstance(node, InputIndexTracker):
return node.node.name
else:
raise ValueError(f"Unexpected: {type(node)}")

# Ensure the initializer is of the expected type
assert init.data_type == from_type.onnx_type, (
f"Initializer {init.name} is not of type {from_type.str_short}"
)

if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes:
# The initializer is too large, so we need to convert it at runtime.
logger.debug(
f"Initializer {init.name} is too large, skipping initializer conversion, cast in "
"runtime instead"
)
exclude_consumers = (
low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes
)
exclude_consumers_names: list[str] = []

exclude_consumers_names = [_get_name(node) for node in exclude_consumers]
self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names)
return None

Comment on lines +684 to +698
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Initializer size heuristic should not rely on raw_data only; handle float_data and EXTERNAL

Large tensors stored via float_data or external data will bypass the size check and be materialized in memory. Compute nbytes from dims/data_type and treat EXTERNAL as large.

Apply:

-        if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes:
+        # Robust size computation without loading into NumPy
+        def _elem_size(dt: int) -> int:
+            return {
+                onnx.TensorProto.FLOAT: 4,
+                onnx.TensorProto.FLOAT16: 2,
+                onnx.TensorProto.BFLOAT16: 2,
+                onnx.TensorProto.DOUBLE: 8,
+            }.get(dt, 0)
+
+        is_external = (
+            hasattr(init, "data_location")
+            and init.data_location == onnx.TensorProto.EXTERNAL
+        )
+        num_elems = int(np.prod(init.dims)) if len(init.dims) > 0 else 0
+        approx_nbytes = len(init.raw_data) if init.raw_data else num_elems * _elem_size(init.data_type)
+
+        if is_external or approx_nbytes > self.init_conversion_max_bytes:
             # The initializer is too large, so we need to convert it at runtime.
             logger.debug(
-                f"Initializer {init.name} is too large, skipping initializer conversion, cast in "
-                "runtime instead"
+                f"Initializer {init.name} (~{approx_nbytes}B, external={is_external}) is large; "
+                "skip static conversion, insert runtime Cast instead"
             )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 686-700, the
initializer-size heuristic currently only checks init.raw_data which misses
large tensors stored in float_data or as EXTERNAL; update the check to compute
the initializer byte-size from its dims and data_type (or from float_data length
when present) and treat ONNX external data as large: if
init.HasField("data_location") and data_location==EXTERNAL consider it exceeding
the threshold, otherwise compute nbytes = product(dims) * sizeof(element_type)
(use the ONNX tensor data_type -> element size mapping or numpy dtype mapping)
and use that nbytes in the comparison against self.init_conversion_max_bytes;
keep the existing branch behavior (log, build exclude_consumers_names and call
_add_cast) when the computed size exceeds the limit.

np_array = numpy_helper.to_array(init)
# Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
if self._is_bf16(to_type) and self._is_fp32(from_type):
new_init = onnx.TensorProto()
new_init.dims.extend(np_array.shape)
new_init.name = init.name
new_init.data_type = onnx.TensorProto.BFLOAT16
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
new_init.raw_data = bf16_bytes.tobytes()
else:
assert to_type.numpy_type is not None
data_max, data_lowest = (
np.finfo(to_type.numpy_type).max,
np.finfo(to_type.numpy_type).smallest_subnormal,
)
if np.any(np.abs(np_array) > data_max):
logger.warning(
f"Initializer {init.name} contains values larger than largest "
f"{to_type.str_short} value, values will be clamped to {data_max}."
)
np_array = np.clip(np_array, -1 * data_max, data_max)
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
logger.warning(
f"Initializer {init.name} contains values smaller than smallest "
f"{to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
)
np_array = np.where(
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
data_lowest,
np_array,
)
new_array = np_array.astype(to_type.numpy_type)
new_init = numpy_helper.from_array(new_array, init.name)

return new_init

def _replace_tensor_name(
self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str
Expand Down