Skip to content

Conversation

ajrasane
Copy link
Contributor

@ajrasane ajrasane commented Sep 30, 2025

What does this PR do?

Type of change: Bug fix

Overview:

  • infer shapes while loading the ONNX model
  • Automatically infer input shapes before the quantization process begins
python -m modelopt.onnx.quantization --onnx_path=ViT-large_inpsize_1x3x1024x2048_opsetv_15.onnx --quantize_mode=int8 --calibration_eps=cuda:0 --calibrate_per_node --log_level=DEBUG

Testing

Unit and integration tests pass

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: No

Summary by CodeRabbit

  • New Features

    • Automatically derives calibration input shapes from ONNX models when not provided, simplifying quantization setup.
    • Accepts calibration shapes as either a dictionary or a string for greater flexibility.
    • Adds a utility to retrieve input tensor shapes from ONNX models.
    • Infers tensor shapes before ONNX/TensorRT processing, improving reliability.
  • Documentation

    • Simplified ONNX PTQ example by removing explicit per-node calibration shape specification.

@ajrasane ajrasane requested a review from cjluo-nv September 30, 2025 21:21
@ajrasane ajrasane requested a review from a team as a code owner September 30, 2025 21:21
@ajrasane ajrasane requested a review from i-riyad September 30, 2025 21:21
Copy link

coderabbitai bot commented Sep 30, 2025

Walkthrough

Broadens calibration_shapes to accept dicts, adds get_input_shapes for automatic input-shape extraction when unspecified, updates graph_utils and quantize flows to handle str|dict|None, and runs ONNX shape inference before loading in TRT utilities. Signatures updated in fp8.py and int8.py.

Changes

Cohort / File(s) Summary
Calibration shapes type expansion
modelopt/onnx/quantization/fp8.py, modelopt/onnx/quantization/int8.py
Updated quantize signatures: calibration_shapes type changed from str | None to str | dict | None. No other logic changes.
Graph utilities: new API and shape handling
modelopt/onnx/quantization/graph_utils.py
Added get_input_shapes(onnx_path) -> dict[str, list[int]]. Updated find_nodes_from_matmul_to_exclude(...) and _exclude_matmuls_by_symbolic_inference(...) to accept str | dict | None for calibration_shapes; only parse when calibration_shapes is a string.
Automatic calibration shapes derivation
modelopt/onnx/quantization/quantize.py
When calibration_shapes is not provided, derive shapes via get_input_shapes(onnx_path) during quantization preprocessing and main flow. Imported get_input_shapes.
TRT utilities shape inference
modelopt/onnx/trt_utils.py
Invoke onnx.shape_inference.infer_shapes_path(onnx_path) before loading the ONNX model to ensure inferred shapes are available.
Docs/example update
examples/onnx_ptq/README.md
Removed explicit per-node calibration_shapes example flag (--calibration_shapes=input:1x3x224x224) from README sample command.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Quantize as quantize.py
  participant GraphUtils as graph_utils.get_input_shapes
  participant ONNX as ONNX Model
  participant TRT as trt_utils

  User->>Quantize: quantize(model_path, calibration_shapes=None|str|dict)
  alt calibration_shapes provided
    Quantize->>Quantize: use provided calibration_shapes
  else not provided
    Quantize->>GraphUtils: get_input_shapes(model_path)
    GraphUtils->>ONNX: load model & read input tensor shapes
    GraphUtils-->>Quantize: return {input_name: [dims]}
    Quantize->>Quantize: use derived shapes
  end
  Note over Quantize: proceed with exclusion detection and quantization steps

  par Preload inference (separate step)
    TRT->>ONNX: infer_shapes_path(model_path)
    Note right of TRT: ensure shapes inferred before model load
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I nibble bytes where tensors grow,
Shapes unfold like fields of snow—
A carrot compass points to dicts,
ONNX hums soft inferred scripts.
I hop, I thump: "Calibrate!"—then go. 🐇

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly conveys the primary change of automatically inferring calibration shapes for per-node calibration, directly reflecting the PR’s main enhancement. It is specific and clear, using terminology consistent with the code’s domain. Although it includes a bracketed reference, this does not obscure the core intent.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch ajrasane/input_shapes

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/graph_utils.py (1)

1062-1077: BUG: dict calibration_shapes are ignored; only strings are parsed

When calibration_shapes is a dict (as passed by the new fallback), it’s currently discarded, so GEMV exclusions won’t use user/model-provided shapes.

Apply:

-def _exclude_matmuls_by_symbolic_inference(
-    model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
-) -> list[str]:
+def _exclude_matmuls_by_symbolic_inference(
+    model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
+) -> list[str]:
@@
-    # Apply calibration shapes if provided
-    input_shapes = (
-        parse_shapes_spec(calibration_shapes)
-        if (calibration_shapes and isinstance(calibration_shapes, str))
-        else {}
-    )
+    # Apply calibration shapes if provided (accept str spec or dict[str, list[int]|str])
+    def _normalize_shapes(shapes) -> dict[str, list[int]]:
+        if shapes is None:
+            return {}
+        if isinstance(shapes, str):
+            return parse_shapes_spec(shapes)
+        if isinstance(shapes, dict):
+            norm = {}
+            for name, val in shapes.items():
+                if isinstance(val, str):
+                    norm[name] = list(map(int, val.split("x")))
+                else:
+                    norm[name] = list(val)
+            return norm
+        return {}
+    input_shapes = _normalize_shapes(calibration_shapes)
🧹 Nitpick comments (5)
modelopt/onnx/trt_utils.py (1)

269-271: Make shape inference deterministic and resilient

Write the inferred model back to the same path and don’t fail hard if inference isn’t possible (e.g., read-only paths, ops not supported). Wrap with try/except and log.

-    # Infer shapes
-    onnx.shape_inference.infer_shapes_path(onnx_path)
+    # Infer shapes (best‑effort): write back to the same path if possible
+    try:
+        # For ONNX >= 1.14, infer_shapes_path supports output path; this no-ops on some older versions.
+        onnx.shape_inference.infer_shapes_path(onnx_path, onnx_path)  # type: ignore[call-arg]
+    except TypeError:
+        # Fallback to legacy signature
+        try:
+            onnx.shape_inference.infer_shapes_path(onnx_path)
+        except Exception as e:
+            logger.warning(f"Shape inference skipped: {e}")
+    except Exception as e:
+        logger.warning(f"Shape inference skipped: {e}")
modelopt/onnx/quantization/quantize.py (3)

56-57: Avoid name confusion with utils.get_input_shapes

Alias the path-based helper to distinguish it from modelopt.onnx.utils.get_input_shapes(model).

-from modelopt.onnx.quantization.graph_utils import (
-    cast_custom_ops,
-    find_nodes_from_mha_to_exclude,
-    get_input_shapes,
-    print_stat,
-    remove_redundant_cast_nodes,
-    validate_op_types_spelling,
-)
+from modelopt.onnx.quantization.graph_utils import (
+    cast_custom_ops,
+    find_nodes_from_mha_to_exclude,
+    get_input_shapes as get_input_shapes_from_path,
+    print_stat,
+    remove_redundant_cast_nodes,
+    validate_op_types_spelling,
+)

444-451: Compute fallback calibration_shapes before constructing the DataReader

RandomDataProvider/CalibrationDataProvider may benefit from concrete shapes. Move the fallback ahead so both readers and GEMV exclusion use the same shapes.

-    # Use random scales if calibration data is not supplied
-    if calibration_data is None:
-        calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes)
-    else:
-        calibration_data_reader = CalibrationDataProvider(
-            onnx_path, calibration_data, calibration_shapes
-        )
+    # Derive calibration shapes from the model if not provided
+    if not calibration_shapes:
+        calibration_shapes = get_input_shapes_from_path(onnx_path)
+
+    # Use random scales if calibration data is not supplied
+    if calibration_data is None:
+        calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes)
+    else:
+        calibration_data_reader = CalibrationDataProvider(
+            onnx_path, calibration_data, calibration_shapes
+        )

473-475: Remove duplicated fallback (moved earlier)

This block becomes redundant after moving the fallback before reader construction.

-    if not calibration_shapes:
-        calibration_shapes = get_input_shapes(onnx_path)
+    # handled earlier before constructing calibration_data_reader
modelopt/onnx/quantization/graph_utils.py (1)

75-82: Reuse existing utility and sanitize unknown dims

Avoid duplicating logic and ensure unknown dims (0) don’t leak into calibration. Use the model-based helper and coerce 0→1 for stability.

-def get_input_shapes(onnx_path: str) -> dict[str, list[int]]:
-    """Returns the input shapes of the given ONNX model."""
-    onnx_model = onnx.load(onnx_path)
-    input_shape_dict = {}
-    for input in onnx_model.graph.input:
-        input_shape_dict[input.name] = [x.dim_value for x in input.type.tensor_type.shape.dim]
-    return input_shape_dict
+def get_input_shapes(onnx_path: str) -> dict[str, list[int]]:
+    """Return external input shapes of the given ONNX model (0→1 for unknown dims)."""
+    onnx_model = onnx.load(onnx_path, load_external_data=True)
+    from modelopt.onnx.utils import get_input_shapes as _get_input_shapes  # local import to avoid cycle
+    shapes = _get_input_shapes(onnx_model, external_inputs_only=True)
+    # Normalize unknowns to 1 for downstream symbolic inference and data generation
+    return {k: [d if isinstance(d, int) and d > 0 else 1 for d in v] for k, v in shapes.items()}
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 17439e6 and af23ccd.

📒 Files selected for processing (5)
  • modelopt/onnx/quantization/fp8.py (1 hunks)
  • modelopt/onnx/quantization/graph_utils.py (4 hunks)
  • modelopt/onnx/quantization/int8.py (1 hunks)
  • modelopt/onnx/quantization/quantize.py (3 hunks)
  • modelopt/onnx/trt_utils.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/graph_utils.py (1)
modelopt/onnx/utils.py (2)
  • get_input_shapes (219-226)
  • parse_shapes_spec (234-247)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/quantization/graph_utils.py (1)
  • get_input_shapes (75-81)
modelopt/onnx/utils.py (1)
  • get_input_shapes (219-226)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/onnx/quantization/int8.py (1)

118-118: Signature broadened — ensure end-to-end dict support

Accepting dict is fine here. Note that graph_utils._exclude_matmuls_by_symbolic_inference currently ignores dict calibration_shapes and only parses strings, so dicts passed from here won’t take effect for GEMV exclusion until that’s fixed. I’ve proposed a patch in graph_utils.py.

modelopt/onnx/quantization/fp8.py (1)

167-167: Signature broadened — align with symbolic GEMV exclusion

Same note as INT8: passing dict shapes is supported by this signature, but graph_utils._exclude_matmuls_by_symbolic_inference currently ignores dicts. See proposed fix in graph_utils.py.

modelopt/onnx/quantization/graph_utils.py (1)

935-936: API accepts dict — ensure downstream actually uses it

find_nodes_from_matmul_to_exclude correctly exposes calibration_shapes as str | dict | None. The symbolic path below must consume dicts; see patch on _exclude_matmuls_by_symbolic_inference.

@ajrasane ajrasane force-pushed the ajrasane/input_shapes branch from af23ccd to 94975f1 Compare September 30, 2025 21:29
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between af23ccd and 94975f1.

📒 Files selected for processing (6)
  • examples/onnx_ptq/README.md (1 hunks)
  • modelopt/onnx/quantization/fp8.py (1 hunks)
  • modelopt/onnx/quantization/graph_utils.py (4 hunks)
  • modelopt/onnx/quantization/int8.py (1 hunks)
  • modelopt/onnx/quantization/quantize.py (3 hunks)
  • modelopt/onnx/trt_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • modelopt/onnx/quantization/graph_utils.py
  • modelopt/onnx/quantization/fp8.py
  • modelopt/onnx/quantization/int8.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/quantization/graph_utils.py (1)
  • get_input_shapes (75-81)
modelopt/onnx/utils.py (1)
  • get_input_shapes (219-226)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality

Comment on lines +473 to +475
if not calibration_shapes:
calibration_shapes = get_input_shapes(onnx_path)

Copy link

@coderabbitai coderabbitai bot Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against zero/unknown dims when auto-populating calibration shapes

graph_utils.get_input_shapes returns 0 for any dynamic axis (the common “batch” dim). Feeding that dict directly into the calibration readers produces zero-sized batches and the per-node calibrator errors out (“need at least one calibration sample”). We need to verify every inferred dimension is >0 (or fallback to a safe default) before adopting the auto-generated shapes; otherwise force the caller to supply explicit shapes.

🤖 Prompt for AI Agents
In modelopt/onnx/quantization/quantize.py around lines 473-475, the
auto-populated calibration_shapes from get_input_shapes may contain zero values
for dynamic axes (e.g., batch), which leads to zero-sized batches; validate
every inferred dimension and if any dimension is <= 0 either replace it with a
safe default (e.g., 1) and emit a warning or raise a clear ValueError forcing
the caller to provide explicit shapes; implement the check immediately after
calling get_input_shapes, update calibration_shapes with the fixed dimensions,
and log/raise appropriately so downstream calibration readers never receive
zero-sized dimensions.

Copy link
Contributor

@gcunhase gcunhase Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajrasane are there any risks with always setting calibration_shapes or should we only set it if --calibrate_per_node is given?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, need to check what is the shape for dynamic shapes: 1 or 0, as indicated by the CodeRabbit comment above.

Comment on lines +269 to +271
# Infer shapes
onnx.shape_inference.infer_shapes_path(onnx_path)

Copy link

@coderabbitai coderabbitai bot Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Do not call infer_shapes_path on large/external-data models

onnx.shape_inference.infer_shapes_path rewrites the model in-place using onnx.save_model without save_as_external_data=True. For the >2 GB ONNX graphs we routinely quantize (Llama, ViT-H, etc.), this raises ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB and stops the run—this worked before the change. Please switch to running shape inference on an already loaded ModelProto (e.g., onnx.shape_inference.infer_shapes(onnx_model)) and persist it via our existing save_onnx, which already handles external data formats, instead of infer_shapes_path.

🤖 Prompt for AI Agents
In modelopt/onnx/trt_utils.py around lines 269 to 271, do not call
onnx.shape_inference.infer_shapes_path(onnx_path) because it rewrites and saves
the model without external-data support and will fail for >2GB models; instead
load the ModelProto first, run onnx.shape_inference.infer_shapes(onnx_model) on
that in-memory proto, then persist the result using our existing save_onnx
function (which handles external data) so shape inference is applied but saving
uses the external-data-aware saver.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ajrasane look into this suggestion

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is incorrect. The infer_shapes_path API is specifically designed to handle models larger than 2GB. It should also be able to handle smaller models if we pass their path: https://onnx.ai/onnx/api/shape_inference.html#infer-shapes-path

Copy link

codecov bot commented Sep 30, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.80%. Comparing base (17439e6) to head (94975f1).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #394   +/-   ##
=======================================
  Coverage   73.79%   73.80%           
=======================================
  Files         171      171           
  Lines       17583    17592    +9     
=======================================
+ Hits        12975    12983    +8     
- Misses       4608     4609    +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ajrasane ajrasane merged commit 0d51156 into main Oct 1, 2025
30 of 32 checks passed
@ajrasane ajrasane deleted the ajrasane/input_shapes branch October 1, 2025 17:07
kevalmorabia97 pushed a commit that referenced this pull request Oct 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants