-
Notifications
You must be signed in to change notification settings - Fork 169
[5256037] Automatically infer calibration shapes for per node calibration #394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughBroadens calibration_shapes to accept dicts, adds get_input_shapes for automatic input-shape extraction when unspecified, updates graph_utils and quantize flows to handle str|dict|None, and runs ONNX shape inference before loading in TRT utilities. Signatures updated in fp8.py and int8.py. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Quantize as quantize.py
participant GraphUtils as graph_utils.get_input_shapes
participant ONNX as ONNX Model
participant TRT as trt_utils
User->>Quantize: quantize(model_path, calibration_shapes=None|str|dict)
alt calibration_shapes provided
Quantize->>Quantize: use provided calibration_shapes
else not provided
Quantize->>GraphUtils: get_input_shapes(model_path)
GraphUtils->>ONNX: load model & read input tensor shapes
GraphUtils-->>Quantize: return {input_name: [dims]}
Quantize->>Quantize: use derived shapes
end
Note over Quantize: proceed with exclusion detection and quantization steps
par Preload inference (separate step)
TRT->>ONNX: infer_shapes_path(model_path)
Note right of TRT: ensure shapes inferred before model load
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/graph_utils.py (1)
1062-1077
: BUG: dict calibration_shapes are ignored; only strings are parsedWhen calibration_shapes is a dict (as passed by the new fallback), it’s currently discarded, so GEMV exclusions won’t use user/model-provided shapes.
Apply:
-def _exclude_matmuls_by_symbolic_inference( - model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None -) -> list[str]: +def _exclude_matmuls_by_symbolic_inference( + model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None +) -> list[str]: @@ - # Apply calibration shapes if provided - input_shapes = ( - parse_shapes_spec(calibration_shapes) - if (calibration_shapes and isinstance(calibration_shapes, str)) - else {} - ) + # Apply calibration shapes if provided (accept str spec or dict[str, list[int]|str]) + def _normalize_shapes(shapes) -> dict[str, list[int]]: + if shapes is None: + return {} + if isinstance(shapes, str): + return parse_shapes_spec(shapes) + if isinstance(shapes, dict): + norm = {} + for name, val in shapes.items(): + if isinstance(val, str): + norm[name] = list(map(int, val.split("x"))) + else: + norm[name] = list(val) + return norm + return {} + input_shapes = _normalize_shapes(calibration_shapes)
🧹 Nitpick comments (5)
modelopt/onnx/trt_utils.py (1)
269-271
: Make shape inference deterministic and resilientWrite the inferred model back to the same path and don’t fail hard if inference isn’t possible (e.g., read-only paths, ops not supported). Wrap with try/except and log.
- # Infer shapes - onnx.shape_inference.infer_shapes_path(onnx_path) + # Infer shapes (best‑effort): write back to the same path if possible + try: + # For ONNX >= 1.14, infer_shapes_path supports output path; this no-ops on some older versions. + onnx.shape_inference.infer_shapes_path(onnx_path, onnx_path) # type: ignore[call-arg] + except TypeError: + # Fallback to legacy signature + try: + onnx.shape_inference.infer_shapes_path(onnx_path) + except Exception as e: + logger.warning(f"Shape inference skipped: {e}") + except Exception as e: + logger.warning(f"Shape inference skipped: {e}")modelopt/onnx/quantization/quantize.py (3)
56-57
: Avoid name confusion with utils.get_input_shapesAlias the path-based helper to distinguish it from modelopt.onnx.utils.get_input_shapes(model).
-from modelopt.onnx.quantization.graph_utils import ( - cast_custom_ops, - find_nodes_from_mha_to_exclude, - get_input_shapes, - print_stat, - remove_redundant_cast_nodes, - validate_op_types_spelling, -) +from modelopt.onnx.quantization.graph_utils import ( + cast_custom_ops, + find_nodes_from_mha_to_exclude, + get_input_shapes as get_input_shapes_from_path, + print_stat, + remove_redundant_cast_nodes, + validate_op_types_spelling, +)
444-451
: Compute fallback calibration_shapes before constructing the DataReaderRandomDataProvider/CalibrationDataProvider may benefit from concrete shapes. Move the fallback ahead so both readers and GEMV exclusion use the same shapes.
- # Use random scales if calibration data is not supplied - if calibration_data is None: - calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes) - else: - calibration_data_reader = CalibrationDataProvider( - onnx_path, calibration_data, calibration_shapes - ) + # Derive calibration shapes from the model if not provided + if not calibration_shapes: + calibration_shapes = get_input_shapes_from_path(onnx_path) + + # Use random scales if calibration data is not supplied + if calibration_data is None: + calibration_data_reader = RandomDataProvider(onnx_path, calibration_shapes) + else: + calibration_data_reader = CalibrationDataProvider( + onnx_path, calibration_data, calibration_shapes + )
473-475
: Remove duplicated fallback (moved earlier)This block becomes redundant after moving the fallback before reader construction.
- if not calibration_shapes: - calibration_shapes = get_input_shapes(onnx_path) + # handled earlier before constructing calibration_data_readermodelopt/onnx/quantization/graph_utils.py (1)
75-82
: Reuse existing utility and sanitize unknown dimsAvoid duplicating logic and ensure unknown dims (0) don’t leak into calibration. Use the model-based helper and coerce 0→1 for stability.
-def get_input_shapes(onnx_path: str) -> dict[str, list[int]]: - """Returns the input shapes of the given ONNX model.""" - onnx_model = onnx.load(onnx_path) - input_shape_dict = {} - for input in onnx_model.graph.input: - input_shape_dict[input.name] = [x.dim_value for x in input.type.tensor_type.shape.dim] - return input_shape_dict +def get_input_shapes(onnx_path: str) -> dict[str, list[int]]: + """Return external input shapes of the given ONNX model (0→1 for unknown dims).""" + onnx_model = onnx.load(onnx_path, load_external_data=True) + from modelopt.onnx.utils import get_input_shapes as _get_input_shapes # local import to avoid cycle + shapes = _get_input_shapes(onnx_model, external_inputs_only=True) + # Normalize unknowns to 1 for downstream symbolic inference and data generation + return {k: [d if isinstance(d, int) and d > 0 else 1 for d in v] for k, v in shapes.items()}
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
modelopt/onnx/quantization/fp8.py
(1 hunks)modelopt/onnx/quantization/graph_utils.py
(4 hunks)modelopt/onnx/quantization/int8.py
(1 hunks)modelopt/onnx/quantization/quantize.py
(3 hunks)modelopt/onnx/trt_utils.py
(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/onnx/quantization/graph_utils.py (1)
modelopt/onnx/utils.py (2)
get_input_shapes
(219-226)parse_shapes_spec
(234-247)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/quantization/graph_utils.py (1)
get_input_shapes
(75-81)modelopt/onnx/utils.py (1)
get_input_shapes
(219-226)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/onnx/quantization/int8.py (1)
118-118
: Signature broadened — ensure end-to-end dict supportAccepting dict is fine here. Note that graph_utils._exclude_matmuls_by_symbolic_inference currently ignores dict calibration_shapes and only parses strings, so dicts passed from here won’t take effect for GEMV exclusion until that’s fixed. I’ve proposed a patch in graph_utils.py.
modelopt/onnx/quantization/fp8.py (1)
167-167
: Signature broadened — align with symbolic GEMV exclusionSame note as INT8: passing dict shapes is supported by this signature, but graph_utils._exclude_matmuls_by_symbolic_inference currently ignores dicts. See proposed fix in graph_utils.py.
modelopt/onnx/quantization/graph_utils.py (1)
935-936
: API accepts dict — ensure downstream actually uses itfind_nodes_from_matmul_to_exclude correctly exposes calibration_shapes as str | dict | None. The symbolic path below must consume dicts; see patch on _exclude_matmuls_by_symbolic_inference.
…tion Signed-off-by: ajrasane <[email protected]>
af23ccd
to
94975f1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/onnx_ptq/README.md
(1 hunks)modelopt/onnx/quantization/fp8.py
(1 hunks)modelopt/onnx/quantization/graph_utils.py
(4 hunks)modelopt/onnx/quantization/int8.py
(1 hunks)modelopt/onnx/quantization/quantize.py
(3 hunks)modelopt/onnx/trt_utils.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- modelopt/onnx/quantization/graph_utils.py
- modelopt/onnx/quantization/fp8.py
- modelopt/onnx/quantization/int8.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/quantization/graph_utils.py (1)
get_input_shapes
(75-81)modelopt/onnx/utils.py (1)
get_input_shapes
(219-226)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
if not calibration_shapes: | ||
calibration_shapes = get_input_shapes(onnx_path) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against zero/unknown dims when auto-populating calibration shapes
graph_utils.get_input_shapes
returns 0
for any dynamic axis (the common “batch” dim). Feeding that dict directly into the calibration readers produces zero-sized batches and the per-node calibrator errors out (“need at least one calibration sample”). We need to verify every inferred dimension is >0 (or fallback to a safe default) before adopting the auto-generated shapes; otherwise force the caller to supply explicit shapes.
🤖 Prompt for AI Agents
In modelopt/onnx/quantization/quantize.py around lines 473-475, the
auto-populated calibration_shapes from get_input_shapes may contain zero values
for dynamic axes (e.g., batch), which leads to zero-sized batches; validate
every inferred dimension and if any dimension is <= 0 either replace it with a
safe default (e.g., 1) and emit a warning or raise a clear ValueError forcing
the caller to provide explicit shapes; implement the check immediately after
calling get_input_shapes, update calibration_shapes with the fixed dimensions,
and log/raise appropriately so downstream calibration readers never receive
zero-sized dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ajrasane are there any risks with always setting calibration_shapes
or should we only set it if --calibrate_per_node
is given?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, need to check what is the shape for dynamic shapes: 1 or 0, as indicated by the CodeRabbit comment above.
# Infer shapes | ||
onnx.shape_inference.infer_shapes_path(onnx_path) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not call infer_shapes_path
on large/external-data models
onnx.shape_inference.infer_shapes_path
rewrites the model in-place using onnx.save_model
without save_as_external_data=True
. For the >2 GB ONNX graphs we routinely quantize (Llama, ViT-H, etc.), this raises ValueError: Message onnx.ModelProto exceeds maximum protobuf size of 2GB
and stops the run—this worked before the change. Please switch to running shape inference on an already loaded ModelProto
(e.g., onnx.shape_inference.infer_shapes(onnx_model)
) and persist it via our existing save_onnx
, which already handles external data formats, instead of infer_shapes_path
.
🤖 Prompt for AI Agents
In modelopt/onnx/trt_utils.py around lines 269 to 271, do not call
onnx.shape_inference.infer_shapes_path(onnx_path) because it rewrites and saves
the model without external-data support and will fail for >2GB models; instead
load the ModelProto first, run onnx.shape_inference.infer_shapes(onnx_model) on
that in-memory proto, then persist the result using our existing save_onnx
function (which handles external data) so shape inference is applied but saving
uses the external-data-aware saver.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ajrasane look into this suggestion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is incorrect. The infer_shapes_path API is specifically designed to handle models larger than 2GB. It should also be able to handle smaller models if we pass their path: https://onnx.ai/onnx/api/shape_inference.html#infer-shapes-path
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #394 +/- ##
=======================================
Coverage 73.79% 73.80%
=======================================
Files 171 171
Lines 17583 17592 +9
=======================================
+ Hits 12975 12983 +8
- Misses 4608 4609 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…tion (#394) Signed-off-by: ajrasane <[email protected]>
What does this PR do?
Type of change: Bug fix
Overview:
Testing
Unit and integration tests pass
Before your PR is "Ready for review"
Summary by CodeRabbit
New Features
Documentation