Skip to content

Conversation

gcunhase
Copy link
Contributor

@gcunhase gcunhase commented Sep 24, 2025

What does this PR do?

Type of change: Bug fix

Overview: Ensure that FP32 custom ops are supported in ModelOpt by blocking their conversion to FP16 in Autocast.

Usage

--trt_plugins_precision $CUSTOM_OP_NAME:fp32

or

--op_types_to_exclude_fp16 $CUSTOM_OP_NAME

Testing

$ python -m modelopt.onnx.quantization --onnx_path=${MODEL_NAME}.onnx \
    --trt_plugins=$PLUGIN_PATH \
    --trt_plugins_precision $CUSTOM_OP_NAME:fp32 \
    --high_precision_dtype fp16
$ trtexec --onnx=${MODEL_NAME}.quant.onnx --staticPlugins=$PLUGIN_PATH --stronglyTyped

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: N/A
  • Did you add or update any necessary documentation?: N/A
  • Did you update Changelog?: No

Summary by CodeRabbit

  • New Features

    • Add CLI/API option to exclude specific op types from FP16/BF16 conversion during ONNX quantization.
  • Improvements

    • Per‑precision (FP16/FP32) cast mappings for custom ops are tracked and respected.
    • FP16 exclusion rules are propagated throughout the quantization flow; a warning appears when exclusions are irrelevant under FP32.
    • Preprocessing and quantization flow updated to support per-op FP16 exclusions.
  • Documentation

    • Changelog updated with the new option and guidance for custom op precision.

@gcunhase gcunhase requested a review from a team as a code owner September 24, 2025 17:37
@gcunhase gcunhase requested a review from ajrasane September 24, 2025 17:37
Copy link

coderabbitai bot commented Sep 24, 2025

Walkthrough

Adds an optional op_types_to_exclude_fp16 parameter across ONNX quantize APIs, threads it through preprocessing and quantization flows, changes TRT precision parsing to per-precision cast mappings, and uses the new parameter (combined with fp32 cast keys) to exclude specified op types from FP16/BF16 conversion.

Changes

Cohort / File(s) Summary
FP8 quantization API
modelopt/onnx/quantization/fp8.py
quantize(...) signature gains `op_types_to_exclude_fp16: list[str]
INT8 quantization API
modelopt/onnx/quantization/int8.py
quantize(...) signature gains `op_types_to_exclude_fp16: list[str]
Preprocessing & Orchestration
modelopt/onnx/quantization/quantize.py
_preprocess_onnx return tuple extended with per-precision cast mappings dict; quantize(...) signature gains op_types_to_exclude_fp16; combines provided arg with keys of custom_ops_to_cast_fp32, may warn when high_precision_dtype=="fp32", and forwards exclusion list into downstream quantization calls.
TRT Precision Parsing
modelopt/onnx/trt_utils.py
interpret_trt_plugins_precision_flag now returns per-precision mappings { "fp16": {...}, "fp32": {...} } where each precision maps to { op_type: { "inp": [...], "out": [...] } } (nested per-precision cast dicts).
CLI / Entrypoint
modelopt/onnx/quantization/__main__.py
Adds CLI argument --op_types_to_exclude_fp16 (nargs="+"); removes default=[] from several list args; forwards op_types_to_exclude_fp16 to quantize().
Changelog
CHANGELOG.rst
Adds 0.39 entry documenting new op_types_to_exclude_fp16 flag and guidance to use trt_plugins_precisionfp32 for custom ops.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant CLI as User/CLI
    participant Q as quantize.py
    participant P as _preprocess_onnx
    participant T as trt_utils
    participant CAL as calibration/quantizer
    participant I8 as int8.quantize
    participant F8 as fp8.quantize

    CLI->>Q: quantize(model, ..., trt_plugins_precision=flag, --op_types_to_exclude_fp16=list)
    Q->>P: _preprocess_onnx(model, trt_plugins_precision, ...)
    P->>T: interpret_trt_plugins_precision_flag(flag)
    T-->>P: { "fp16": {op: spec}, "fp32": {op: spec} }
    P-->>Q: (..., custom_ops_to_cast_fp32 = {op: spec}, ...)
    Q->>Q: combined_exclude = uniq(op_types_to_exclude_fp16 + keys(custom_ops_to_cast_fp32))
    alt high_precision_dtype == "fp32" and combined_exclude non-empty
        Q-->>CLI: warn("exclusion void when high_precision_dtype=='fp32'")
    end
    Q->>CAL: calibration/quantization steps...
    CAL-->>Q: calibrated model
    Q->>I8: int8.quantize(model, ..., op_types_to_exclude_fp16=combined_exclude)
    Q->>F8: fp8.quantize(model, ..., op_types_to_exclude_fp16=combined_exclude)
    note right of I8: convert_to_f16 called with op_block_list=combined_exclude
    note right of F8: convert_to_f16 called with op_block_list=combined_exclude
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

I twitch my whiskers at the new command,
I sort FP16 and FP32 with a careful hand.
Some ops stay warm, some learn to fly light,
I hop through casts into the night.
A rabbit nods — mappings look just right 🐇🥕

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly states the primary change of inserting cast nodes for FP32-required operations, reflecting the main bug fix of preserving FP32-only custom ops during Autocast. It directly matches the introduction of the op_types_to_exclude_fp16 parameter and related casting logic across the ONNX quantization pipeline. It avoids generic terms and clearly communicates the main objective to any teammate scanning the history.
Docstring Coverage ✅ Passed Docstring coverage is 85.71% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 03ee3fd and 890b12e.

📒 Files selected for processing (6)
  • CHANGELOG.rst (1 hunks)
  • modelopt/onnx/quantization/__main__.py (2 hunks)
  • modelopt/onnx/quantization/fp8.py (2 hunks)
  • modelopt/onnx/quantization/int8.py (2 hunks)
  • modelopt/onnx/quantization/quantize.py (8 hunks)
  • modelopt/onnx/trt_utils.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • modelopt/onnx/quantization/main.py
  • modelopt/onnx/quantization/int8.py
  • modelopt/onnx/quantization/fp8.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/trt_utils.py (1)
  • interpret_trt_plugins_precision_flag (327-431)
modelopt/onnx/quantization/graph_utils.py (1)
  • cast_custom_ops (1282-1369)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: wait-checks / wait
  • GitHub Check: linux
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (4)
CHANGELOG.rst (1)

4-12: Changelog entry reads well

Thanks for documenting the new op_types_to_exclude_fp16 flag and the TRT plugin precision fallback. The note clearly tells users how to keep custom ops in FP32.

modelopt/onnx/quantization/quantize.py (1)

431-433: LGTM: deduping FP16 exclusions keeps intent clear.

Combining user-specified exclusions with FP32 custom ops via dict.fromkeys preserves order while removing duplicates. Looks good.

modelopt/onnx/trt_utils.py (2)

370-375: Don't overwrite prior FP16/FP32 casts.

custom_ops_to_cast[precision] = {...} replaces any entries already collected for that precision, so only the last custom op survives. With flags like --trt_plugins_precision Foo:fp32 Bar:fp32, the earlier op is dropped and will still be autocast to FP16. Merge into the existing precision map instead.

-            if precision in ["fp16", "fp32"]:
-                custom_ops_to_cast[precision] = {
-                    op_type: {
-                        "inp": list(range(num_inps)),
-                        "out": list(range(num_outs)),
-                    }
-                }
+            if precision in ["fp16", "fp32"]:
+                ops_map = custom_ops_to_cast.setdefault(precision, {})
+                ops_map[op_type] = {
+                    "inp": list(range(num_inps)),
+                    "out": list(range(num_outs)),
+                }

413-420: Handle output-only casts and keep accumulated entries.

This block still requires an input index to exist before recording the cast, so output-only requests (e.g., [fp32,fp32]:[fp32,fp32] where inputs stay FP16) are silently ignored. It also overwrites previously stored ops the same way as Line 370. Allow either inputs or outputs to trigger storage, and merge into the existing precision map.

-            for precision in ["fp16", "fp32"]:
+            for precision in ["fp16", "fp32"]:
                 inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
                 out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
-                if inp_precision_cast:
-                    custom_ops_to_cast[precision] = {
-                        op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
-                    }
+                if inp_precision_cast or out_precision_cast:
+                    ops_map = custom_ops_to_cast.setdefault(precision, {})
+                    ops_map[op_type] = {
+                        "inp": inp_precision_cast,
+                        "out": out_precision_cast,
+                    }

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gcunhase gcunhase force-pushed the dev/gcunhasergio/fp32_cast_custom_ops_5455919 branch from a2fb1be to b00431a Compare September 24, 2025 17:44
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/onnx/trt_utils.py (1)

413-413: Update comment to reflect bidirectional casting behavior.

The comment states "Will cast the inputs to FP16/FP32 and the outputs back to FP32" but the code now supports casting to both FP16 and FP32 based on the precision specification.

-            # Will cast the inputs to FP16/FP32 and the outputs back to FP32
+            # Will cast the inputs and outputs based on the specified precision (FP16/FP32)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 26c203a and b00431a.

📒 Files selected for processing (4)
  • modelopt/onnx/quantization/fp8.py (2 hunks)
  • modelopt/onnx/quantization/int8.py (2 hunks)
  • modelopt/onnx/quantization/quantize.py (5 hunks)
  • modelopt/onnx/trt_utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/trt_utils.py (1)
  • interpret_trt_plugins_precision_flag (327-431)
modelopt/onnx/quantization/graph_utils.py (1)
  • cast_custom_ops (1282-1369)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (10)
modelopt/onnx/quantization/int8.py (2)

131-131: LGTM! Parameter addition for FP32 casting support.

The addition of custom_ops_to_cast_fp32: list[str] = [] parameter is correctly placed and follows the existing pattern for optional parameters.


283-283: convert_to_f16 honors op_block_list; custom_ops_to_cast_fp32 entries remain in FP32

modelopt/onnx/quantization/fp8.py (2)

181-181: LGTM! Consistent parameter addition for FP32 casting.

The custom_ops_to_cast_fp32 parameter is correctly added with a default empty list, maintaining consistency with the int8.py implementation.


322-322: Good improvement: Configurable FP32 casting replaces hardcoded list.

Replacing the hardcoded ["Resize"] with the configurable custom_ops_to_cast_fp32 parameter provides better flexibility. The Resize op can now be included in the list if needed, rather than being unconditionally blocked.

modelopt/onnx/quantization/quantize.py (4)

84-84: LGTM! Return type correctly extended for FP32 cast mappings.

The function signature properly adds the extra dict return value for FP32 cast mappings.


183-183: Good practice: Initialize custom_ops_to_cast as empty dict.

Initializing the variable before conditional usage prevents potential UnboundLocalError.


189-194: Drop incorrect dictionary access concern The custom_ops_to_cast["fp16"] value is already a flat mapping of op types to their {inp, out} indices, matching cast_custom_ops’s expected ops_to_cast format.

Likely an incorrect or invalid review comment.


203-203: FP32 cast dictionary structure is correctcustom_ops_to_cast.get("fp32") already returns a mapping of each op_type to its {"inp":[…],"out":[…]} structure.

modelopt/onnx/trt_utils.py (2)

413-420: Confirm casting map is consumed per-precision as intended
The custom_ops_to_cast dict is built with top‐level “fp16”/“fp32” keys, and in quantize.py only the “fp16” entry is ever retrieved for casting via

if custom_ops_to_cast.get("fp16"):
    onnx_model = cast_custom_ops(onnx_model, custom_ops_to_cast["fp16"])

No code paths consume custom_ops_to_cast["fp32"], so the nested structure matches its actual use.


370-376: Nested precision grouping is intentional and correct.

cast_custom_ops expects a mapping of op_type → {inp, out}, and quantize.py calls it with custom_ops_to_cast["fp16"], so grouping by precision first matches the consumer’s expectations and does not introduce a breaking change.

Likely an incorrect or invalid review comment.

Copy link

codecov bot commented Sep 24, 2025

Codecov Report

❌ Patch coverage is 23.07692% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.52%. Comparing base (a041bbe) to head (890b12e).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/trt_utils.py 0.00% 7 Missing ⚠️
modelopt/onnx/quantization/quantize.py 50.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #363      +/-   ##
==========================================
- Coverage   73.53%   73.52%   -0.01%     
==========================================
  Files         172      172              
  Lines       17700    17706       +6     
==========================================
+ Hits        13016    13019       +3     
- Misses       4684     4687       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@gcunhase gcunhase requested a review from galagam September 25, 2025 13:41
@gcunhase gcunhase force-pushed the dev/gcunhasergio/fp32_cast_custom_ops_5455919 branch 2 times, most recently from 62564db to 0853ce1 Compare September 25, 2025 15:14
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/trt_utils.py (1)

378-385: Logging uses mutated value; message becomes misleading.

You overwrite precision before logging, so the warning prints identical precisions.

-                if precision != quantize_mode:
-                    precision = quantize_mode
-                    logger.warning(
-                        f"Requested custom op precision ({precision}) is different than quantize mode: "
+                if precision != quantize_mode:
+                    requested_precision = precision
+                    precision = quantize_mode
+                    logger.warning(
+                        f"Requested custom op precision ({requested_precision}) is different than quantize mode: "
                         f"{quantize_mode}. Mixed {precision}+{quantize_mode} precision is not yet supported. "
                         f"Setting the custom op precision to be the same as quantize mode."
                     )
🧹 Nitpick comments (2)
modelopt/onnx/quantization/int8.py (1)

131-133: Avoid mutable default for parameter.

Use None as default to avoid shared list instances across calls.

-    custom_ops_to_cast_fp32: list[str] = [],
+    custom_ops_to_cast_fp32: list[str] | None = None,

And later where used:

-            op_block_list=custom_ops_to_cast_fp32,
+            op_block_list=(custom_ops_to_cast_fp32 or []),
modelopt/onnx/trt_utils.py (1)

339-342: Docstring mismatch with new return structure.

custom_ops_to_cast now has per-precision keys ("fp16"/"fp32"). Update the docstring accordingly.

-        Dictionary with custom ops to cast containing the I/O indices to cast.
+        Dictionary with per-precision cast maps:
+            {
+              "fp16": { <op_type>: {"inp": [idx...], "out": [idx...]} },
+              "fp32": { <op_type>: {"inp": [idx...], "out": [idx...]} },
+            }
         Dictionary with custom ops to quantize containing the I/O indices to quantize.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 62564db and 0853ce1.

📒 Files selected for processing (4)
  • modelopt/onnx/quantization/fp8.py (2 hunks)
  • modelopt/onnx/quantization/int8.py (2 hunks)
  • modelopt/onnx/quantization/quantize.py (5 hunks)
  • modelopt/onnx/trt_utils.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/onnx/quantization/fp8.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/trt_utils.py (1)
  • interpret_trt_plugins_precision_flag (327-431)
modelopt/onnx/quantization/graph_utils.py (1)
  • cast_custom_ops (1282-1369)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (4)
modelopt/onnx/quantization/quantize.py (3)

189-195: LGTM: FP16-cast application gated and performed in preprocessing.

Conditionally applying cast_custom_ops for fp16-only custom ops is sound and localized.


476-481: Duplicate: Passing only keys drops per-IO FP32 mapping.

Prior comment already flagged that list(custom_ops_to_cast_fp32.keys()) loses "inp"/"out" indices, preventing selective FP32 casting.

Consider passing the full dict and adapting downstream to preserve mapping:

-            custom_ops_to_cast_fp32=list(custom_ops_to_cast_fp32.keys()),
+            custom_ops_to_cast_fp32=custom_ops_to_cast_fp32,

Then in int8/fp8 quantize implementations, derive op_block_list for convert_to_f16 as needed:

ops_block_list = (
    custom_ops_to_cast_fp32
    if isinstance(custom_ops_to_cast_fp32, list)
    else list(custom_ops_to_cast_fp32.keys())
)

84-85: preprocess_onnx is a private helper, not part of the public API
No external callers beyond its own use in quantize.py; changing its return tuple does not break any public interfaces.

Likely an incorrect or invalid review comment.

modelopt/onnx/quantization/int8.py (1)

276-286: convert_to_f16 uses op types for op_block_list and supports bf16
op_block_list is matched against each node’s op_type and the assertion includes "bf16" as a valid low_precision_type.

Comment on lines +413 to 421
# Will cast the inputs to FP16/FP32 and the outputs back to FP32
for precision in ["fp16", "fp32"]:
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
if inp_precision_cast:
custom_ops_to_cast[precision] = {
op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: Drops output-only casting and overwrites maps.

  • Mapping is added only if inp_precision_cast is non-empty, so output-only casts are ignored.
  • Overwrites the existing precision map for each op instead of merging.
-            # Will cast the inputs to FP16/FP32 and the outputs back to FP32
-            for precision in ["fp16", "fp32"]:
-                inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
-                out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
-                if inp_precision_cast:
-                    custom_ops_to_cast[precision] = {
-                        op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
-                    }
+            # Will cast requested inputs to FP16/FP32 and outputs back to FP32
+            for precision in ["fp16", "fp32"]:
+                inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
+                out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
+                if inp_precision_cast or out_precision_cast:
+                    ops_map = custom_ops_to_cast.setdefault(precision, {})
+                    ops_map[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Will cast the inputs to FP16/FP32 and the outputs back to FP32
for precision in ["fp16", "fp32"]:
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
if inp_precision_cast:
custom_ops_to_cast[precision] = {
op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
}
# Will cast requested inputs to FP16/FP32 and outputs back to FP32
for precision in ["fp16", "fp32"]:
inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
if inp_precision_cast or out_precision_cast:
ops_map = custom_ops_to_cast.setdefault(precision, {})
ops_map[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast}
🤖 Prompt for AI Agents
In modelopt/onnx/trt_utils.py around lines 413 to 421, the current loop only
adds a precision entry when inp_precision_cast is non-empty and assigns a new
dict for the op_type, which drops output-only casts and overwrites any existing
entries; change it to add the precision mapping when either inp_precision_cast
or out_precision_cast is non-empty, ensure custom_ops_to_cast[precision] is
created if missing, and merge into the existing op_type entry (creating 'inp'
and/or 'out' keys as needed) by extending or setting the index lists rather than
replacing the whole op_type dict so both input-only, output-only and combined
casts are preserved and multiple calls accumulate safely.

Copy link

@galagam galagam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might want to add more generic solution, to allow propagating op types to AutoCast's op_types_to_exclude , since this solution is plugin-specific. But overall LGTM, approved.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
modelopt/onnx/quantization/int8.py (1)

122-123: Avoid mutable default argument for op_types_to_exclude_fp16.

Use None and coalesce at call sites to prevent accidental cross-call state.

-    op_types_to_exclude_fp16: list[str] = [],
+    op_types_to_exclude_fp16: list[str] | None = None,
modelopt/onnx/quantization/__main__.py (1)

104-113: Clarify help text to include BF16.

The option applies to FP16 and BF16 reductions.

-            "A space-separated list of node types to exclude from FP16 conversion. "
-            "This is only relevant if '--high_precision_dtype != fp32'."
+            "A space-separated list of node types to exclude from FP16/BF16 conversion. "
+            "Relevant when --high_precision_dtype is 'fp16' or 'bf16'."
modelopt/onnx/quantization/quantize.py (1)

429-431: Don’t mutate caller-provided list; de-duplicate when extending with custom-op FP32 keys.

Prevents side effects on the CLI args list and removes duplicates.

-    op_types_to_exclude_fp16 = op_types_to_exclude_fp16 or []
-    op_types_to_exclude_fp16.extend(list(custom_ops_to_cast_fp32.keys()))
+    op_types_to_exclude_fp16 = list(
+        dict.fromkeys((op_types_to_exclude_fp16 or []) + list(custom_ops_to_cast_fp32.keys()))
+    )
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0853ce1 and aec1a03.

📒 Files selected for processing (4)
  • modelopt/onnx/quantization/__main__.py (2 hunks)
  • modelopt/onnx/quantization/fp8.py (2 hunks)
  • modelopt/onnx/quantization/int8.py (2 hunks)
  • modelopt/onnx/quantization/quantize.py (8 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/onnx/quantization/fp8.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/trt_utils.py (1)
  • interpret_trt_plugins_precision_flag (327-431)
modelopt/onnx/quantization/graph_utils.py (1)
  • cast_custom_ops (1282-1369)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (1)
modelopt/onnx/quantization/quantize.py (1)

429-431: Passing only keys drops per‑IO mapping (repeat).

Only op types are propagated to op_block_list, losing any per‑input/output FP32 mapping from interpret_trt_plugins_precision_flag. If convert_to_f16 cannot honor per‑IO exclusions, document the limitation or thread the full mapping and handle casts accordingly.

Please confirm whether convert_to_f16 supports:

  • op-level block listing only, or
  • per‑input/output casting around FP32 custom ops.

If it’s op-level only, consider documenting that :<precision> with mixed per‑IO FP32 is not supported for FP16/BF16 exclusion.

#!/bin/bash
# Inspect convert_to_f16 to verify whether it accepts/uses per-IO mapping or only op-type lists.
rg -nP -C3 'def\s+convert_to_f16\s*\('
rg -nP -C3 'op_block_list'

@gcunhase gcunhase changed the title [5455919] Insert cast nodes for 'FP32 required' custom ops [5455919] Insert cast nodes for 'FP32 required' ops Sep 25, 2025
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/quantize.py (1)

182-205: FP32-only custom ops still receive FP16 tensors.

We surface the custom_ops_to_cast["fp32"] mapping from _preprocess_onnx, but never apply it to insert the required Cast nodes. Only the FP16 branch calls cast_custom_ops; FP32 metadata is dropped after extending op_types_to_exclude_fp16. When convert_to_f16() later runs (Lines 276-285 of this file, and the corresponding call sites in int8.py / fp8.py), upstream nodes are rewritten to emit FP16 tensors, so FP32-only plugins keep seeing FP16 inputs/outputs, defeating the whole “FP32 required” guarantee.

Please plumb the full FP32 mapping into the cast-insertion step (e.g., extend cast_custom_ops to accept the target precision or add a dedicated helper) and invoke it here before returning. Until we actually add those Cast nodes, the new flag does not protect FP32-only plugins.

🧹 Nitpick comments (3)
modelopt/onnx/quantization/__main__.py (1)

287-287: Wire-through looks correct; consider no-op warning when unused

The parameter is correctly forwarded. Optional: log a warning if --op_types_to_exclude_fp16 is provided while --high_precision_dtype=fp32, since it will have no effect.

modelopt/onnx/quantization/int8.py (1)

122-122: Signature extension: add brief docstring note

Parameter addition is fine. Consider documenting op_types_to_exclude_fp16 in the function docstring to clarify it expects op types (not node names).

modelopt/onnx/quantization/fp8.py (1)

171-171: Signature extension: add brief docstring note

Consider documenting op_types_to_exclude_fp16 in the function docstring, mirroring INT8.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aec1a03 and 9174343.

📒 Files selected for processing (4)
  • modelopt/onnx/quantization/__main__.py (2 hunks)
  • modelopt/onnx/quantization/fp8.py (2 hunks)
  • modelopt/onnx/quantization/int8.py (2 hunks)
  • modelopt/onnx/quantization/quantize.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/trt_utils.py (1)
  • interpret_trt_plugins_precision_flag (327-431)
modelopt/onnx/quantization/graph_utils.py (1)
  • cast_custom_ops (1282-1369)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (5)
modelopt/onnx/quantization/__main__.py (1)

104-113: CLI flag LGTM

New --op_types_to_exclude_fp16 argument is clear and consistent with the rest of the CLI.

modelopt/onnx/quantization/int8.py (2)

280-286: None‑guard on op_block_list resolved

Passing op_block_list=op_types_to_exclude_fp16 or [] avoids None pitfalls. Good fix.


280-286: convert_to_f16 semantics verified
Supports low_precision_type="bf16", op_block_list matches node.op_type, and PrecisionConverter inserts both up- and down-casts around block-listed ops to preserve FP32 execution.

modelopt/onnx/quantization/fp8.py (2)

318-325: None‑guard on op_block_list resolved

op_block_list=op_types_to_exclude_fp16 or [] is appropriate.


318-325: No action needed: ‘Resize’ is already blocked by default. The PrecisionConverter’s OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION includes “Resize”, so convert_to_f16 forces Resize to remain in high precision without any user‐specified op_block_list.

Likely an incorrect or invalid review comment.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
modelopt/onnx/quantization/quantize.py (2)

84-84: Return typing could be more precise (and override_shapes typing mismatch).

Please tighten the return type to reflect the nested dict shapes, and accept None for override_shapes to avoid the type: ignore at the call site.

Apply:

-    override_shapes: str,
+    override_shapes: str | None,
@@
-) -> tuple[str, onnx.ModelProto, list[str], bool, bool, bool, dict, dict]:
+) -> tuple[
+    str,
+    onnx.ModelProto,
+    list[str],
+    bool,
+    bool,
+    bool,
+    dict[str, dict[str, list[int]]],  # custom_ops_to_cast["fp32"]
+    dict[str, dict[str, list[int]]],  # custom_ops_to_quantize
+]:

452-453: Validate op_types_to_exclude_fp16 spelling as well.

Typos in op_types_to_exclude_fp16 won’t be caught today.

Apply:

 validate_op_types_spelling(onnx_path, op_types_to_quantize, op_types_to_exclude)
+validate_op_types_spelling(onnx_path, None, op_types_to_exclude_fp16)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9174343 and ff0617e.

📒 Files selected for processing (1)
  • modelopt/onnx/quantization/quantize.py (8 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/trt_utils.py (1)
  • interpret_trt_plugins_precision_flag (327-431)
modelopt/onnx/quantization/graph_utils.py (1)
  • cast_custom_ops (1282-1369)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (8)
modelopt/onnx/quantization/quantize.py (8)

189-191: Casting FP16 for custom ops before quantization looks fine.

The early FP16 cast via cast_custom_ops for custom_ops_to_cast["fp16"] is consistent with the precision spec.


203-205: Threading FP32 mapping forward is good, but see selective-cast handling below.

Returning custom_ops_to_cast["fp32"] enables later steps to respect FP32-only custom ops. However, ensure you don’t lose the per‑I/O index info downstream (see comment on Lines 430‑433).


219-220: New parameter accepted.

Adding op_types_to_exclude_fp16 to quantize is consistent with the PR goal.


271-274: Docstring addition LGTM.

Clear description; relevant when high_precision_dtype != 'fp32'.


415-417: Good: capture FP32-cast mapping from preprocessing.

Variable naming (custom_ops_to_cast_fp32) matches intent and improves readability.


434-439: Helpful warning when high_precision_dtype == fp32.

Good user guidance; avoids silent no‑ops.


430-433: Don’t drop per‑I/O FP32 mapping by taking only keys.

Converting custom_ops_to_cast_fp32 to just op_type keys loses "inp"/"out" index information, preventing selective casting and forcing whole‑op exclusions from FP16. This restricts fidelity of the TRT precision spec (e.g., op_type:[fp32,fp16]:[...] cases).

Consider preserving the full dict and threading it to the conversion step that inserts casts (or augment convert_to_f16 to accept this mapping and selectively insert casts around only the specified I/O indices). If full plumb‑through is not immediately feasible, at least keep both:

  • a coarse op-type list for exclusion, and
  • the detailed mapping for future selective cast insertion.

I can help sketch the API changes if you want to support selective per‑I/O casting end‑to‑end.


481-495: op_types_to_exclude_fp16 is correctly accepted and forwarded
Both quantize functions in int8.py and fp8.py define the op_types_to_exclude_fp16 parameter and pass it as op_block_list to convert_to_f16.

@gcunhase gcunhase force-pushed the dev/gcunhasergio/fp32_cast_custom_ops_5455919 branch from 31d2879 to e75c1ca Compare September 25, 2025 20:03
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
modelopt/onnx/quantization/int8.py (2)

157-173: Fix None dereference when extending nodes_to_exclude.

nodes_to_exclude may be None (function arg default) before calling extend, causing a runtime error.

Apply this diff:

@@
     if enable_gemv_detection_for_trt:
@@
-        nodes_to_exclude.extend(matmul_nodes_to_exclude)  # type: ignore[union-attr]
+        nodes_to_exclude = nodes_to_exclude or []
+        nodes_to_exclude.extend(matmul_nodes_to_exclude)

69-73: Guard nodes_to_exclude in _find_nodes_to_quantize.

nodes_to_exclude has a default of None; adding it directly will crash on None.

Apply this diff:

-    partitioned_nodes = set(sum(non_quantizable_hard_coded_partitions, []) + nodes_to_exclude)  # noqa: RUF017
+    partitioned_nodes = set(
+        sum(non_quantizable_hard_coded_partitions, []) + (nodes_to_exclude or [])
+    )  # noqa: RUF017
modelopt/onnx/quantization/fp8.py (1)

215-231: Fix None dereference when extending nodes_to_exclude.

nodes_to_exclude can be None before extend, causing a crash.

Apply this diff:

@@
-        matmul_nodes_to_exclude = find_nodes_from_matmul_to_exclude(
+        matmul_nodes_to_exclude = find_nodes_from_matmul_to_exclude(
             onnx_path,
             use_external_data_format,
             intermediate_generated_files,
             calibration_data_reader,
             calibration_eps,
             calibration_shapes,
         )
-        nodes_to_exclude.extend(matmul_nodes_to_exclude)  # type: ignore[union-attr]
+        nodes_to_exclude = nodes_to_exclude or []
+        nodes_to_exclude.extend(matmul_nodes_to_exclude)
🧹 Nitpick comments (3)
CHANGELOG.rst (1)

4-12: Clarify CLI flags in changelog entry.

Consider explicitly naming the CLI flags for discoverability, e.g., mention --op_types_to_exclude_fp16 and --trt_plugins_precision op:fp32. Also note applicability to BF16 (since high_precision_dtype can be 'bf16') to avoid confusion.

modelopt/onnx/trt_utils.py (1)

368-370: Return early on unsupported precision to avoid ambiguous state.

Log and continue to next item after detecting unsupported precision.

Apply this diff:

-            if precision not in supported_precisions:
-                logger.warning(f"Precision {precision} is not supported. Skipping.")
+            if precision not in supported_precisions:
+                logger.warning(f"Precision {precision} is not supported. Skipping.")
+                continue
modelopt/onnx/quantization/quantize.py (1)

430-439: Log and validate FP16/BF16 exclusion list.

  • Add a log for final exclusions to aid debugging.
  • Optionally validate spelling like other op-type lists.

Apply this diff:

@@
-    op_types_to_exclude_fp16 = list(
+    op_types_to_exclude_fp16 = list(
         dict.fromkeys((op_types_to_exclude_fp16 or []) + list(custom_ops_to_cast_fp32.keys()))
     )
     if high_precision_dtype == "fp32" and op_types_to_exclude_fp16:
         logger.warning(
             "Nodes were detected for exclusion from FP16/BF16 conversion, but 'high_precision_dtype' is set to FP32. "
             "Since the model won't be converted to a lower precision, this flag is void."
         )
+    if op_types_to_exclude_fp16:
+        logger.info(f"Op types excluded from {high_precision_dtype.upper()} conversion: {op_types_to_exclude_fp16}")

Optionally also run validate_op_types_spelling for op_types_to_exclude_fp16 if supported by the helper.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between ff0617e and e75c1ca.

📒 Files selected for processing (6)
  • CHANGELOG.rst (1 hunks)
  • modelopt/onnx/quantization/__main__.py (2 hunks)
  • modelopt/onnx/quantization/fp8.py (2 hunks)
  • modelopt/onnx/quantization/int8.py (2 hunks)
  • modelopt/onnx/quantization/quantize.py (8 hunks)
  • modelopt/onnx/trt_utils.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/onnx/quantization/main.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/trt_utils.py (1)
  • interpret_trt_plugins_precision_flag (327-431)
modelopt/onnx/quantization/graph_utils.py (1)
  • cast_custom_ops (1282-1369)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (11)
modelopt/onnx/quantization/int8.py (2)

122-122: Public API extension looks good.

Adding op_types_to_exclude_fp16 and threading it into the flow is consistent with the PR goal.


280-286: Correctly guards None and passes through to convert_to_f16.

Passing op_block_list=op_types_to_exclude_fp16 or [] prevents None issues. Low-risk change.

modelopt/onnx/quantization/quantize.py (4)

84-85: Updated _preprocess_onnx return signature looks consistent.

The added dicts for per-precision casts integrate with downstream usage.


189-195: Casts for fp16-only custom ops applied at preprocess.

Good gating on fp16 only; relies on trt_utils map to accumulate correctly (see fix there).

Ensure interpret_trt_plugins_precision_flag merges multiple ops per precision to avoid missing casts after the preprocess stage.


203-205: Returning fp32 cast mappings for exclusion is appropriate.

Feeding FP32-only ops into the FP16 exclusion logic aligns with the PR intent.


415-417: Threading custom_ops_to_cast_fp32 to quantize is correct.

Downstream usage derives op_types_to_exclude_fp16 from its keys.

modelopt/onnx/quantization/fp8.py (3)

171-171: Public API extension looks good.

Adding op_types_to_exclude_fp16 for FP16/BF16 conversion control is aligned with int8 path.


318-325: Correctly threads op_types_to_exclude_fp16 into convert_to_f16.

Guards None and passes low_precision_type appropriately.

Confirm whether 'Resize' previously required default exclusion here. If so, consider appending it when not explicitly provided.


310-317: Verify convert_fp16_io when targeting BF16.

convert_fp16_io runs even if high_precision_dtype == "bf16". Confirm it’s intended or switch to a BF16-aware variant to avoid mismatched IO dtypes.

modelopt/onnx/trt_utils.py (2)

370-376: Do not overwrite previous custom-op entries for the same precision.

custom_ops_to_cast[precision] = {op_type: ...} clobbers prior entries for that precision. Merge instead.

Apply this diff:

-            if precision in ["fp16", "fp32"]:
-                custom_ops_to_cast[precision] = {
-                    op_type: {
-                        "inp": list(range(num_inps)),
-                        "out": list(range(num_outs)),
-                    }
-                }
+            if precision in ["fp16", "fp32"]:
+                ops_map = custom_ops_to_cast.setdefault(precision, {})
+                ops_map[op_type] = {
+                    "inp": list(range(num_inps)),
+                    "out": list(range(num_outs)),
+                }

413-420: Preserve output-only casts and avoid map overwrite.

  • The current condition ignores output-only requests (when no inputs to cast).
  • Also overwrites the precision map each time.

Handle both inputs and outputs and merge into the existing map.

Apply this diff:

-            # Will cast the inputs to FP16/FP32 and the outputs back to FP32
-            for precision in ["fp16", "fp32"]:
-                inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
-                out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
-                if inp_precision_cast:
-                    custom_ops_to_cast[precision] = {
-                        op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
-                    }
+            # Will cast requested inputs to FP16/FP32 and outputs back to FP32
+            for precision in ["fp16", "fp32"]:
+                inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
+                out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
+                if inp_precision_cast or out_precision_cast:
+                    ops_map = custom_ops_to_cast.setdefault(precision, {})
+                    ops_map[op_type] = {
+                        "inp": inp_precision_cast,
+                        "out": out_precision_cast,
+                    }

@gcunhase gcunhase force-pushed the dev/gcunhasergio/fp32_cast_custom_ops_5455919 branch from e75c1ca to fcaa32f Compare September 25, 2025 20:44
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/quantization/quantize.py (1)

219-274: Docstring for trt_plugins_precision is outdated vs implementation.

Parser supports per-IO syntax and int8/fp8, but the doc here only mentions fp16/fp32 and single-value format.

Update to match main.py help text, including:

  • Single value: <op_type>: where precision ∈ {fp32, fp16, int8, fp8}
  • Per-IO: <op_type>:[p_in1,p_in2,...]:[p_out1,p_out2,...]
  • Note that int8/fp8 must match quantize_mode.
🧹 Nitpick comments (3)
CHANGELOG.rst (1)

4-12: Changelog entry LGTM; consider finalizing release date.

Content matches the feature. Replace 2025-10-xx with the release date before tagging.

modelopt/onnx/trt_utils.py (2)

390-403: Be robust to whitespace in bracketed I/O precisions.

Split values can contain spaces, causing false “unsupported precision” errors. Strip each token.

Apply:

-            inp_precision = inp_precision.strip("[]").split(",")
-            out_precision = out_precision.strip("[]").split(",")
+            inp_precision = [p.strip() for p in inp_precision.strip("[]").split(",")]
+            out_precision = [p.strip() for p in out_precision.strip("[]").split(",")]

404-412: Normalize int8/fp8 mismatches to quantize_mode in per-IO path.

You warn but don’t normalize inp/out precisions; indices may target the wrong mode. Normalize where requested.

Outside this hunk, update after the warning to map any “int8/fp8” tokens to quantize_mode in both inp_precision and out_precision so subsequent casts/quantize index extraction aligns with the chosen mode.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e75c1ca and fcaa32f.

📒 Files selected for processing (6)
  • CHANGELOG.rst (1 hunks)
  • modelopt/onnx/quantization/__main__.py (2 hunks)
  • modelopt/onnx/quantization/fp8.py (2 hunks)
  • modelopt/onnx/quantization/int8.py (2 hunks)
  • modelopt/onnx/quantization/quantize.py (8 hunks)
  • modelopt/onnx/trt_utils.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/onnx/quantization/int8.py
  • modelopt/onnx/quantization/fp8.py
🧰 Additional context used
🧬 Code graph analysis (1)
modelopt/onnx/quantization/quantize.py (2)
modelopt/onnx/trt_utils.py (1)
  • interpret_trt_plugins_precision_flag (327-431)
modelopt/onnx/quantization/graph_utils.py (1)
  • cast_custom_ops (1282-1369)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (10)
modelopt/onnx/quantization/__main__.py (2)

104-113: Good addition: CLI flag to exclude FP16/BF16 conversion.

Flag name, nargs, default, and help text are appropriate and consistent with the intended behavior.


287-287: Correctly forwards new arg to quantize().

Wires the CLI flag into the workflow as expected.

modelopt/onnx/quantization/quantize.py (6)

183-195: Casts only when FP16 mapping present — OK.

Preprocess inserts FP16 casts and defers FP32 handling to downstream FP16/BF16 conversion via exclusion. Matches intent.

Please confirm convert_to_f16 in fp8/int8 paths respects op_types_to_exclude_fp16 by inserting FP32 boundary casts as needed. Based on learnings


203-205: Returning FP32 cast map for exclusion usage — OK.

Returning the fp32 map for exclusion downstream is consistent with later union logic.


415-417: Unpacking the new return values — OK.

Names reflect semantics and ordering from _preprocess_onnx.


430-439: Combining FP16-exclusion sources — OK.

Deduplicates and warns appropriately for fp32 high_precision_dtype.


481-481: Propagates op_types_to_exclude_fp16 to quantizers — OK.

Ensures FP16/BF16 conversion respects exclusions.


84-84: _preprocess_onnx call site updated The call at modelopt/onnx/quantization/quantize.py:417 now unpacks all eight return values to match the updated signature.

modelopt/onnx/trt_utils.py (2)

370-376: Bug: Overwrites accumulated casts for a given precision.

Using custom_ops_to_cast[precision] = {...} clobbers previous ops for that precision. Merge instead.

Apply:

-            if precision in ["fp16", "fp32"]:
-                custom_ops_to_cast[precision] = {
-                    op_type: {
-                        "inp": list(range(num_inps)),
-                        "out": list(range(num_outs)),
-                    }
-                }
+            if precision in ["fp16", "fp32"]:
+                ops_map = custom_ops_to_cast.setdefault(precision, {})
+                ops_map[op_type] = {
+                    "inp": list(range(num_inps)),
+                    "out": list(range(num_outs)),
+                }

413-420: Two issues: drops output-only casts and still overwrites per-precision maps.

  • Only adds entries when inputs need casting; output-only requests are ignored.
  • Overwrites maps; multiple ops can’t accumulate under the same precision.

Fix both, and avoid reusing “precision” as loop var.

Apply:

-            # Will cast the inputs to FP16/FP32 and the outputs back to FP32
-            for precision in ["fp16", "fp32"]:
-                inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == precision]
-                out_precision_cast = [i for i, p in enumerate(out_precision) if p == precision]
-                if inp_precision_cast:
-                    custom_ops_to_cast[precision] = {
-                        op_type: {"inp": inp_precision_cast, "out": out_precision_cast}
-                    }
+            # Will cast the inputs to FP16/FP32 and the outputs back to FP32
+            for cast_precision in ["fp16", "fp32"]:
+                inp_precision_cast = [i for i, p in enumerate(inp_precision) if p == cast_precision]
+                out_precision_cast = [i for i, p in enumerate(out_precision) if p == cast_precision]
+                if inp_precision_cast or out_precision_cast:
+                    ops_map = custom_ops_to_cast.setdefault(cast_precision, {})
+                    ops_map[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast}

@gcunhase gcunhase enabled auto-merge (squash) September 25, 2025 22:25
@gcunhase gcunhase force-pushed the dev/gcunhasergio/fp32_cast_custom_ops_5455919 branch from 03ee3fd to 890b12e Compare September 25, 2025 22:41
@gcunhase gcunhase merged commit 59a2675 into NVIDIA:main Sep 26, 2025
27 checks passed
kevalmorabia97 pushed a commit that referenced this pull request Sep 26, 2025
kevalmorabia97 pushed a commit that referenced this pull request Sep 26, 2025
yeyu-nvidia pushed a commit that referenced this pull request Oct 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants