-
Notifications
You must be signed in to change notification settings - Fork 190
[Autocast] Optimize _convert_initializers runtime
#459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Autocast] Optimize _convert_initializers runtime
#459
Conversation
Signed-off-by: Ali Boubezari <[email protected]>
WalkthroughIntroduces data-tracking dataclasses and substantially reworks precision conversion initialization logic to handle mixed-precision consumption cases, distinguishing between low-precision, high-precision, and dual-consumption scenarios with support for large initializer runtime casting. Changes
Sequence Diagram(s)sequenceDiagram
participant Conv as _convert_initializers
participant Track as Consumer Tracking
participant Route as Route by Case
participant CastInit as _cast_initializer
participant Runtime as Runtime (_add_cast)
Conv->>Track: Build initializer→consumers mapping
Track->>Track: Classify: low-precision, high-precision
Conv->>Route: Route by consumption pattern
alt Low-precision only
Route->>CastInit: Cast down (FP32→FP16/BF16)
CastInit->>CastInit: Clamp values, warn if adjusted
CastInit->>Runtime: Large initializer? Defer to runtime
else High-precision only
Route->>CastInit: Cast up (FP16/BF16→FP32)
CastInit->>CastInit: Create new TensorProto
else Mixed consumption
Route->>CastInit: Duplicate & adjust (cast one variant)
CastInit->>CastInit: Replace original, update node inputs
end
CastInit->>Conv: Return casted TensorProto or None
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes The changes introduce multi-case branching logic, new data structures, and per-initializer handling with multiple edge cases (mixed consumption, large initializers, BF16 handling). While the changes are contained to one file, the substantial rework of core conversion logic and introduction of new helper methods with nuanced behavior require careful reasoning across multiple scenarios. Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Ali Boubezari <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/onnx/autocast/precisionconverter.py (2)
99-103: Avoid mutable default argument for trt_pluginsUsing [] as a default leaks state across instances.
Apply:
- trt_plugins: list[str] | None = [], + trt_plugins: list[str] | None = None, ) -> None: @@ - self.trt_plugins = trt_plugins + self.trt_plugins = trt_plugins or []
388-403: Type annotation/doc mismatch in _get_node_initializers_mapFunction docstring claims list[str], but it returns list[onnx.TensorProto].
Apply:
- def _get_node_initializers_map(self) -> dict[str, list[str]]: + def _get_node_initializers_map(self) -> dict[str, list[onnx.TensorProto]]: @@ - dict[str, list[str]]: Mapping from node names to lists of initializer names. + dict[str, list[onnx.TensorProto]]: Mapping from node names to initializer protos.
🧹 Nitpick comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
44-58: Dataclasses look good; consider slots to reduce per-instance overheadThese trackers may be created per input edge; slots can lower memory/GC pressure in large graphs. Optional.
Example:
@dataclass class InputIndexTracker: + __slots__ = ("node", "node_index") """A class that tracks the index of an input to a node.""" @@ @dataclass class InitializerConsumerTracker: + __slots__ = ("low_precision_nodes", "high_precision_nodes") """A class that tracks the nodes that consume an initializer."""
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/onnx/autocast/precisionconverter.py(3 hunks)
🔇 Additional comments (1)
modelopt/onnx/autocast/precisionconverter.py (1)
636-649: Mixed-consumption path looks correct; verify node name uniqueness for exclude-consumers logicRuntime-cast relies on matching consumer node names. If models contain empty/duplicate node names, exclusion may fail. Ensure GraphSanitizer enforces unique names before this pass.
You can verify uniqueness is enforced by checking sanitizer code or asserting invariants pre-pass. If needed, I can add a precondition check that auto-assigns unique names to unnamed nodes.
| if init.data_type not in { | ||
| self.high_precision_type.onnx_type, | ||
| self.low_precision_type.onnx_type, | ||
| }: | ||
| logger.debug(f"Initializer {init_name} is not a float, skipping") | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Float-type check incorrectly skips BF16 when low_precision_type != 'bf16'
The condition treats BF16 initializers as “not a float” if low_precision_type is fp16, so such inits are never considered for conversion.
Apply:
- if init.data_type not in {
- self.high_precision_type.onnx_type,
- self.low_precision_type.onnx_type,
- }:
- logger.debug(f"Initializer {init_name} is not a float, skipping")
+ supported_float_types = {
+ onnx.TensorProto.FLOAT,
+ onnx.TensorProto.FLOAT16,
+ onnx.TensorProto.BFLOAT16,
+ }
+ if init.data_type not in supported_float_types:
+ logger.debug(f"Initializer {init_name} is not a supported float type, skipping")
continueOptionally, add conversion support for FP16 <-> BF16 in _cast_initializer later.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if init.data_type not in { | |
| self.high_precision_type.onnx_type, | |
| self.low_precision_type.onnx_type, | |
| }: | |
| logger.debug(f"Initializer {init_name} is not a float, skipping") | |
| continue | |
| supported_float_types = { | |
| onnx.TensorProto.FLOAT, | |
| onnx.TensorProto.FLOAT16, | |
| onnx.TensorProto.BFLOAT16, | |
| } | |
| if init.data_type not in supported_float_types: | |
| logger.debug(f"Initializer {init_name} is not a supported float type, skipping") | |
| continue |
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 550-555, the
current float-type check skips BF16 initializers when low_precision_type is fp16
because it only allows the two configured types; change the check to detect any
floating-point initializer (e.g., test init.data_type against the set of ONNX
float types like FP32, FP16, BF16 or use an existing helper that identifies
floating dtypes) so BF16 inits are not incorrectly skipped, then proceed with
conversion logic (and optionally add FP16<->BF16 handling in _cast_initializer
later).
| """Cast an initializer to a new precision based on its consumer nodes. | ||
| if should_duplicate: | ||
| # Create a new low precision copy with a different name | ||
| new_name = f"{init.name}_{to_type.str_short}" | ||
| logger.debug( | ||
| f"Initializer {init.name} is shared, creating {to_type.str_short} copy as {new_name} due " | ||
| f"to node {node.name}" | ||
| ) | ||
| This method converts an initializer to a new precision while handling special cases like bfloat16 conversion | ||
| and providing warnings when values are clamped or replaced due to precision limits. | ||
| # Update the node to use the new initializer | ||
| for i, input_name in enumerate(node.input): | ||
| if input_name == init.name: | ||
| node.input[i] = new_name | ||
| break | ||
| Args: | ||
| init: The initializer to cast. | ||
| from_type: The original precision of the initializer. | ||
| to_type: The new precision to cast the initializer to. | ||
| if init.name in initializer_converted_dup: | ||
| return False | ||
| initializer_converted_dup.append(init.name) | ||
| else: | ||
| if init.name in initializer_converted: | ||
| return False | ||
| new_name = init.name | ||
| logger.debug( | ||
| f"Converting initializer {new_name} to {to_type.str_short} due to node {node.name}" | ||
| ) | ||
| initializer_converted.append(init.name) | ||
| self.model.graph.initializer.remove(init) | ||
|
|
||
| # Numpy does not support bfloat16, use ml_dtypes to create the raw data instead | ||
| if self._is_bf16(to_type) and self._is_fp32(from_type): | ||
| new_init = onnx.TensorProto() | ||
| new_init.dims.extend(np_array.shape) | ||
| new_init.name = new_name | ||
| new_init.data_type = onnx.TensorProto.BFLOAT16 | ||
| bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16) | ||
| new_init.raw_data = bf16_bytes.tobytes() | ||
| else: | ||
| assert to_type.numpy_type is not None | ||
| data_max, data_lowest = ( | ||
| np.finfo(to_type.numpy_type).max, | ||
| np.finfo(to_type.numpy_type).smallest_subnormal, | ||
| ) | ||
| if np.any(np.abs(np_array) > data_max): | ||
| logger.warning( | ||
| f"Initializer {init.name} used by node {node.name} contains values larger than " | ||
| f"largest {to_type.str_short} value, values will be clamped to {data_max}." | ||
| ) | ||
| np_array = np.clip(np_array, -1 * data_max, data_max) | ||
| if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)): | ||
| logger.warning( | ||
| f"Initializer {init.name} used by node {node.name} contains values smaller than " | ||
| f"smallest {to_type.str_short} value, values will be replaced with {data_lowest:.1e}." | ||
| ) | ||
| np_array = np.where( | ||
| (np_array != 0.0) & (np.abs(np_array) < data_lowest), | ||
| data_lowest, | ||
| np_array, | ||
| ) | ||
| new_array = np_array.astype(to_type.numpy_type) | ||
| new_init = numpy_helper.from_array(new_array, new_name) | ||
| self.model.graph.initializer.extend([new_init]) | ||
| return True | ||
| return False | ||
| except Exception as e: | ||
| logger.error(f"Error converting initializer {init.name}: {e}") | ||
| return False | ||
| Returns: | ||
| onnx.TensorProto: The casted initializer. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring return type is outdated
The function also returns None (runtime-cast path). Update docstring to reflect this.
Apply:
- Returns:
- onnx.TensorProto: The casted initializer.
+ Returns:
+ onnx.TensorProto | None: The casted initializer, or None when casting is deferred to runtime.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| """Cast an initializer to a new precision based on its consumer nodes. | |
| if should_duplicate: | |
| # Create a new low precision copy with a different name | |
| new_name = f"{init.name}_{to_type.str_short}" | |
| logger.debug( | |
| f"Initializer {init.name} is shared, creating {to_type.str_short} copy as {new_name} due " | |
| f"to node {node.name}" | |
| ) | |
| This method converts an initializer to a new precision while handling special cases like bfloat16 conversion | |
| and providing warnings when values are clamped or replaced due to precision limits. | |
| # Update the node to use the new initializer | |
| for i, input_name in enumerate(node.input): | |
| if input_name == init.name: | |
| node.input[i] = new_name | |
| break | |
| Args: | |
| init: The initializer to cast. | |
| from_type: The original precision of the initializer. | |
| to_type: The new precision to cast the initializer to. | |
| if init.name in initializer_converted_dup: | |
| return False | |
| initializer_converted_dup.append(init.name) | |
| else: | |
| if init.name in initializer_converted: | |
| return False | |
| new_name = init.name | |
| logger.debug( | |
| f"Converting initializer {new_name} to {to_type.str_short} due to node {node.name}" | |
| ) | |
| initializer_converted.append(init.name) | |
| self.model.graph.initializer.remove(init) | |
| # Numpy does not support bfloat16, use ml_dtypes to create the raw data instead | |
| if self._is_bf16(to_type) and self._is_fp32(from_type): | |
| new_init = onnx.TensorProto() | |
| new_init.dims.extend(np_array.shape) | |
| new_init.name = new_name | |
| new_init.data_type = onnx.TensorProto.BFLOAT16 | |
| bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16) | |
| new_init.raw_data = bf16_bytes.tobytes() | |
| else: | |
| assert to_type.numpy_type is not None | |
| data_max, data_lowest = ( | |
| np.finfo(to_type.numpy_type).max, | |
| np.finfo(to_type.numpy_type).smallest_subnormal, | |
| ) | |
| if np.any(np.abs(np_array) > data_max): | |
| logger.warning( | |
| f"Initializer {init.name} used by node {node.name} contains values larger than " | |
| f"largest {to_type.str_short} value, values will be clamped to {data_max}." | |
| ) | |
| np_array = np.clip(np_array, -1 * data_max, data_max) | |
| if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)): | |
| logger.warning( | |
| f"Initializer {init.name} used by node {node.name} contains values smaller than " | |
| f"smallest {to_type.str_short} value, values will be replaced with {data_lowest:.1e}." | |
| ) | |
| np_array = np.where( | |
| (np_array != 0.0) & (np.abs(np_array) < data_lowest), | |
| data_lowest, | |
| np_array, | |
| ) | |
| new_array = np_array.astype(to_type.numpy_type) | |
| new_init = numpy_helper.from_array(new_array, new_name) | |
| self.model.graph.initializer.extend([new_init]) | |
| return True | |
| return False | |
| except Exception as e: | |
| logger.error(f"Error converting initializer {init.name}: {e}") | |
| return False | |
| Returns: | |
| onnx.TensorProto: The casted initializer. | |
| """ | |
| """Cast an initializer to a new precision based on its consumer nodes. | |
| This method converts an initializer to a new precision while handling special cases like bfloat16 conversion | |
| and providing warnings when values are clamped or replaced due to precision limits. | |
| Args: | |
| init: The initializer to cast. | |
| from_type: The original precision of the initializer. | |
| to_type: The new precision to cast the initializer to. | |
| Returns: | |
| onnx.TensorProto | None: The casted initializer, or None when casting is deferred to runtime. | |
| """ |
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 658 to 670, the
docstring's Returns section incorrectly states it only returns onnx.TensorProto
despite the function sometimes returning None for runtime-cast paths; update the
Returns docstring to accurately reflect the possible return values (e.g.,
"onnx.TensorProto or None: The casted initializer, or None if casting is
deferred to runtime") and adjust wording to match project docstring style.
| if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes: | ||
| # The initializer is too large, so we need to convert it at runtime. | ||
| logger.debug( | ||
| f"Initializer {init.name} is too large, skipping initializer conversion, cast in " | ||
| "runtime instead" | ||
| ) | ||
| exclude_consumers = ( | ||
| low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes | ||
| ) | ||
| exclude_consumers_names: list[str] = [] | ||
|
|
||
| exclude_consumers_names = [_get_name(node) for node in exclude_consumers] | ||
| self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names) | ||
| return None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Initializer size heuristic should not rely on raw_data only; handle float_data and EXTERNAL
Large tensors stored via float_data or external data will bypass the size check and be materialized in memory. Compute nbytes from dims/data_type and treat EXTERNAL as large.
Apply:
- if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes:
+ # Robust size computation without loading into NumPy
+ def _elem_size(dt: int) -> int:
+ return {
+ onnx.TensorProto.FLOAT: 4,
+ onnx.TensorProto.FLOAT16: 2,
+ onnx.TensorProto.BFLOAT16: 2,
+ onnx.TensorProto.DOUBLE: 8,
+ }.get(dt, 0)
+
+ is_external = (
+ hasattr(init, "data_location")
+ and init.data_location == onnx.TensorProto.EXTERNAL
+ )
+ num_elems = int(np.prod(init.dims)) if len(init.dims) > 0 else 0
+ approx_nbytes = len(init.raw_data) if init.raw_data else num_elems * _elem_size(init.data_type)
+
+ if is_external or approx_nbytes > self.init_conversion_max_bytes:
# The initializer is too large, so we need to convert it at runtime.
logger.debug(
- f"Initializer {init.name} is too large, skipping initializer conversion, cast in "
- "runtime instead"
+ f"Initializer {init.name} (~{approx_nbytes}B, external={is_external}) is large; "
+ "skip static conversion, insert runtime Cast instead"
)Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In modelopt/onnx/autocast/precisionconverter.py around lines 686-700, the
initializer-size heuristic currently only checks init.raw_data which misses
large tensors stored in float_data or as EXTERNAL; update the check to compute
the initializer byte-size from its dims and data_type (or from float_data length
when present) and treat ONNX external data as large: if
init.HasField("data_location") and data_location==EXTERNAL consider it exceeding
the threshold, otherwise compute nbytes = product(dims) * sizeof(element_type)
(use the ONNX tensor data_type -> element size mapping or numpy dtype mapping)
and use that nbytes in the comparison against self.init_conversion_max_bytes;
keep the existing branch behavior (log, build exclude_consumers_names and call
_add_cast) when the computed size exceeds the limit.
| def _cast_initializer( | ||
| self, | ||
| init: onnx.TensorProto, | ||
| from_type: PrecisionTypes, | ||
| to_type: PrecisionTypes, | ||
| low_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto], | ||
| high_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto], | ||
| ) -> onnx.TensorProto | None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is copied from the previous implementation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #459 +/- ##
==========================================
- Coverage 73.41% 73.39% -0.03%
==========================================
Files 180 180
Lines 18011 18110 +99
==========================================
+ Hits 13223 13291 +68
- Misses 4788 4819 +31 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Thanks for the contribution @aboubezari ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aboubezari Looks good. Added some minor comments. Thank you!
Signed-off-by: Ali Boubezari <[email protected]>
|
@galagam We should be all good to test & merge. Thank you! |
|
/ok to test e2c0d39 |
@galagam, there was an error processing your request: See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/ |
|
/ok to test adad68d |
Signed-off-by: Ali Boubezari <[email protected]>
What does this PR do?
Type of change: Bug fix
Overview: Optimize the runtime of the
_convert_initializers. On our internal workloads, this change optimized total autocast runtime from 2 hours to 1 hour (50% reduction).Implementation summary:
The current implementation of
_convert_initializersis a greedy approach that loops through the graph node by node. Every time an initializer is casted, the initializer to node mappings are re-computed, which is extremely inefficient since this function loop through all nodes internally, causing overall runtime to beO(n^2).Instead, we can loop through the initializers and pre-compute the low and high precision nodes that consume them. For each initializer, it boils down to 2 cases:
For case (1), if the initializer doesn't match the consuming nodes, then cast it and replace it. For case (2), create a duplicate and point the mismatching consumer nodes to the new one. There is no need to greedily re-compute mappings as the initializer => consuming node relationship is unique and won't ever cause downstream conflicts.
Before:


After:
Testing
Unittests, internal workload testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit