Skip to content

Conversation

@gcunhase
Copy link
Contributor

@gcunhase gcunhase commented Oct 20, 2025

What does this PR do?

Type of change: Bug fix

Overview: Fixed issue with --calibration_shapes usage with SymbolicShapeInference. Replaced with onnx infer_shapes function.

Usage

$ python -m modelopt.onnx.quantization --onnx_path=$MODEL_NAME.onnx

Testing

Model in bug 5597849.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: No
  • Did you update Changelog?: No

Summary by CodeRabbit

  • Refactor

    • Improved shape inference used during ONNX model optimization for more reliable exclusion of certain operations.
    • Calibration input shapes are now initialized only when per-node calibration is active, reducing unnecessary automatic shape loading.
  • Chores

    • Internal imports and implementation updated; documentation clarified to reflect the new shape inference workflow.

@gcunhase gcunhase requested a review from a team as a code owner October 20, 2025 22:16
@gcunhase gcunhase requested a review from ajrasane October 20, 2025 22:16
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 20, 2025

Walkthrough

Renamed helper _exclude_matmuls_by_symbolic_inference_exclude_matmuls_by_shape_inference, replaced SymbolicShapeInference.infer_shapes() with infer_shapes() from modelopt.onnx.utils, adjusted input_shapes initialization, and gated automatic loading of model input shapes to only run when calibrate_per_node is true.

Changes

Cohort / File(s) Summary
Shape inference & matmul exclusion
modelopt/onnx/quantization/graph_utils.py
Renamed _exclude_matmuls_by_symbolic_inference_exclude_matmuls_by_shape_inference; removed SymbolicShapeInference import; added infer_shapes import from modelopt.onnx.utils; replaced calls to SymbolicShapeInference.infer_shapes(model) with infer_shapes(model); changed input_shapes initialization to {} then populate conditionally for str/dict; updated docstring and call-site.
Calibration shapes gating
modelopt/onnx/quantization/quantize.py
Changed control flow so get_input_shapes is called to load model input shapes only when calibrate_per_node is True and calibration_shapes is not provided; eliminated prior unconditional loading when calibration_shapes was missing.

Sequence Diagram(s)

sequenceDiagram
    participant Finder as find_nodes_from_matmul_to_exclude
    participant Excluder as _exclude_matmuls_by_shape_inference
    participant Infer as infer_shapes
    Finder->>Excluder: call(matmul_nodes, calibration_shapes, model)
    alt calibration_shapes provided
        Excluder->>Infer: infer_shapes(model, input_shapes)
        Infer-->>Excluder: model_with_inferred_shapes
        Excluder->>Excluder: evaluate shapes -> collect exclusions
    else calibration_shapes absent
        Excluder->>Excluder: use existing value_info -> collect exclusions
    end
    Excluder-->>Finder: return excluded_matmul_node_names
Loading
sequenceDiagram
    participant Quantize as quantize()
    participant Loader as get_input_shapes
    Quantize->>Quantize: has calibrate_per_node?
    alt calibrate_per_node == true and calibration_shapes missing
        Quantize->>Loader: get_input_shapes(model)
        Loader-->>Quantize: input_shapes
    else
        Quantize->>Quantize: skip automatic input shape loading
    end
    Quantize-->>Quantize: proceed with calibration using available shapes
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐰 I hopped through code with a nimble bound,

renamed a helper and smoothed the ground,
shapes now inferred by a lighter call,
calibration waits when per-node is all,
nibble the bugs, then bounce around! 🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The PR title "[5597849] Fix for 'SymbolicShapeInference' error" accurately describes the core changes made in the pull request. The changeset's main modifications—renaming internal functions, replacing SymbolicShapeInference.infer_shapes with infer_shapes from modelopt.onnx.utils, and adjusting calibration shape handling—are all directly part of fixing the SymbolicShapeInference compatibility issue identified in bug 5597849. The title is specific enough that a teammate reviewing history would quickly understand this addresses a bug fix related to shape inference functionality, and the bug reference makes it traceable.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 543c2ca and 663d9f0.

📒 Files selected for processing (2)
  • modelopt/onnx/quantization/graph_utils.py (5 hunks)
  • modelopt/onnx/quantization/quantize.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/graph_utils.py (1)
modelopt/onnx/utils.py (2)
  • infer_shapes (723-736)
  • parse_shapes_spec (234-247)
🔇 Additional comments (5)
modelopt/onnx/quantization/graph_utils.py (4)

38-38: LGTM: Import addition is necessary and improves large model handling.

The infer_shapes utility from modelopt.onnx.utils properly handles large models (>2GB) by using external data format, which is an improvement over the previous approach.


969-971: LGTM: Function rename improves clarity.

The rename from _exclude_matmuls_by_symbolic_inference to _exclude_matmuls_by_shape_inference better reflects the updated implementation. Since this is a private function, there's no public API impact.


1061-1064: LGTM: Function signature and docstring are clear.

The updated docstring accurately describes the function's purpose using shape inference.


1073-1094: Implementation improvements look good, consider adding error handling.

The changes improve robustness:

  1. Lines 1073-1079: Explicit initialization of input_shapes as an empty dict is safer and prevents potential None issues.
  2. Line 1092: Using infer_shapes(model) instead of SymbolicShapeInference.infer_shapes(model) is an improvement as it handles large models properly.
  3. Line 1094: Adding model outputs to value_info_map ensures MatMul nodes whose outputs are model outputs can be properly analyzed.

However, the infer_shapes(model) call might fail for some models (e.g., models with unsupported ops or invalid graphs). Consider wrapping it in a try-except block with a clear error message.

try:
    model = infer_shapes(model)
except Exception as e:
    raise RuntimeError(
        f"Shape inference failed for model. This is required to determine MatMul exclusions. "
        f"Error: {e}"
    ) from e
modelopt/onnx/quantization/quantize.py (1)

473-474: Backward compatibility concern addressed by intentional design.

The gating condition change (restricting auto-loading to calibrate_per_node=True) is intentional and supported by robust fallback logic. When calibrate_per_node=False and calibration_shapes is not provided, find_nodes_from_matmul_to_exclude() properly falls back to actual inference via calibration_data_reader (graph_utils.py:968-976), rather than symbolic shape inference. Both paths are valid approaches.

No tests or documented examples were found indicating users relied on the old unconditional auto-loading behavior when calibrate_per_node=False. The change fixes the SymbolicShapeInference error mentioned in the PR description while maintaining functional correctness through the existing fallback mechanism.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
modelopt/onnx/quantization/graph_utils.py (2)

968-976: Use explicit None-check to select the shape-inference path.

Truthiness check can misroute when calibration_shapes is {} or "".

-    if calibration_shapes:
+    if calibration_shapes is not None:
         nodes_to_exclude = _exclude_matmuls_by_shape_inference(
             model, matmul_nodes, calibration_shapes
         )

1061-1113: Shape inference path: add robustness (output lookup, input validation, minor wording).

  • Include graph outputs in shape lookup; current code only checks value_info and may falsely error if a MatMul/Gemm output is a graph output.
  • Validate calibration_shapes dims are positive integers.
  • Minor: adjust comment to “shape inference” for consistency.
 def _exclude_matmuls_by_shape_inference(
     model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
 ) -> list[str]:
     """Use shape inference to find MatMuls with dimension 1."""
-    # Prepare model for symbolic inference
+    # Prepare model for shape inference
     for graph_input in model.graph.input:
         for dim in graph_input.type.tensor_type.shape.dim:
             if dim.HasField("dim_param"):
                 dim.Clear()
                 dim.dim_value = 1

     # Apply calibration shapes if provided
     input_shapes = {}
     if calibration_shapes:
         input_shapes = (
             parse_shapes_spec(calibration_shapes)
             if isinstance(calibration_shapes, str)
             else calibration_shapes
         )
     for graph_input in model.graph.input:
         if graph_input.name in input_shapes:
             input_shape = input_shapes[graph_input.name]
             tensor_shape = graph_input.type.tensor_type.shape.dim
             if len(tensor_shape) != len(input_shape):
                 raise ValueError(
                     f"{graph_input.name} expects shape of rank {len(tensor_shape)}, "
                     f"but calibration shape of rank {len(input_shape)} was passed."
                 )
-            for dim, new_dim_value in zip(tensor_shape, input_shape):
-                dim.dim_value = new_dim_value
+            for dim, new_dim_value in zip(tensor_shape, input_shape):
+                # Require positive, concrete dims to keep ONNX inference stable.
+                if int(new_dim_value) <= 0:
+                    raise ValueError(
+                        f"calibration_shapes for '{graph_input.name}' must be > 0; got {new_dim_value}"
+                    )
+                dim.dim_value = int(new_dim_value)

-    model = infer_shapes(model)
-    value_info_map = {vi.name: vi for vi in model.graph.value_info}
+    model = infer_shapes(model)

     nodes_to_exclude = []
     for matmul_node in matmul_nodes:
         output_name = matmul_node.outputs[0].name
-        value_info = value_info_map.get(output_name)
-        if not value_info:
-            raise RuntimeError(f"Shape inference did not find shape for {output_name}.")
+        # Search inputs, outputs, or value_info
+        value_info = get_tensor_from_name(model.graph, output_name)
+        if value_info is None:
+            raise RuntimeError(f"Shape inference did not find shape for {output_name}.")

         dims = value_info.type.tensor_type.shape.dim
         if all(isinstance(inp, Variable) for inp in matmul_node.inputs):
             if len(dims) < 2:
                 raise RuntimeError(f"Shape for {output_name} is incorrect.")
             if dims[-1].dim_value == 1 or dims[-2].dim_value == 1:
                 nodes_to_exclude.append(matmul_node.name)
         elif len(dims) < 3 and any(out.dim_value == 1 for out in dims):
             nodes_to_exclude.append(matmul_node.name)

     return nodes_to_exclude
🧹 Nitpick comments (1)
modelopt/onnx/quantization/graph_utils.py (1)

949-951: Docstring wording: “symbolic shape inference” → “shape inference”.

This path now uses ONNX shape inference via infer_shapes, not ORT SymbolicShapeInference. Update the docstring for accuracy.

-        calibration_shapes: Model input shapes for inference. If provided, symbolic shape inference will be used
-            instead of calibration_data_reader.
+        calibration_shapes: Model input shapes for inference. If provided, ONNX shape inference will be used
+            instead of calibration_data_reader.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bffe2ff and 7cf0d8f.

📒 Files selected for processing (1)
  • modelopt/onnx/quantization/graph_utils.py (5 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/graph_utils.py (1)
modelopt/onnx/utils.py (2)
  • infer_shapes (723-736)
  • parse_shapes_spec (234-247)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/onnx/quantization/graph_utils.py (1)

38-41: Let me check if graph_utils.py itself has any remaining references to SymbolicShapeInference:

Good switch to internal ONNX shape inference helper.

Using modelopt.onnx.utils.infer_shapes avoids the ORT SymbolicShapeInference dependency and handles >2GB models. The change in graph_utils.py successfully removes the ORT dependency without leaving dangling references. ✓

@gcunhase gcunhase force-pushed the dev/gcunhasergio/calib_shapes_fix_5597849 branch from 7cf0d8f to 543c2ca Compare October 20, 2025 22:39
@gcunhase gcunhase enabled auto-merge (squash) October 20, 2025 22:39
@codecov
Copy link

codecov bot commented Oct 21, 2025

Codecov Report

❌ Patch coverage is 25.00000% with 6 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.39%. Comparing base (bffe2ff) to head (663d9f0).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/quantization/graph_utils.py 14.28% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #453      +/-   ##
==========================================
- Coverage   73.42%   73.39%   -0.04%     
==========================================
  Files         180      180              
  Lines       17975    17976       +1     
==========================================
- Hits        13199    13193       -6     
- Misses       4776     4783       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gcunhase gcunhase merged commit 4476f21 into NVIDIA:main Oct 21, 2025
27 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants