-
Notifications
You must be signed in to change notification settings - Fork 169
Making stronglyTyped default for modelopt evaluation #287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
WalkthroughReplaces CLI Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as User / CI
participant Eval as examples/onnx_ptq/evaluate.py
participant Deploy as Deployment config
participant EB as TRT EngineBuilder
participant TRT as TensorRT
U->>Eval: run evaluate (--engine_precision)
Eval->>Deploy: build deployment (precision = args.engine_precision)
Deploy->>EB: build_engine(trt_mode = precision)
alt trt_mode == STRONGLY_TYPED
Note right of EB #b3e5fc: treated as low-bit → opt_level = 4
EB->>TRT: create engine(opt_level=4)
else other modes
EB->>TRT: create engine(opt_level=builder_optimization_level)
end
TRT-->>Eval: engine ready / run inference
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~30 minutes Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
12911e6
to
97cd184
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #287 +/- ##
=======================================
Coverage 73.86% 73.87%
=======================================
Files 172 172
Lines 17415 17416 +1
=======================================
+ Hits 12864 12866 +2
+ Misses 4551 4550 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
97cd184
to
22dde2f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/onnx_ptq/README.md (1)
119-125
: Remove unsupported flag from docs
Theevaluate.py
CLI no longer accepts--quantize_mode
(or any precision flag); drop--quantize_mode=stronglyTyped
from the ONNX PTQ snippet (lines 119–125) and the later LLM evaluation example inexamples/onnx_ptq/README.md
.modelopt/onnx/quantization/int8.py (1)
156-172
: Fix None deref:nodes_to_exclude
may beNone
before.extend()
.
nodes_to_exclude.extend(...)
will crash when the param is not provided. Normalize optionals up front.logger.info("Detecting GEMV patterns for TRT optimization") matmul_nodes_to_exclude = find_nodes_from_matmul_to_exclude( onnx_path, use_external_data_format, intermediate_generated_files, calibration_data_reader, calibration_eps, calibration_shapes, ) - nodes_to_exclude.extend(matmul_nodes_to_exclude) # type: ignore[union-attr] + # Normalize optionals before use + nodes_to_exclude = list(nodes_to_exclude or []) + intermediate_generated_files = list(intermediate_generated_files or []) + custom_ops_to_quantize = list(custom_ops_to_quantize or []) + nodes_to_exclude.extend(matmul_nodes_to_exclude)
🧹 Nitpick comments (8)
examples/onnx_ptq/evaluation.py (1)
29-34
: Defaulting tostronglyTyped
is good; consider a lightweight override.Keeping stronglyTyped as default aligns with TRT guidance. As a convenience, allow an env override (no CLI churn) so users can quickly A/B test.
import torch import torchvision.transforms as transforms from torchvision.datasets import ImageNet from tqdm import tqdm +import os ... deployment = { "runtime": "TRT", "accelerator": "GPU", - "precision": "stronglyTyped", + "precision": os.getenv("TRT_PRECISION", "stronglyTyped"), "onnx_opset": "21", }Please confirm the minimal TRT version that supports strongly typed networks here and mention it in the README “Evaluate” section.
modelopt/onnx/quantization/int8.py (2)
127-135
: Validate or warn on unsupportedhigh_precision_dtype
values.You only handle {"fp16","bf16"}; other strings silently no-op. Emit a warning for unexpected values to aid debugging.
- if high_precision_dtype in ["fp16", "bf16"]: + if high_precision_dtype in ["fp16", "bf16"]: ... + else: + logger.warning( + "Unknown high_precision_dtype '%s'; skipping float downcast. Expected one of {'fp16','bf16'}.", + high_precision_dtype, + )Also applies to: 275-285
113-135
: Avoid mutable defaults for list parameters (future-proofing).
intermediate_generated_files
,custom_ops_to_quantize
have mutable defaults. You mitigated at runtime above; consider following up to switch toNone
defaults in a future API clean-up.modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py (1)
103-111
: Confirm adding STRONGLY_TYPED to “low-bit” bucket (forces opt level 4).Treating STRONGLY_TYPED as low-bit changes builderOptimizationLevel to 4 and overrides user-provided levels. Is that intentional for all stronglyTyped builds, including fp16/fp32-typed graphs? If yes, consider renaming
_is_low_bit_mode
to reflect the broader intent, or add an inline comment to avoid confusion. Otherwise, gate the opt-level bump more narrowly.Suggested inline doc tweak:
def _is_low_bit_mode(trt_mode: str) -> bool: - return trt_mode in [ + # Modes that benefit from max builderOptimizationLevel (4), not just "low-bit". + return trt_mode in [ TRTMode.INT8, TRTMode.INT4, TRTMode.FLOAT8, TRTMode.BEST, TRTMode.STRONGLY_TYPED, ]tests/examples/test_onnx_ptq.sh (2)
160-177
: Fix array-membership test (ShellCheck SC2199/SC2076).
[[ " ${latency_models[@]} " =~ " $model_name " ]]
is brittle. Use a loop or regex without quotes.Apply:
- if [[ " ${latency_models[@]} " =~ " $model_name " ]]; then + in_latency_set=false + for m in "${latency_models[@]}"; do + if [[ "$m" == "$model_name" ]]; then in_latency_set=true; break; fi + done + if $in_latency_set; then
48-56
: Consider removing int8_iq from modes to avoid confusion.You skip quantization for
int8_iq
and evaluate FP16 withprecision="int8"
. If IQ is deprecated, drop it fromquant_modes
and the evaluation mapping to simplify.Proposed minimal:
-if [ $cuda_capability -ge 89 ]; then - quant_modes=("fp8" "int8" "int8_iq") +if [ $cuda_capability -ge 89 ]; then + quant_modes=("fp8" "int8") else echo "CUDA capability is less than 89, skipping fp8 mode!" - quant_modes=("int8" "int8_iq") + quant_modes=("int8") fi -all_modes=("${base_modes[@]}" "${quant_modes[@]}") +all_modes=("${base_modes[@]}" "${quant_modes[@]}")modelopt/onnx/quantization/quantize.py (2)
222-222
: API tightening may break external callers; offer soft-landing.Changing
high_precision_dtype
from optional to required can break programmatic users. Consider accepting Optional at the type level and defaulting at runtime for back-compat.-def quantize( +def quantize( @@ - high_precision_dtype: str = "fp16", + high_precision_dtype: str | None = "fp16", @@ ) -> None:And near the top of the body:
- configure_logging(log_level.upper(), log_file) + configure_logging(log_level.upper(), log_file) + if high_precision_dtype is None: + high_precision_dtype = "fp16"
289-296
: Docstring reads well; minor clarity tweak optional.Suggest noting explicitly that no conversion occurs if the input is already fp16/bf16.
- and the input model is of dtype fp32, model's weight and activation will be converted to - 'fp16' or 'bf16'. + and the input model is of dtype fp32, weights and activations are converted accordingly. + If the input is already fp16/bf16, no conversion is applied.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
examples/onnx_ptq/README.md
(1 hunks)examples/onnx_ptq/evaluate.py
(1 hunks)examples/onnx_ptq/evaluation.py
(1 hunks)modelopt/onnx/quantization/__main__.py
(2 hunks)modelopt/onnx/quantization/int8.py
(1 hunks)modelopt/onnx/quantization/qdq_utils.py
(2 hunks)modelopt/onnx/quantization/quantize.py
(3 hunks)modelopt/torch/_deploy/_runtime/tensorrt/constants.py
(0 hunks)modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py
(1 hunks)modelopt/torch/_deploy/_runtime/trt_client.py
(0 hunks)tests/_test_utils/onnx_quantization/utils.py
(1 hunks)tests/examples/test_onnx_ptq.sh
(2 hunks)tests/unit/onnx/test_qdq_rules_int8.py
(1 hunks)
💤 Files with no reviewable changes (2)
- modelopt/torch/_deploy/_runtime/tensorrt/constants.py
- modelopt/torch/_deploy/_runtime/trt_client.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py (1)
modelopt/torch/_deploy/_runtime/tensorrt/constants.py (1)
TRTMode
(83-91)
🪛 Shellcheck (0.10.0)
tests/examples/test_onnx_ptq.sh
[error] 176-176: Arrays implicitly concatenate in [[ ]]. Use a loop (or explicit * instead of @).
(SC2199)
[warning] 176-176: Remove quotes from right-hand side of =~ to match as a regex rather than literally.
(SC2076)
🔇 Additional comments (5)
tests/unit/onnx/test_qdq_rules_int8.py (1)
44-47
: LGTM: safer guard before checking producer op.Requiring
inp.inputs
avoids probing graph inputs/constants. Good stability improvement.modelopt/onnx/quantization/qdq_utils.py (1)
998-1000
: In-place node list clear is fine.
del graph.node[:]
preserves references while rebuilding the node list; the change looks good.modelopt/onnx/quantization/__main__.py (2)
183-189
: Good default: --high_precision_dtype now fp16 with clear help text.
285-286
: Passing dtype through unconditionally looks right.Directly forwarding
args.high_precision_dtype
aligns with the stricter API and removes hidden fallbacks.modelopt/onnx/quantization/quantize.py (1)
465-466
: Good: dtype forwarded into int8/fp8 paths (removes hidden fallback).
examples/onnx_ptq/evaluate.py
Outdated
# Note. stronglyTyped is recommended, all other modes have been deprecated in TensorRT | ||
deployment = { | ||
"runtime": "TRT", | ||
"precision": args.quantize_mode, | ||
"precision": "stronglyTyped", | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
evaluate.py drops --quantize_mode but tests still pass it → argparse error.
Current tests invoke evaluate.py with --quantize_mode
; without accepting it here, argparse will fail. Easiest fix: accept and ignore it (deprecated), preserving back-compat.
Apply:
def main():
parser = argparse.ArgumentParser()
@@
args = parser.parse_args()
- # Note. stronglyTyped is recommended, all other modes have been deprecated in TensorRT
+ # Back-compat: accept deprecated --quantize_mode but ignore its value.
+ # StronglyTyped is recommended; other modes are deprecated in TensorRT.
+ if hasattr(args, "quantize_mode"):
+ pass
+
+ # Note. stronglyTyped is recommended, all other modes have been deprecated in TensorRT
deployment = {
"runtime": "TRT",
"precision": "stronglyTyped",
}
If you prefer to remove the flag from callers instead, I can update the test script calls too (see test comment).
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/onnx_ptq/evaluate.py around lines 57 to 61, the script no longer
accepts the deprecated --quantize_mode argument so tests that pass it cause an
argparse error; update the argument parser to add a --quantize_mode option
(accepting the same expected values or a string) and simply ignore its value at
runtime to preserve backward compatibility, and mark it in the help text as
deprecated; no further behavior changes are needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a terrible idea. Need to fix the tests instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
@galagam You're absolutely right! Maintaining backward compatibility for a deprecated argument goes against the clean API approach. Let me check the test files and provide fixes for them instead.
🏁 Script executed:
#!/bin/bash
# Find all files that call evaluate.py with --quantize_mode
echo "=== Searching for files that call evaluate.py with --quantize_mode ==="
rg -n --type=sh --type=py -C3 "evaluate\.py.*--quantize_mode"
echo -e "\n=== Looking at test files that might need updating ==="
fd -e sh -e py test | xargs rg -l "evaluate\.py"
Length of output: 158
🏁 Script executed:
#!/bin/bash
# Examine the test file that calls evaluate.py
echo "=== Content of tests/examples/test_onnx_ptq.sh ==="
cat tests/examples/test_onnx_ptq.sh
echo -e "\n=== Searching for any --quantize_mode usage in test files ==="
rg -n --type=sh --type=py "quantize_mode" tests/
Length of output: 11217
Perfect! I found the issue. The test script is still passing --quantize_mode
to evaluate.py
on lines 180 and 189, but evaluate.py
no longer accepts this argument. Here's the fix:
The fix is to simply remove --quantize_mode=$precision
from both evaluation calls in the test script, since evaluate.py
now hardcodes "stronglyTyped"
precision.
Fix for tests/examples/test_onnx_ptq.sh:
if [[ " ${latency_models[@]} " =~ " $model_name " ]]; then
CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \
--onnx_path=$eval_model_path \
--model_name="${timm_model_name[$model_name]}" \
- --quantize_mode=$precision \
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
else
CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \
--onnx_path=$eval_model_path \
--imagenet_path=$imagenet_path \
--eval_data_size=$calib_size \
--batch_size $batch_size \
--model_name="${timm_model_name[$model_name]}" \
- --quantize_mode=$precision \
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
This removes the deprecated argument from both evaluation paths and aligns the tests with the new API where evaluate.py
always uses "stronglyTyped"
precision.
🏁 Script executed:
#!/bin/bash
# Show the specific lines that need to be fixed
echo "=== Lines that need to be fixed in tests/examples/test_onnx_ptq.sh ==="
sed -n '175,195p' tests/examples/test_onnx_ptq.sh
Length of output: 1097
Exactly! Here's the precise fix needed for tests/examples/test_onnx_ptq.sh:
Remove --quantize_mode=$precision
from both evaluate.py calls:
Line ~180 (latency models branch):
CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \
--onnx_path=$eval_model_path \
--model_name="${timm_model_name[$model_name]}" \
- --quantize_mode=$precision \
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
Line ~189 (regular evaluation branch):
CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \
--onnx_path=$eval_model_path \
--imagenet_path=$imagenet_path \
--eval_data_size=$calib_size \
--batch_size $batch_size \
--model_name="${timm_model_name[$model_name]}" \
- --quantize_mode=$precision \
--results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
Optional cleanup: Since $precision
is no longer used, you could also remove the precision assignment logic (lines ~166-174), but that's not strictly necessary for fixing the immediate issue.
This approach is much cleaner than maintaining deprecated arguments - thanks for steering toward the right solution! 🎯
examples/onnx_ptq/evaluate.py
Outdated
# Note. stronglyTyped is recommended, all other modes have been deprecated in TensorRT | ||
deployment = { | ||
"runtime": "TRT", | ||
"precision": args.quantize_mode, | ||
"precision": "stronglyTyped", | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a terrible idea. Need to fix the tests instead.
MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default. | ||
If quantize_mode == 'fp8' and mha_accumulation_dtype == 'fp32', Cast nodes will be added to | ||
MHA's bmm1 and bmm2's input and output tensors. | ||
MHA accumulation dtype. One of ['fp32', 'fp16']. If quantize_mode == 'fp8' and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MHA accumulation dtype. One of ['fp32', 'fp16']. If quantize_mode == 'fp8' and | |
MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default. If quantize_mode == 'fp8' and |
22dde2f
to
f2c5ff9
Compare
Signed-off-by: Riyad Islam <[email protected]>
f2c5ff9
to
9123d26
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (2)
tests/_test_utils/onnx_quantization/utils.py (1)
23-29
: Fix Cast hop indexing and add None-guards to avoid AttributeError.After stepping into a Cast, its input index is 0. Reusing
inp_idx
is wrong and there’s no guard for missing producers. This can raise on models where inputs aren’t quantized or Cast is terminal.- qnode = node - # After quantization, the quantized node can be casted - if qnode.i(inp_idx).op == "Cast": - qnode = qnode.i(inp_idx) - assert qnode.i(inp_idx).op == "DequantizeLinear", ( - f"Input '{inp.name}' of node '{qnode.name}' is not quantized but should be!" - ) + producer = node.i(inp_idx) + # Quantized path may include a Cast right after DQ + if producer and producer.op == "Cast": + producer = producer.i(0) + assert producer and producer.op == "DequantizeLinear", ( + f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!" + )modelopt/onnx/quantization/qdq_utils.py (1)
529-536
: Guard against None when skipping a Cast consumer.
quantized_node.op_type
is accessed before checking forNone
, and the second hop after Cast isn’t validated. This can crash on graphs where DQ feeds only a Cast or Cast has no consumer.- quantized_node = tensor_consumers.get(dq_node.output[0], [None])[0] - if quantized_node.op_type == "Cast": - quantized_node = tensor_consumers.get(quantized_node.output[0], [None])[0] - - if not quantized_node: - raise ValueError(f"No consumer found for {dq_node.name}") + quantized_node = tensor_consumers.get(dq_node.output[0], [None])[0] + if not quantized_node: + raise ValueError(f"No consumer found for {dq_node.name}") + if quantized_node.op_type == "Cast": + next_node = tensor_consumers.get(quantized_node.output[0], [None])[0] + if not next_node: + raise ValueError(f"No consumer found after Cast for {quantized_node.name}") + quantized_node = next_node
🧹 Nitpick comments (3)
modelopt/onnx/quantization/qdq_utils.py (1)
998-1000
: Prefer explicit protobuf clearing for readability.Minor:
graph.ClearField("node")
is clearer than slicing deletion on a protobuf repeated field.- del graph.node[:] - graph.node.extend(new_nodes) + graph.ClearField("node") + graph.node.extend(new_nodes)tests/examples/test_onnx_ptq.sh (2)
176-181
: Fix array membership test (ShellCheck SC2199/SC2076).The regex-like test on arrays is brittle. Use a loop flag for exact membership.
- if [[ " ${latency_models[@]} " =~ " $model_name " ]]; then + should_eval_latency=false + for m in "${latency_models[@]}"; do + if [[ "$m" == "$model_name" ]]; then + should_eval_latency=true + break + fi + done + if $should_eval_latency; then CUDA_VISIBLE_DEVICES=$gpu_id python evaluate.py \ --onnx_path=$eval_model_path \ --model_name="${timm_model_name[$model_name]}" \ - --engine_precision=$precision \ + --engine_precision=$precision \ --results_path=$model_dir/$quant_mode/${model_name}_${quant_mode}.csv &
189-190
: Mirror the membership fix in the non-latency branch invocation block if you refactor the conditional.No functional change here; just ensure consistency after the conditional refactor above.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (13)
examples/onnx_ptq/README.md
(2 hunks)examples/onnx_ptq/evaluate.py
(1 hunks)examples/onnx_ptq/evaluation.py
(1 hunks)modelopt/onnx/quantization/__main__.py
(2 hunks)modelopt/onnx/quantization/int8.py
(1 hunks)modelopt/onnx/quantization/qdq_utils.py
(2 hunks)modelopt/onnx/quantization/quantize.py
(3 hunks)modelopt/torch/_deploy/_runtime/tensorrt/constants.py
(0 hunks)modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py
(1 hunks)modelopt/torch/_deploy/_runtime/trt_client.py
(0 hunks)tests/_test_utils/onnx_quantization/utils.py
(1 hunks)tests/examples/test_onnx_ptq.sh
(2 hunks)tests/unit/onnx/test_qdq_rules_int8.py
(1 hunks)
💤 Files with no reviewable changes (2)
- modelopt/torch/_deploy/_runtime/trt_client.py
- modelopt/torch/_deploy/_runtime/tensorrt/constants.py
🚧 Files skipped from review as they are similar to previous changes (7)
- modelopt/torch/_deploy/_runtime/tensorrt/engine_builder.py
- examples/onnx_ptq/evaluation.py
- examples/onnx_ptq/README.md
- tests/unit/onnx/test_qdq_rules_int8.py
- modelopt/onnx/quantization/quantize.py
- modelopt/onnx/quantization/int8.py
- modelopt/onnx/quantization/main.py
🧰 Additional context used
🪛 Shellcheck (0.10.0)
tests/examples/test_onnx_ptq.sh
[error] 176-176: Arrays implicitly concatenate in [[ ]]. Use a loop (or explicit * instead of @).
(SC2199)
[warning] 176-176: Remove quotes from right-hand side of =~ to match as a regex rather than literally.
(SC2076)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (2)
examples/onnx_ptq/evaluate.py (1)
52-58
: CLI switch to --engine_precision looks good.Choices and default align with the PR goal of stronglyTyped by default.
tests/examples/test_onnx_ptq.sh (1)
164-173
: Approve mode-to-precision mapping. Runtime honors “best” for int8_iq and maps other modes to “stronglyTyped”.
822ee09
to
262a623
Compare
Signed-off-by: Riyad Islam <[email protected]>
262a623
to
9e7cd86
Compare
# Remove transpose and reshape nodes | ||
new_nodes = [node for node in graph.node if node.name not in nodes_to_remove] | ||
graph.node.clear() | ||
del graph.node[:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any reason we use this over graph.node.clear()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
modelopt/torch/_deploy/_runtime/tensorrt/tensorrt_utils.py (4)
60-63
: Return type mismatch (bytes vs bytearray).Annotated as bytes but returns bytearray. Simplify and drop the type ignore.
def get_engine_bytes(engine: trt.tensorrt.ICudaEngine) -> bytes: """Return serialized TensorRT engine bytes.""" - return bytearray(engine.serialize()) # type: ignore[return-value] + return bytes(engine.serialize())
65-74
: Incorrect function signature: function returns a tuple but is annotated as engine only.This will confuse type checkers and callers.
-from tensorrt import Logger +from typing import Optional, Tuple @@ -def load_engine(buffer: bytes, log_level: int = trt.Logger.ERROR) -> trt.tensorrt.ICudaEngine: +def load_engine(buffer: bytes, log_level: int = trt.Logger.ERROR) -> Tuple[Optional[trt.tensorrt.ICudaEngine], str]: @@ - return runtime.deserialize_cuda_engine(buffer), "" + return runtime.deserialize_cuda_engine(buffer), "" except Exception as e: logging.exception(str(e)) return None, str(e)
110-131
: get_output_shapes returns TRT dims objects, not List[List[int]] as annotated.Materialize to Python lists for a stable, JSON-safe structure.
- output_shapes = [] + output_shapes: list[list[int]] = [] @@ - if not engine.binding_is_input(binding_index): - shape = context.get_binding_shape(binding_index) - output_shapes.append(shape) + if not engine.binding_is_input(binding_index): + dims = context.get_binding_shape(binding_index) + output_shapes.append(list(dims))
170-180
: Hashing logic hashes the payload twice; fix and clean up docstring typo.Current code computes SHA256(engine_bytes || engine_bytes). The docstring says hash of engine bytes only.
def prepend_hash_to_bytes(engine_bytes: bytes) -> bytes: """Prepend the engine bytes with the SHA256 hash of the engine bytes - This has will serve as a unique identifier for the engine and will be used to manage + This hash will serve as a unique identifier for the engine and will be used to manage TRTSessions in the TRTClient. """ - hash_object = hashlib.sha256(engine_bytes) - hash_object.update(engine_bytes) - hash_bytes = hash_object.digest() + hash_bytes = hashlib.sha256(engine_bytes).digest() engine_bytes = hash_bytes + engine_bytes return engine_bytes
🧹 Nitpick comments (5)
CHANGELOG.rst (2)
8-8
: Tighten wording; show exact flag and choices.Use “strong typing/strongly typed,” include the CLI flag, and note the TRT rationale for clarity.
-- Deprecated ``quantize_mode`` argument in ``examples/onnx_ptq/evaluate.py`` to support strongly typing. Use ``engine_precision`` instead. +- Deprecated ``quantize_mode`` in ``examples/onnx_ptq/evaluate.py`` in favor of strong typing (to align with TensorRT deprecating weak typing). Use ``--engine_precision`` instead (choices: ``best``, ``fp16``, ``stronglyTyped``).
13-13
: Fix grammar and surface behavior change.“defaults to,” code-literal for fp16, and concise note that output weights change.
-- ``high_precision_dtype`` default to fp16 in ONNX quantization, i.e. quantized output model weights are now FP16 by default. +- ``high_precision_dtype`` now defaults to ``fp16`` in ONNX quantization; quantized output model weights are FP16 by default. + (This changes prior behavior for users expecting FP32 weights.)modelopt/torch/_deploy/_runtime/tensorrt/tensorrt_utils.py (3)
134-152
: Avoid double-parsing ONNX and guard input alignment.Parse once, validate tensor count, use integer division.
- input_names = get_onnx_input_names(onnx.load_from_string(onnx_bytes)) - - batch_size = get_batch_size(onnx.load_from_string(onnx_bytes)) + model = onnx.load_from_string(onnx_bytes) + input_names = get_onnx_input_names(model) + batch_size = get_batch_size(model) if not batch_size or batch_size <= 0: batch_size = 1 + if len(input_tensors) != len(input_names): + raise ValueError(f"Expected {len(input_names)} input tensors, got {len(input_tensors)}.") # If input tensor batch % batch_size != 0, we don't use all input tensors for calibration. - num_batches = int(input_tensors[0].shape[0] / batch_size) + num_batches = input_tensors[0].shape[0] // batch_size
154-168
: Prefer explicit exception over assert; consider FP8 support when available.Asserts can be stripped with -O. Also, optionally map FP8 when both TRT and torch support it.
def convert_trt_dtype_to_torch(trt_dtype: trt.tensorrt.DataType) -> torch.dtype: @@ - assert trt_dtype in trt_to_torch_dtype_map, f"Unsupported TensorRT data type: {trt_dtype}" - return trt_to_torch_dtype_map[trt_dtype] + if hasattr(trt.DataType, "FP8") and hasattr(torch, "float8_e4m3fn"): + trt_to_torch_dtype_map[trt.DataType.FP8] = torch.float8_e4m3fn + if trt_dtype not in trt_to_torch_dtype_map: + raise ValueError(f"Unsupported TensorRT data type: {trt_dtype}") + return trt_to_torch_dtype_map[trt_dtype]
182-196
: Minor doc and determinism nits in convert_shape_to_string.Fix example and sort keys for stable output.
"""Convert a shape dictionary to a string. For example, if the shape is: { - "input": [1, 3, 224, 224], + "input": [1, 3, 224, 224], "output": [1, 1000] }. The output string will be: - input:1x3x244x244,output:1x1000 + input:1x3x224x224,output:1x1000 """ result = "" - for key, value in shape.items(): - result += f"{key}:{'x'.join(map(str, value))}," + for key in sorted(shape): + value = shape[key] + result += f"{key}:{'x'.join(map(str, value))}," return result[:-1]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
CHANGELOG.rst
(1 hunks)modelopt/onnx/quantization/qdq_utils.py
(2 hunks)modelopt/torch/_deploy/_runtime/tensorrt/tensorrt_utils.py
(1 hunks)tests/_test_utils/onnx_quantization/utils.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/_test_utils/onnx_quantization/utils.py
- modelopt/onnx/quantization/qdq_utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/_deploy/_runtime/tensorrt/tensorrt_utils.py (1)
27-27
: Resolved: no lingering precision helper references — search confirms zero occurrences ofvalidate_precision
orINT8_IQ
in the codebase; existingTRTMode
uses are intentional.
Signed-off-by: Riyad Islam <[email protected]> Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Riyad Islam <[email protected]>
What does this PR do?
Type of change: new example
Overview: TensorRT deprecated weakly typing in support of explicitly typed onnx model. Modelopt's evaluation and deployment utilities should reflect that.
Usage
# Add a code snippet demonstrating how to use this
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit