Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 161 additions & 54 deletions modelopt/onnx/autocast/precisionconverter.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ class InitializerConsumerTracker:

OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"]

# Temporarily block these ops in low precision, as they are not supported yet
OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop"])

# Mapping of op types to indices of inputs that should not be converted to low precision.
SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {2}}
SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}}
Expand Down Expand Up @@ -240,8 +237,8 @@ def convert(
tensor_to_producers=tensor_to_producers,
)

# Convert initializers to correct precision according to the consumer nodes
self._convert_initializers(
# Convert initializers to correct precision according to the consumer nodes (main graph + subgraphs)
self._convert_initializers_recursive(
low_precision_nodes=low_precision_nodes, high_precision_nodes=high_precision_nodes
)

Expand All @@ -250,17 +247,8 @@ def convert(
# Populate type information with inferred types
self.model = self._propagate_types_shapes_custom_ops(self.model)
else:
# Clear type/shape information for intermediates and outputs
for vi in self.model.graph.value_info:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"
for out in self.model.graph.output:
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(out.type.tensor_type.shape.dim):
if d.dim_value:
out.type.tensor_type.shape.dim[idx].dim_param = "unk"
# Clear type/shape information for intermediates and outputs (including subgraphs)
self._clear_types_and_shapes_recursive(self.model.graph)
# Populate type information with inferred types
self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False)
self._ensure_types_are_defined()
Expand All @@ -285,6 +273,47 @@ def _ensure_types_are_defined(self):
if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED:
vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type

def _clear_types_and_shapes_recursive(
self, graph: onnx.GraphProto, is_subgraph: bool = False
) -> None:
"""Recursively clear type/shape information for a graph and all its subgraphs.

This is necessary for control flow operators (Scan, If, Loop) which have subgraphs.

Args:
graph: The ONNX graph to clear types and shapes for.
is_subgraph: Whether this is a subgraph (True) or the main graph (False).
"""

def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> None:
logger.debug(
f"Clearing types/shapes in {'subgraph' if is_sub else 'main graph'}: {g.name}"
)

# Clear type/shape information for inputs (only for subgraphs, not main graph inputs)
if is_sub:
for inp in g.input:
if inp.type.HasField("tensor_type"):
inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(inp.type.tensor_type.shape.dim):
if d.dim_value:
inp.type.tensor_type.shape.dim[idx].dim_param = "unk"

# Clear type/shape information for intermediates and outputs
for vi in g.value_info:
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(vi.type.tensor_type.shape.dim):
if d.dim_value:
vi.type.tensor_type.shape.dim[idx].dim_param = "unk"

for out in g.output:
out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
for idx, d in enumerate(out.type.tensor_type.shape.dim):
if d.dim_value:
out.type.tensor_type.shape.dim[idx].dim_param = "unk"

utils.walk_subgraphs_recursive(graph, _clear_callback, is_subgraph=is_subgraph)

def _propagate_types_shapes_custom_ops(self, model):
"""Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications."""
logger.info("Propagating tensor shapes and types in model with custom ops.")
Expand Down Expand Up @@ -682,6 +711,118 @@ def _convert_initializers(
node.node.input[node.node_index] = new_init_name
self.model.graph.initializer.extend([new_init])

def _convert_initializers_recursive(
self, low_precision_nodes: list[str], high_precision_nodes: list[str]
) -> None:
"""Convert initializers in main graph and all subgraphs to appropriate precision.

For the main graph, uses sophisticated consumer tracking to determine precision.
For subgraphs, inherits precision from the parent control flow node and converts
all initializers to that precision (no runtime casts).

Args:
low_precision_nodes: List of node names in main graph that are low precision.
high_precision_nodes: List of node names in main graph that are high precision.
"""
# Convert main graph initializers with full consumer tracking
self._convert_initializers(low_precision_nodes, high_precision_nodes)

# Convert subgraph initializers - walk all subgraphs and convert based on parent node precision
low_precision_nodes_set = set(low_precision_nodes)

def _convert_subgraph_callback(
graph: onnx.GraphProto, parent: onnx.NodeProto, is_subgraph: bool
) -> None:
if not is_subgraph or parent is None:
return

# Inherit precision from parent control flow node
target_type = (
self.low_precision_type
if parent.name in low_precision_nodes_set
else self.high_precision_type
)

# Convert all float initializers to target precision
for init in graph.initializer:
if init.data_type not in ONNX_TYPES or init.data_type == target_type.onnx_type:
continue

from_type = (
self.high_precision_type
if init.data_type == self.high_precision_type.onnx_type
else self.low_precision_type
if init.data_type == self.low_precision_type.onnx_type
else None
)

if from_type is None:
logger.debug(
f"Skipping subgraph initializer {init.name} with unsupported type {init.data_type}"
)
continue

new_init = self._convert_initializer_data(init, from_type, target_type)
init.CopyFrom(new_init)

utils.walk_subgraphs_recursive(self.model.graph, _convert_subgraph_callback)

def _convert_initializer_data(
self,
init: onnx.TensorProto,
from_type: PrecisionTypes,
to_type: PrecisionTypes,
) -> onnx.TensorProto:
"""Convert initializer data to a new precision.

This is the core conversion logic extracted for reuse. Handles bfloat16 conversion
and provides warnings when values are clamped or replaced due to precision limits.

Args:
init: The initializer to convert.
from_type: The original precision of the initializer.
to_type: The new precision to cast the initializer to.

Returns:
onnx.TensorProto: The converted initializer.
"""
np_array = numpy_helper.to_array(init)

# Handle bfloat16 conversion
if self._is_bf16(to_type) and self._is_fp32(from_type):
new_init = onnx.TensorProto()
new_init.dims.extend(np_array.shape)
new_init.name = init.name
new_init.data_type = onnx.TensorProto.BFLOAT16
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
new_init.raw_data = bf16_bytes.tobytes()
else:
assert to_type.numpy_type is not None
data_max, data_lowest = (
np.finfo(to_type.numpy_type).max,
np.finfo(to_type.numpy_type).smallest_subnormal,
)
if np.any(np.abs(np_array) > data_max):
logger.warning(
f"Initializer '{init.name}' contains values larger than largest "
f"{to_type.str_short} value, values will be clamped to {data_max}."
)
np_array = np.clip(np_array, -1 * data_max, data_max)
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
logger.warning(
f"Initializer '{init.name}' contains values smaller than smallest "
f"{to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
)
np_array = np.where(
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
data_lowest,
np_array,
)
new_array = np_array.astype(to_type.numpy_type)
new_init = numpy_helper.from_array(new_array, init.name)

return new_init

def _cast_initializer(
self,
init: onnx.TensorProto,
Expand All @@ -699,9 +840,11 @@ def _cast_initializer(
init: The initializer to cast.
from_type: The original precision of the initializer.
to_type: The new precision to cast the initializer to.
low_precision_nodes: Low precision nodes that consume this initializer.
high_precision_nodes: High precision nodes that consume this initializer.

Returns:
onnx.TensorProto: The casted initializer.
onnx.TensorProto | None: The casted initializer, or None if a runtime cast was inserted instead.
"""

def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
Expand All @@ -727,47 +870,11 @@ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
exclude_consumers = (
low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes
)
exclude_consumers_names: list[str] = []

exclude_consumers_names = [_get_name(node) for node in exclude_consumers]
self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names)
return None

np_array = numpy_helper.to_array(init)
# Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
if self._is_bf16(to_type) and self._is_fp32(from_type):
new_init = onnx.TensorProto()
new_init.dims.extend(np_array.shape)
new_init.name = init.name
new_init.data_type = onnx.TensorProto.BFLOAT16
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
new_init.raw_data = bf16_bytes.tobytes()
else:
assert to_type.numpy_type is not None
data_max, data_lowest = (
np.finfo(to_type.numpy_type).max,
np.finfo(to_type.numpy_type).smallest_subnormal,
)
if np.any(np.abs(np_array) > data_max):
logger.warning(
f"Initializer {init.name} contains values larger than largest "
f"{to_type.str_short} value, values will be clamped to {data_max}."
)
np_array = np.clip(np_array, -1 * data_max, data_max)
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
logger.warning(
f"Initializer {init.name} contains values smaller than smallest "
f"{to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
)
np_array = np.where(
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
data_lowest,
np_array,
)
new_array = np_array.astype(to_type.numpy_type)
new_init = numpy_helper.from_array(new_array, init.name)

return new_init
return self._convert_initializer_data(init, from_type, to_type)

def _replace_tensor_name(
self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str
Expand Down
36 changes: 36 additions & 0 deletions modelopt/onnx/autocast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import logging
from collections import defaultdict
from collections.abc import Callable

import onnx

Expand Down Expand Up @@ -122,6 +123,41 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int:
raise ValueError("Cast node does not have 'to' attribute")


def walk_subgraphs_recursive(
graph: onnx.GraphProto,
callback: Callable,
parent_node: onnx.NodeProto = None,
is_subgraph: bool = False,
) -> None:
"""Recursively walk through a graph and all its subgraphs, applying a callback.

This utility function traverses an ONNX graph and all nested subgraphs by examining
graph attributes in nodes. It works with standard control flow operators (Scan, If, Loop)
as well as custom operators that define subgraphs using ONNX graph attributes.

Args:
graph: The graph to walk.
callback: Function to call for each graph. Signature: callback(graph, parent_node, is_subgraph).
parent_node: The parent node containing this subgraph (None for main graph).
is_subgraph: Whether this is a subgraph (True) or the main graph (False).

Note:
Works with any node that has attributes of type AttributeProto.GRAPH or
AttributeProto.GRAPHS, including custom operators.
"""
# Apply callback to current graph
callback(graph, parent_node, is_subgraph)

# Recursively process subgraphs in control flow nodes
for node in graph.node:
for attr in node.attribute:
if attr.type == onnx.AttributeProto.GRAPH:
walk_subgraphs_recursive(attr.g, callback, parent_node=node, is_subgraph=True)
elif attr.type == onnx.AttributeProto.GRAPHS:
for subgraph in attr.graphs:
walk_subgraphs_recursive(subgraph, callback, parent_node=node, is_subgraph=True)


def get_op_types_not_supported_in_low_precision(
model: onnx.ModelProto,
min_opset: int,
Expand Down
Loading