Skip to content

Conversation

@aboubezari
Copy link
Contributor

@aboubezari aboubezari commented Oct 22, 2025

What does this PR do?

Type of change: Bug fix

Overview: Optimize the runtime of the _convert_initializers. On our internal workloads, this change optimized total autocast runtime from 2 hours to 1 hour (50% reduction).

Implementation summary:
The current implementation of _convert_initializers is a greedy approach that loops through the graph node by node. Every time an initializer is casted, the initializer to node mappings are re-computed, which is extremely inefficient since this function loop through all nodes internally, causing overall runtime to be O(n^2).
Instead, we can loop through the initializers and pre-compute the low and high precision nodes that consume them. For each initializer, it boils down to 2 cases:

  1. Initializer is used by only high/low precision nodes
  2. Initializer is used by both high & low precision nodes.

For case (1), if the initializer doesn't match the consuming nodes, then cast it and replace it. For case (2), create a duplicate and point the mismatching consumer nodes to the new one. There is no need to greedily re-compute mappings as the initializer => consuming node relationship is unique and won't ever cause downstream conflicts.

Before:
image
After:
image

Testing

Unittests, internal workload testing

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: No

Additional Information

Summary by CodeRabbit

  • Improvements
    • Enhanced precision conversion handling for ONNX models with improved support for mixed-precision scenarios.
    • Optimized processing of large model weights through deferred runtime casting.
    • Extended support for FP32, FP16, and BF16 precision conversions with better value handling.

@aboubezari aboubezari requested a review from a team as a code owner October 22, 2025 16:32
@aboubezari aboubezari requested a review from i-riyad October 22, 2025 16:32
@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 22, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 22, 2025

Walkthrough

Introduces data-tracking dataclasses and substantially reworks precision conversion initialization logic to handle mixed-precision consumption cases, distinguishing between low-precision, high-precision, and dual-consumption scenarios with support for large initializer runtime casting.

Changes

Cohort / File(s) Summary
Precision Converter Enhancement
modelopt/onnx/autocast/precisionconverter.py
Added two new dataclasses (InputIndexTracker, InitializerConsumerTracker) for tracking initializer consumption patterns. Extended imports (defaultdict, dataclass, field). Reworked _convert_initializers to build consumer mappings, skip incompatible initializers, and handle three consumption cases (low-only, high-only, mixed). Added _cast_initializer helper method supporting FP32/FP16/BF16 conversions with value clamping warnings, large initializer deferred casting via runtime, and ONNX TensorProto creation. Expanded in-code documentation for edge case handling.

Sequence Diagram(s)

sequenceDiagram
    participant Conv as _convert_initializers
    participant Track as Consumer Tracking
    participant Route as Route by Case
    participant CastInit as _cast_initializer
    participant Runtime as Runtime (_add_cast)

    Conv->>Track: Build initializer→consumers mapping
    Track->>Track: Classify: low-precision, high-precision
    Conv->>Route: Route by consumption pattern
    
    alt Low-precision only
        Route->>CastInit: Cast down (FP32→FP16/BF16)
        CastInit->>CastInit: Clamp values, warn if adjusted
        CastInit->>Runtime: Large initializer? Defer to runtime
    else High-precision only
        Route->>CastInit: Cast up (FP16/BF16→FP32)
        CastInit->>CastInit: Create new TensorProto
    else Mixed consumption
        Route->>CastInit: Duplicate & adjust (cast one variant)
        CastInit->>CastInit: Replace original, update node inputs
    end
    
    CastInit->>Conv: Return casted TensorProto or None
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

The changes introduce multi-case branching logic, new data structures, and per-initializer handling with multiple edge cases (mixed consumption, large initializers, BF16 handling). While the changes are contained to one file, the substantial rework of core conversion logic and introduction of new helper methods with nuanced behavior require careful reasoning across multiple scenarios.

Poem

🐰✨ Precision paths now split with care,
Low and high consumers everywhere!
Mixed-precision mixed no more—
Each initializer finds its door,
With casting wisdom, data flows fair!

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The pull request title "[Autocast] Optimize _convert_initializers runtime" is directly related to the main change in the changeset. The PR's primary objective is to optimize the runtime of the _convert_initializers function, which is exactly what the title conveys. The title is specific and clear—it identifies the exact function being optimized and indicates that runtime performance is the focus, avoiding vague terms or noise. A teammate scanning commit history would immediately understand that this PR improves the execution speed of a specific function. The changes (new dataclasses, refactored algorithm, helper methods) all support and implement this stated optimization goal.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Ali Boubezari <[email protected]>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)

99-103: Avoid mutable default argument for trt_plugins

Using [] as a default leaks state across instances.

Apply:

-        trt_plugins: list[str] | None = [],
+        trt_plugins: list[str] | None = None,
     ) -> None:
@@
-        self.trt_plugins = trt_plugins
+        self.trt_plugins = trt_plugins or []

388-403: Type annotation/doc mismatch in _get_node_initializers_map

Function docstring claims list[str], but it returns list[onnx.TensorProto].

Apply:

-    def _get_node_initializers_map(self) -> dict[str, list[str]]:
+    def _get_node_initializers_map(self) -> dict[str, list[onnx.TensorProto]]:
@@
-            dict[str, list[str]]: Mapping from node names to lists of initializer names.
+            dict[str, list[onnx.TensorProto]]: Mapping from node names to initializer protos.
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

44-58: Dataclasses look good; consider slots to reduce per-instance overhead

These trackers may be created per input edge; slots can lower memory/GC pressure in large graphs. Optional.

Example:

 @dataclass
 class InputIndexTracker:
+    __slots__ = ("node", "node_index")
     """A class that tracks the index of an input to a node."""
@@
 @dataclass
 class InitializerConsumerTracker:
+    __slots__ = ("low_precision_nodes", "high_precision_nodes")
     """A class that tracks the nodes that consume an initializer."""
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between d0e83ed and 2673939.

📒 Files selected for processing (1)
  • modelopt/onnx/autocast/precisionconverter.py (3 hunks)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)

636-649: Mixed-consumption path looks correct; verify node name uniqueness for exclude-consumers logic

Runtime-cast relies on matching consumer node names. If models contain empty/duplicate node names, exclusion may fail. Ensure GraphSanitizer enforces unique names before this pass.

You can verify uniqueness is enforced by checking sanitizer code or asserting invariants pre-pass. If needed, I can add a precondition check that auto-assigns unique names to unnamed nodes.

Comment on lines 550 to 555
if init.data_type not in {
self.high_precision_type.onnx_type,
self.low_precision_type.onnx_type,
}:
logger.debug(f"Initializer {init_name} is not a float, skipping")
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Float-type check incorrectly skips BF16 when low_precision_type != 'bf16'

The condition treats BF16 initializers as “not a float” if low_precision_type is fp16, so such inits are never considered for conversion.

Apply:

-            if init.data_type not in {
-                self.high_precision_type.onnx_type,
-                self.low_precision_type.onnx_type,
-            }:
-                logger.debug(f"Initializer {init_name} is not a float, skipping")
+            supported_float_types = {
+                onnx.TensorProto.FLOAT,
+                onnx.TensorProto.FLOAT16,
+                onnx.TensorProto.BFLOAT16,
+            }
+            if init.data_type not in supported_float_types:
+                logger.debug(f"Initializer {init_name} is not a supported float type, skipping")
                 continue

Optionally, add conversion support for FP16 <-> BF16 in _cast_initializer later.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if init.data_type not in {
self.high_precision_type.onnx_type,
self.low_precision_type.onnx_type,
}:
logger.debug(f"Initializer {init_name} is not a float, skipping")
continue
supported_float_types = {
onnx.TensorProto.FLOAT,
onnx.TensorProto.FLOAT16,
onnx.TensorProto.BFLOAT16,
}
if init.data_type not in supported_float_types:
logger.debug(f"Initializer {init_name} is not a supported float type, skipping")
continue
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 550-555, the
current float-type check skips BF16 initializers when low_precision_type is fp16
because it only allows the two configured types; change the check to detect any
floating-point initializer (e.g., test init.data_type against the set of ONNX
float types like FP32, FP16, BF16 or use an existing helper that identifies
floating dtypes) so BF16 inits are not incorrectly skipped, then proceed with
conversion logic (and optionally add FP16<->BF16 handling in _cast_initializer
later).

Comment on lines +658 to +670
"""Cast an initializer to a new precision based on its consumer nodes.
if should_duplicate:
# Create a new low precision copy with a different name
new_name = f"{init.name}_{to_type.str_short}"
logger.debug(
f"Initializer {init.name} is shared, creating {to_type.str_short} copy as {new_name} due "
f"to node {node.name}"
)
This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
and providing warnings when values are clamped or replaced due to precision limits.
# Update the node to use the new initializer
for i, input_name in enumerate(node.input):
if input_name == init.name:
node.input[i] = new_name
break
Args:
init: The initializer to cast.
from_type: The original precision of the initializer.
to_type: The new precision to cast the initializer to.
if init.name in initializer_converted_dup:
return False
initializer_converted_dup.append(init.name)
else:
if init.name in initializer_converted:
return False
new_name = init.name
logger.debug(
f"Converting initializer {new_name} to {to_type.str_short} due to node {node.name}"
)
initializer_converted.append(init.name)
self.model.graph.initializer.remove(init)

# Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
if self._is_bf16(to_type) and self._is_fp32(from_type):
new_init = onnx.TensorProto()
new_init.dims.extend(np_array.shape)
new_init.name = new_name
new_init.data_type = onnx.TensorProto.BFLOAT16
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
new_init.raw_data = bf16_bytes.tobytes()
else:
assert to_type.numpy_type is not None
data_max, data_lowest = (
np.finfo(to_type.numpy_type).max,
np.finfo(to_type.numpy_type).smallest_subnormal,
)
if np.any(np.abs(np_array) > data_max):
logger.warning(
f"Initializer {init.name} used by node {node.name} contains values larger than "
f"largest {to_type.str_short} value, values will be clamped to {data_max}."
)
np_array = np.clip(np_array, -1 * data_max, data_max)
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
logger.warning(
f"Initializer {init.name} used by node {node.name} contains values smaller than "
f"smallest {to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
)
np_array = np.where(
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
data_lowest,
np_array,
)
new_array = np_array.astype(to_type.numpy_type)
new_init = numpy_helper.from_array(new_array, new_name)
self.model.graph.initializer.extend([new_init])
return True
return False
except Exception as e:
logger.error(f"Error converting initializer {init.name}: {e}")
return False
Returns:
onnx.TensorProto: The casted initializer.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring return type is outdated

The function also returns None (runtime-cast path). Update docstring to reflect this.

Apply:

-        Returns:
-            onnx.TensorProto: The casted initializer.
+        Returns:
+            onnx.TensorProto | None: The casted initializer, or None when casting is deferred to runtime.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""Cast an initializer to a new precision based on its consumer nodes.
if should_duplicate:
# Create a new low precision copy with a different name
new_name = f"{init.name}_{to_type.str_short}"
logger.debug(
f"Initializer {init.name} is shared, creating {to_type.str_short} copy as {new_name} due "
f"to node {node.name}"
)
This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
and providing warnings when values are clamped or replaced due to precision limits.
# Update the node to use the new initializer
for i, input_name in enumerate(node.input):
if input_name == init.name:
node.input[i] = new_name
break
Args:
init: The initializer to cast.
from_type: The original precision of the initializer.
to_type: The new precision to cast the initializer to.
if init.name in initializer_converted_dup:
return False
initializer_converted_dup.append(init.name)
else:
if init.name in initializer_converted:
return False
new_name = init.name
logger.debug(
f"Converting initializer {new_name} to {to_type.str_short} due to node {node.name}"
)
initializer_converted.append(init.name)
self.model.graph.initializer.remove(init)
# Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
if self._is_bf16(to_type) and self._is_fp32(from_type):
new_init = onnx.TensorProto()
new_init.dims.extend(np_array.shape)
new_init.name = new_name
new_init.data_type = onnx.TensorProto.BFLOAT16
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
new_init.raw_data = bf16_bytes.tobytes()
else:
assert to_type.numpy_type is not None
data_max, data_lowest = (
np.finfo(to_type.numpy_type).max,
np.finfo(to_type.numpy_type).smallest_subnormal,
)
if np.any(np.abs(np_array) > data_max):
logger.warning(
f"Initializer {init.name} used by node {node.name} contains values larger than "
f"largest {to_type.str_short} value, values will be clamped to {data_max}."
)
np_array = np.clip(np_array, -1 * data_max, data_max)
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
logger.warning(
f"Initializer {init.name} used by node {node.name} contains values smaller than "
f"smallest {to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
)
np_array = np.where(
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
data_lowest,
np_array,
)
new_array = np_array.astype(to_type.numpy_type)
new_init = numpy_helper.from_array(new_array, new_name)
self.model.graph.initializer.extend([new_init])
return True
return False
except Exception as e:
logger.error(f"Error converting initializer {init.name}: {e}")
return False
Returns:
onnx.TensorProto: The casted initializer.
"""
"""Cast an initializer to a new precision based on its consumer nodes.
This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
and providing warnings when values are clamped or replaced due to precision limits.
Args:
init: The initializer to cast.
from_type: The original precision of the initializer.
to_type: The new precision to cast the initializer to.
Returns:
onnx.TensorProto | None: The casted initializer, or None when casting is deferred to runtime.
"""
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 658 to 670, the
docstring's Returns section incorrectly states it only returns onnx.TensorProto
despite the function sometimes returning None for runtime-cast paths; update the
Returns docstring to accurately reflect the possible return values (e.g.,
"onnx.TensorProto or None: The casted initializer, or None if casting is
deferred to runtime") and adjust wording to match project docstring style.

Comment on lines +686 to +700
if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes:
# The initializer is too large, so we need to convert it at runtime.
logger.debug(
f"Initializer {init.name} is too large, skipping initializer conversion, cast in "
"runtime instead"
)
exclude_consumers = (
low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes
)
exclude_consumers_names: list[str] = []

exclude_consumers_names = [_get_name(node) for node in exclude_consumers]
self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names)
return None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Initializer size heuristic should not rely on raw_data only; handle float_data and EXTERNAL

Large tensors stored via float_data or external data will bypass the size check and be materialized in memory. Compute nbytes from dims/data_type and treat EXTERNAL as large.

Apply:

-        if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes:
+        # Robust size computation without loading into NumPy
+        def _elem_size(dt: int) -> int:
+            return {
+                onnx.TensorProto.FLOAT: 4,
+                onnx.TensorProto.FLOAT16: 2,
+                onnx.TensorProto.BFLOAT16: 2,
+                onnx.TensorProto.DOUBLE: 8,
+            }.get(dt, 0)
+
+        is_external = (
+            hasattr(init, "data_location")
+            and init.data_location == onnx.TensorProto.EXTERNAL
+        )
+        num_elems = int(np.prod(init.dims)) if len(init.dims) > 0 else 0
+        approx_nbytes = len(init.raw_data) if init.raw_data else num_elems * _elem_size(init.data_type)
+
+        if is_external or approx_nbytes > self.init_conversion_max_bytes:
             # The initializer is too large, so we need to convert it at runtime.
             logger.debug(
-                f"Initializer {init.name} is too large, skipping initializer conversion, cast in "
-                "runtime instead"
+                f"Initializer {init.name} (~{approx_nbytes}B, external={is_external}) is large; "
+                "skip static conversion, insert runtime Cast instead"
             )

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 686-700, the
initializer-size heuristic currently only checks init.raw_data which misses
large tensors stored in float_data or as EXTERNAL; update the check to compute
the initializer byte-size from its dims and data_type (or from float_data length
when present) and treat ONNX external data as large: if
init.HasField("data_location") and data_location==EXTERNAL consider it exceeding
the threshold, otherwise compute nbytes = product(dims) * sizeof(element_type)
(use the ONNX tensor data_type -> element size mapping or numpy dtype mapping)
and use that nbytes in the comparison against self.init_conversion_max_bytes;
keep the existing branch behavior (log, build exclude_consumers_names and call
_add_cast) when the computed size exceeds the limit.

Comment on lines +650 to +657
def _cast_initializer(
self,
init: onnx.TensorProto,
from_type: PrecisionTypes,
to_type: PrecisionTypes,
low_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto],
high_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto],
) -> onnx.TensorProto | None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is copied from the previous implementation

@codecov
Copy link

codecov bot commented Oct 23, 2025

Codecov Report

❌ Patch coverage is 79.41176% with 21 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.39%. Comparing base (d0e83ed) to head (adad68d).
⚠️ Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/autocast/precisionconverter.py 79.41% 21 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #459      +/-   ##
==========================================
- Coverage   73.41%   73.39%   -0.03%     
==========================================
  Files         180      180              
  Lines       18011    18110      +99     
==========================================
+ Hits        13223    13291      +68     
- Misses       4788     4819      +31     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@galagam
Copy link
Contributor

galagam commented Oct 23, 2025

Thanks for the contribution @aboubezari !
Overall this looks great, and well overdue. I'll review in depth early next week.

Copy link
Contributor

@galagam galagam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aboubezari Looks good. Added some minor comments. Thank you!

Signed-off-by: Ali Boubezari <[email protected]>
@aboubezari
Copy link
Contributor Author

@galagam We should be all good to test & merge. Thank you!

@galagam
Copy link
Contributor

galagam commented Oct 27, 2025

/ok to test e2c0d39

@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 27, 2025

/ok to test e2c0d39

@galagam, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@galagam
Copy link
Contributor

galagam commented Oct 27, 2025

/ok to test adad68d

@galagam galagam enabled auto-merge (squash) October 27, 2025 17:49
@galagam galagam merged commit 41de55f into NVIDIA:main Oct 27, 2025
26 checks passed
kevalmorabia97 pushed a commit that referenced this pull request Oct 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants