Skip to content

Conversation

@gcunhase
Copy link
Contributor

@gcunhase gcunhase commented Sep 5, 2025

What does this PR do?

Type of change: Bug fix

Overview: Fixed issue when quantizing constant inputs in custom ops. This was caused by the DQ in input and Q in output removal function assuming that there was a node before and after the QDQs for graph edge rewiring. That's fixed by this PR.

Usage

Can be used with any model with custom ops.

Testing

  1. Follow steps in https://github.com/NVIDIA/DL4AGX/tree/master/AV-Solutions/bevformer-int8-eq to export the ONNX file and compile the plugin.
  2. Quantize model:
python -m modelopt.onnx.quantization \
  --onnx_path=bevformer_tiny_epoch_24_cp2_op13_simp.onnx \
  --calibration_eps trt cuda:0 cpu \
  --trt_plugins ./libs/libtensorrt_ops.so \
  --trt_plugins_precision  MultiScaleDeformableAttnTRT2:[fp16,int64,fp16,fp16,fp16]:[fp16] \
  --high_precision_dtype fp16

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: N/A
  • Did you add or update any necessary documentation?: N/A
  • Did you update Changelog?: No

Additional Information

None

Summary by CodeRabbit

  • Bug Fixes

    • Improved robustness in quantization graph rewiring to handle missing upstream or downstream nodes without errors.
    • Adds defensive fallbacks that preserve original tensor connections when producers or consumers are absent, maintaining processing continuity.
    • Avoids creating empty precision entries for TensorRT plugin handling by only recording ops that actually use int8/fp8 indices.
  • Chores

    • No changes to public APIs or interfaces.

@gcunhase gcunhase self-assigned this Sep 5, 2025
@gcunhase gcunhase requested a review from a team as a code owner September 5, 2025 18:28
@gcunhase gcunhase requested a review from ajrasane September 5, 2025 18:28
@copy-pr-bot
Copy link

copy-pr-bot bot commented Sep 5, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 5, 2025

Walkthrough

Adds defensive lookups and branching when rewiring Q/DQ edges to avoid KeyError by using .get on tensor_producers/tensor_consumers and falling back to original tensors when producers/consumers are absent; also only registers TRT plugin op types for quantization when int8/fp8 input or output indices exist.

Changes

Cohort / File(s) Summary
Quantization Q/DQ rewiring robustness
modelopt/onnx/quantization/qdq_utils.py
Replace direct dict indexing with .get for tensor_producers / tensor_consumers; add conditional branches to update consumer inputs only if previous producer/downstream DQ exists, otherwise preserve original inputs or apply local fallback; prevents KeyError when edges are missing.
TRT plugin precision filtering
modelopt/onnx/trt_utils.py
Only create custom_ops_to_quantize[op_type] when there is at least one int8 or fp8 input/output index; avoid adding entries with empty index lists.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Caller
  participant Utils as remove_input_dq_and_output_q
  participant Prev as PrevProducer (optional)
  participant Q as Q Node
  participant Consumer as Consumer
  participant DQ as DQ Node
  participant DownDQ as Downstream DQ (optional)

  Caller->>Utils: invoke on Q node
  note right of Utils: Input-side handling
  alt Prev exists
    Utils->>Consumer: set Consumer.input[cons_idx] = Prev.output[0]
  else Prev missing
    Utils->>Consumer: keep Consumer.input[cons_idx] = Q.input[0]
  end

  note right of Utils: Output-side handling
  alt Downstream DQ exists
    Utils->>DownDQ: set DownDQ.input[...] = Producer.output[0]
  else Downstream DQ missing
    Utils->>DQ: set DQ.input[0] = Producer.output[0]
  end

  Utils-->>Caller: return
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Poem

I hop through graphs with careful paws,
I check each link and mend the flaws.
When producers vanish, I don't complain—
I patch the wire and keep the train.
A tidy hop, a quiet gain. 🐇✨

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gcunhase gcunhase force-pushed the dev/gcunhasergio/fix_quant_custom_op_5477976 branch from ef10a57 to e27a2b7 Compare September 5, 2025 18:29
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/onnx/quantization/qdq_utils.py (1)

793-795: Good null-safe fallback; consider pruning now-dead Q when deleting both Q/DQ.

Using tensor_producers.get(...) avoids KeyError and correctly falls back to the original tensor when there’s no upstream producer. When you hit the “delete both” path, you likely leave an orphan Q node if it only fed this DQ. You can opportunistically mark that Q for removal.

Apply within-range change plus supporting setup:

-                                q_node_prev = tensor_producers.get(q_node.input[0], None)
-                                consumer.input[cons_idx] = q_node_prev.output[0] if q_node_prev else q_node.input[0]
+                                q_node_prev = tensor_producers.get(q_node.input[0], None)
+                                consumer.input[cons_idx] = q_node_prev.output[0] if q_node_prev else q_node.input[0]
+                                # If Q had only this DQ consumer, schedule Q for deletion too
+                                if len(tensor_consumers.get(q_node.output[0], [])) == 1:
+                                    q_indices.append(q_index_map[q_node.name])

Outside this hunk, initialize once near the top of the function:

# Build fast lookup maps for later deletions
q_index_map = {n.name: idx for idx, n in q_nodes}
dq_index_map = {n.name: idx for idx, n in dq_nodes}
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 1cf78b2 and e27a2b7.

📒 Files selected for processing (1)
  • modelopt/onnx/quantization/qdq_utils.py (2 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs

@gcunhase gcunhase force-pushed the dev/gcunhasergio/fix_quant_custom_op_5477976 branch 5 times, most recently from dfd8f6a to 505f715 Compare September 5, 2025 18:49
@gcunhase gcunhase changed the title Draft: [5477976] Fix: issue removing Q/DQ nodes around custom ops with constant inputs [5477976] Fix: issue removing Q/DQ nodes around custom ops with constant inputs Sep 5, 2025
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 5, 2025

Note

Unit test generation is an Early Access feature. Expect some limitations and changes as we gather feedback and continue to improve it.


Generating unit tests... This may take up to 20 minutes.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 5, 2025

Here are the copyable unit test edits:

Copyable Edits

tests/unit/onnx/test_qdq_utils.py

@@ -16,6 +16,7 @@
 import numpy as np
 import pytest
 from onnx import TensorProto, helper, numpy_helper
+from onnx import checker
 
 from modelopt.onnx.quantization.qdq_utils import _cast_fp4, _cast_fp8, quantize_weights_to_int4
 
@@ -333,3 +334,129 @@
         assert result.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")]))
         assert result.shape == expected_array.shape
         assert np.all(result == expected_array)
+
+# ---------------------------------------------------------------------------
+# Additional tests appended by CI to broaden coverage for qdq_utils
+# ---------------------------------------------------------------------------
+
+def create_minimal_model_dq_matmul_no_reshape_transpose():
+    """Create a model with DequantizeLinear feeding MatMul directly (no Reshape/Transpose)."""
+    # weight: int8, shape (K, N) = (8, 16)
+    w_data = np.random.randint(-8, 8, size=(8, 16), dtype=np.int8)
+    w = numpy_helper.from_array(w_data, "w_no_rt")
+    # scale per-row
+    s_data = np.random.uniform(0.1, 1.0, size=(8, 1)).astype(np.float32)
+    s = numpy_helper.from_array(s_data, "s_no_rt")
+
+    # input: (M, K) = (?, 8)
+    inp = helper.make_tensor_value_info("inp", TensorProto.FLOAT, [None, 8])
+
+    dq = helper.make_node("DequantizeLinear", ["w_no_rt", "s_no_rt"], ["dq_w"], name="dq_no_rt")
+    mm = helper.make_node("MatMul", ["inp", "dq_w"], ["out"], name="mm_no_rt")
+
+    graph = helper.make_graph(
+        [dq, mm],
+        "g_no_rt",
+        [inp],
+        [helper.make_tensor_value_info("out", TensorProto.FLOAT, [None, 16])],
+        initializer=[w, s],
+    )
+    return helper.make_model(graph)
+
+
+class TestQuantizeWeightsToInt4_Additional:
+    def test_model_checker_valid_after_quantization(self):
+        """Quantized model should pass ONNX checker validation."""
+        model = create_test_model_with_dq_reshape_transpose_matmul()
+        qmodel = quantize_weights_to_int4(model)
+        # Validate structural correctness
+        checker.check_model(qmodel)
+
+    def test_idempotency_of_quantization(self):
+        """Applying quantization twice should be stable (no additional structural changes)."""
+        model = create_test_model_with_dq_reshape_transpose_matmul()
+        q1 = quantize_weights_to_int4(model)
+        q2 = quantize_weights_to_int4(q1)
+
+        # Compare key invariants: node op types multiset and initializer (name, dtype) pairs
+        ops1 = sorted([n.op_type for n in q1.graph.node])
+        ops2 = sorted([n.op_type for n in q2.graph.node])
+        assert ops1 == ops2
+
+        inits1 = sorted([(i.name, i.data_type) for i in q1.graph.initializer])
+        inits2 = sorted([(i.name, i.data_type) for i in q2.graph.initializer])
+        assert inits1 == inits2
+
+        # MatMul should still consume DequantizeLinear output directly
+        mm1 = next(n for n in q1.graph.node if n.op_type == "MatMul")
+        dq1 = next(n for n in q1.graph.node if n.op_type == "DequantizeLinear")
+        mm2 = next(n for n in q2.graph.node if n.op_type == "MatMul")
+        dq2 = next(n for n in q2.graph.node if n.op_type == "DequantizeLinear")
+        assert mm1.input[1] == dq1.output[0]
+        assert mm2.input[1] == dq2.output[0]
+
+    def test_no_pattern_present_is_handled_gracefully(self):
+        """When Reshape/Transpose pattern is absent, quantization should still succeed and keep graph valid."""
+        model = create_minimal_model_dq_matmul_no_reshape_transpose()
+        qmodel = quantize_weights_to_int4(model)
+
+        # Still valid ONNX
+        checker.check_model(qmodel)
+
+        # Weight initializer should be INT4 after quantization
+        w_init = next(i for i in qmodel.graph.initializer if i.name == "w_no_rt")
+        assert w_init.data_type == TensorProto.INT4
+
+        # Graph should still contain DequantizeLinear and MatMul; no Reshape/Transpose should appear
+        node_types = [n.op_type for n in qmodel.graph.node]
+        assert "DequantizeLinear" in node_types
+        assert "MatMul" in node_types
+        assert "Reshape" not in node_types
+        assert "Transpose" not in node_types
+
+
+class TestCastFunctions_Additional:
+    def test_cast_fp8_empty_and_specials(self):
+        """_cast_fp8 should handle empty arrays and special values without error and with correct dtype/shape."""
+        # Empty
+        arr_empty = np.array([], dtype=np.float32)
+        out_empty = _cast_fp8(arr_empty)
+        assert out_empty.dtype == np.dtype((np.uint8, [("e4m3fn", "u1")]))
+        assert out_empty.shape == (0,)
+
+        # Specials: NaN and Inf should not crash
+        arr_specials = np.array([np.nan, np.inf, -np.inf, 0.0, -0.0], dtype=np.float32)
+        out_specials = _cast_fp8(arr_specials)
+        assert out_specials.dtype == np.dtype((np.uint8, [("e4m3fn", "u1")]))
+        assert out_specials.shape == arr_specials.shape
+
+    def test_cast_fp8_random_bulk_shape_and_bounds(self):
+        """_cast_fp8 should preserve shape and produce uint8 payloads."""
+        arr = np.random.randn(7, 3, 5).astype(np.float32) * 10.0
+        out = _cast_fp8(arr)
+        assert out.shape == arr.shape
+        assert out.dtype == np.dtype((np.uint8, [("e4m3fn", "u1")]))
+        # Values are uint8-coded; verify range implicitly via dtype and via max/min checks on the raw view
+        raw = out.view(np.uint8)
+        assert raw.min() >= 0 and raw.max() <= 255
+
+    def test_cast_fp4_empty_and_specials(self):
+        """_cast_fp4 should handle empty arrays and special values without error and with correct dtype/shape."""
+        arr_empty = np.array([], dtype=np.float32)
+        out_empty = _cast_fp4(arr_empty)
+        assert out_empty.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")]))
+        assert out_empty.shape == (0,)
+
+        arr_specials = np.array([np.nan, np.inf, -np.inf, 0.0, -0.0], dtype=np.float32)
+        out_specials = _cast_fp4(arr_specials)
+        assert out_specials.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")]))
+        assert out_specials.shape == arr_specials.shape
+
+    def test_cast_fp4_random_bulk_shape_and_bounds(self):
+        """_cast_fp4 should preserve shape and produce uint8 payloads."""
+        arr = (np.random.rand(4, 4, 4).astype(np.float32) - 0.5) * 6.0
+        out = _cast_fp4(arr)
+        assert out.shape == arr.shape
+        assert out.dtype == np.dtype((np.uint8, [("float4e2m1", "u1")]))
+        raw = out.view(np.uint8)
+        assert raw.min() >= 0 and raw.max() <= 255

tests/unit/onnx/test_trt_utils.py

This is a new file.

# Auto-generated unit tests for trt_utils

import sys
import types
import importlib
import builtins
import onnx
import onnx.helper as oh
import onnx_graphsurgeon as gs
import pytest

# Detect source module path
# Prefer modelopt.onnx.trt_utils; fallback to importing via tests path if needed.
def _import_trt_utils_with_stubbed_tensorrt(stub=None):
    if stub is None:
        # Minimal TensorRT stub with DataType enums required by infer_types_shapes
        stub = types.ModuleType("tensorrt")
        class _DT:
            pass
        # Use distinct singleton objects for mapping keys
        for name in ["float32","float16","bfloat16","int4","int8","uint8","int32","int64","bool","fp8","fp4"]:
            setattr(stub, name, object())
        stub.DataType = _DT
        # Minimal Logger enum placeholders to avoid accidental imports if functions call them
        class _Logger:
            WARNING = 2
        stub.Logger = _Logger
        class _NDCF:
            STRONGLY_TYPED = 0
        stub.NetworkDefinitionCreationFlag = _NDCF
    sys.modules["tensorrt"] = stub
    # Fresh import to bind the stubbed tensorrt
    try:
        mod = importlib.import_module("modelopt.onnx.trt_utils")
    except ModuleNotFoundError:
        # Fallback: try relative import if project layout differs
        mod = importlib.import_module("trt_utils")
    return mod

@pytest.fixture
def trt_utils_module(monkeypatch):
    # Ensure a clean tensorrt stub per test
    if "tensorrt" in sys.modules:
        del sys.modules["tensorrt"]
    return _import_trt_utils_with_stubbed_tensorrt()

def _simple_graph_with_custom_op(op_type="MyCustomOp", num_inps=2, num_outs=1):
    """
    Build a tiny ONNX model using GraphSurgeon where a node with op_type exists.
    Inputs/outputs are tensors without specific types/shapes; sufficient for traversal.
    """
    inputs = [gs.Variable(f"in{i}", dtype=None, shape=None) for i in range(num_inps)]
    outputs = [gs.Variable(f"out{i}", dtype=None, shape=None) for i in range(num_outs)]
    node = gs.Node(op=op_type, inputs=inputs, outputs=outputs, name=f"{op_type}_node")
    graph = gs.Graph(nodes=[node], inputs=inputs, outputs=outputs)
    model = gs.export_onnx(graph)
    return model, node

class TestInterpretTrtPluginsPrecisionFlag:
    def test_single_precision_fp16_casts_all_io(self, trt_utils_module):
        # Model has one custom op with 2 inputs and 3 outputs
        model, _ = _simple_graph_with_custom_op(op_type="MyOp", num_inps=2, num_outs=3)
        cast, quant = trt_utils_module.interpret_trt_plugins_precision_flag(
            model, ["MyOp:fp16"], quantize_mode="int8"
        )
        assert cast == {"MyOp": {"inp": [0,1], "out": [0,1,2]}}
        assert quant == {}

    def test_single_precision_int8_quantizes_all_io_respects_quantize_mode(self, trt_utils_module):
        # Two nodes of same type but different arities; function uses max across nodes
        model1, _ = _simple_graph_with_custom_op("MyQ", 1, 1)
        model2, _ = _simple_graph_with_custom_op("MyQ", 3, 2)
        # Merge graphs by concatenating nodes in a new graph
        g = gs.import_onnx(model1)
        g2 = gs.import_onnx(model2)
        graph = gs.Graph(nodes=g.nodes + g2.nodes, inputs=g.inputs + g2.inputs, outputs=g.outputs + g2.outputs)
        merged = gs.export_onnx(graph)

        cast, quant = trt_utils_module.interpret_trt_plugins_precision_flag(
            merged, ["MyQ:fp8"], quantize_mode="int8"
        )
        # Even though fp8 requested, it should align to quantize_mode=int8 and select max arity: 3 inps, 2 outs
        assert "MyQ" in quant and quant["MyQ"]["inp"] == [0,1,2] and quant["MyQ"]["out"] == [0,1]
        assert cast == {}

    def test_explicit_io_precisions_cast_and_quantize(self, trt_utils_module):
        # 2 inputs, 3 outputs
        model, _ = _simple_graph_with_custom_op("OpX", 2, 3)
        cast, quant = trt_utils_module.interpret_trt_plugins_precision_flag(
            model, ["OpX:[fp16,fp32]:[fp16,int8,fp32]"], quantize_mode="int8"
        )
        # Cast: inputs with fp16 -> [0]; outputs with fp16 or fp32 -> [0,2]
        assert cast["OpX"]["inp"] == [0]
        assert cast["OpX"]["out"] == [0,2]
        # Quantize: outputs with int8 -> [1]
        assert quant["OpX"]["inp"] == []
        assert quant["OpX"]["out"] == [1]

    def test_mismatched_counts_raise(self, trt_utils_module):
        model, _ = _simple_graph_with_custom_op("OpY", 2, 2)
        with pytest.raises(AssertionError):
            trt_utils_module.interpret_trt_plugins_precision_flag(
                model, ["OpY:[fp16]:[fp16,fp16]"], quantize_mode="int8"
            )
        with pytest.raises(AssertionError):
            trt_utils_module.interpret_trt_plugins_precision_flag(
                model, ["OpY:[fp16,fp16]:[fp16]"], quantize_mode="int8"
            )

    def test_unknown_op_type_is_skipped(self, trt_utils_module, caplog):
        model, _ = _simple_graph_with_custom_op("Existing", 1, 1)
        cast, quant = trt_utils_module.interpret_trt_plugins_precision_flag(
            model, ["MissingOp:fp16"], quantize_mode="int8"
        )
        assert cast == {}
        assert quant == {}

class TestSetTrtPluginDomain:
    def test_sets_domain_and_opset_once(self, trt_utils_module):
        # Graph with custom and vanilla nodes
        inputs = [gs.Variable("a"), gs.Variable("b")]
        out1 = gs.Variable("c")
        out2 = gs.Variable("d")
        custom = gs.Node(op="CustomTRT", inputs=[inputs[0]], outputs=[out1], name="custom1")
        add = gs.Node(op="Add", inputs=[out1, inputs[1]], outputs=[out2], name="add1")
        graph = gs.Graph(nodes=[custom, add], inputs=inputs, outputs=[out2])
        model = gs.export_onnx(graph)

        updated = trt_utils_module.set_trt_plugin_domain(model, custom_ops={"CustomTRT"})
        g = gs.import_onnx(updated)
        # Domain should be set on custom op only
        custom_nodes = [n for n in g.nodes if n.op == "CustomTRT"]
        assert len(custom_nodes) == 1 and custom_nodes[0].domain == "trt.plugins"
        add_nodes = [n for n in g.nodes if n.op == "Add"]
        assert add_nodes and (add_nodes[0].domain is None or add_nodes[0].domain == "")

        # TRT domain opset should be appended
        domains = {(opset.domain, opset.version) for opset in updated.opset_import}
        assert ("trt.plugins", 1) in domains

class TestInferTypesShapes:
    def test_sets_elem_type_and_shape_on_existing_value_info(self, trt_utils_module):
        # Create a simple model: X + Y -> Z; add Z to value_info without type/shape
        X = oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 3])
        Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 3])
        Z = oh.make_tensor_value_info("Z", 0, [])  # elem_type unset (0), shape empty
        node = oh.make_node("Add", inputs=["X", "Y"], outputs=["Z"], name="add_node")
        graph = oh.make_graph([node], "g", [X, Y], [oh.make_tensor_value_info("O", onnx.TensorProto.FLOAT, [1,3])])
        # Attach Z as value_info (not output)
        graph.value_info.extend([Z])
        model = oh.make_model(graph)

        # all_tensor_info provides dtype (TRT) and shape (with dynamic str dim)
        trt = sys.modules["tensorrt"]
        info = {"Z": {"dtype": getattr(trt, "float16"), "shape": [1, "unk", 3]}}

        updated = trt_utils_module.infer_types_shapes(model, info)

        # Fetch updated Z value_info
        def _get_vi(m, name):
            for vi in m.graph.value_info:
                if vi.name == name:
                    return vi
            return None

        z_vi = _get_vi(updated, "Z")
        assert z_vi is not None
        assert z_vi.type.tensor_type.elem_type == onnx.TensorProto.FLOAT16
        dims = z_vi.type.tensor_type.shape.dim
        assert [d.dim_value if d.HasField("dim_value") else d.dim_param for d in dims] == [1, "unk", 3]

    def test_appends_value_info_when_missing(self, trt_utils_module):
        # Node output W not present in value_info; expect function to append it.
        A = oh.make_tensor_value_info("A", onnx.TensorProto.FLOAT, [1, 2])
        B = oh.make_tensor_value_info("B", onnx.TensorProto.FLOAT, [1, 2])
        node = oh.make_node("Mul", inputs=["A", "B"], outputs=["W"], name="mul_node")
        graph = oh.make_graph([node], "g2", [A, B], [])
        model = oh.make_model(graph)

        trt = sys.modules["tensorrt"]
        info = {"W": {"dtype": getattr(trt, "int8"), "shape": ["unk", 2]}}

        updated = trt_utils_module.infer_types_shapes(model, info)

        names = [vi.name for vi in updated.graph.value_info]
        assert "W" in names
        # Ensure elem type is INT8
        vi = next(vi for vi in updated.graph.value_info if vi.name == "W")
        assert vi.type.tensor_type.elem_type == onnx.TensorProto.INT8
        dims = vi.type.tensor_type.shape.dim
        assert [d.dim_value if d.HasField("dim_value") else d.dim_param for d in dims] == ["unk", 2]

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/onnx/trt_utils.py (1)

294-297: Possible None dereference when appending to intermediate_generated_files

intermediate_generated_files is typed as Optional, but append is unconditional and will raise if None. Guard it.

-            intermediate_generated_files.append(static_shaped_onnx_path)  # type: ignore[union-attr]
+            if intermediate_generated_files is not None:
+                intermediate_generated_files.append(static_shaped_onnx_path)
🧹 Nitpick comments (1)
modelopt/onnx/trt_utils.py (1)

419-424: Mirror quantize guard for casts and clarify docstring

  • Only add custom_ops_to_cast[op_type] when inp_precision_cast or out_precision_cast is non-empty for symmetry with quantization.
  • Update interpret_trt_plugins_precision_flag docstring to state that op types with no quantizable indices are omitted from custom_ops_to_quantize.
  • Consumers already guard map usage via membership checks or iterate keys—no direct indexing found.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between e27a2b7 and 6c5231a.

📒 Files selected for processing (2)
  • modelopt/onnx/quantization/qdq_utils.py (2 hunks)
  • modelopt/onnx/trt_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/onnx/quantization/qdq_utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/onnx/trt_utils.py (1)

419-424: Good guard: avoid no-op entries in custom_ops_to_quantize

Only adding the op when there is at least one int8/fp8 I/O index prevents empty keys and simplifies downstream handling. This aligns with the PR’s intent around edge cases.

@codecov
Copy link

codecov bot commented Sep 5, 2025

Codecov Report

❌ Patch coverage is 0% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.93%. Comparing base (0d279f1) to head (fabe5b8).
⚠️ Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/quantization/qdq_utils.py 0.00% 6 Missing ⚠️
modelopt/onnx/trt_utils.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #296      +/-   ##
==========================================
- Coverage   73.94%   73.93%   -0.01%     
==========================================
  Files         172      172              
  Lines       17405    17408       +3     
==========================================
+ Hits        12870    12871       +1     
- Misses       4535     4537       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

@ajrasane ajrasane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please add/update any related tests.

@gcunhase gcunhase force-pushed the dev/gcunhasergio/fix_quant_custom_op_5477976 branch 2 times, most recently from d8089c5 to 032e470 Compare September 5, 2025 22:34
@gcunhase gcunhase force-pushed the dev/gcunhasergio/fix_quant_custom_op_5477976 branch from 032e470 to fabe5b8 Compare September 5, 2025 22:38
@gcunhase gcunhase enabled auto-merge (squash) September 5, 2025 22:48
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (5)
modelopt/onnx/trt_utils.py (5)

411-415: Mirror the guard for casts to avoid empty keys in custom_ops_to_cast.

Not functionally wrong, but staying consistent reduces downstream branching and noise.

Apply:

-            custom_ops_to_cast[op_type] = {"inp": inp_precision_cast, "out": out_precision_cast}
+            if inp_precision_cast or out_precision_cast:
+                custom_ops_to_cast[op_type] = {
+                    "inp": inp_precision_cast,
+                    "out": out_precision_cast,
+                }

376-383: Fix logging to show the originally requested precision.

precision is overwritten before logging, so the warning prints equal values.

Apply:

-            if precision in ["int8", "fp8"]:
-                if precision != quantize_mode:
-                    precision = quantize_mode
-                    logger.warning(
-                        f"Requested custom op precision ({precision}) is different than quantize mode: "
-                        f"{quantize_mode}. Mixed {precision}+{quantize_mode} precision is not yet supported. "
-                        f"Setting the custom op precision to be the same as quantize mode."
-                    )
+            if precision in ["int8", "fp8"]:
+                if precision != quantize_mode:
+                    requested_precision = precision
+                    precision = quantize_mode
+                    logger.warning(
+                        f"Requested custom op precision ({requested_precision}) is different than quantize mode: "
+                        f"{quantize_mode}. Mixed {requested_precision}+{quantize_mode} precision is not yet supported. "
+                        f"Setting the custom op precision to be the same as quantize mode."
+                    )

388-394: Trim whitespace when parsing per-IO precisions.

Avoids false “unsupported” warnings for inputs like [fp16, fp8 ].

Apply:

-            inp_precision, out_precision = precision.split(":")
-            inp_precision = inp_precision.strip("[]").split(",")
-            out_precision = out_precision.strip("[]").split(",")
+            raw_inp, raw_out = precision.split(":")
+            inp_precision = [p.strip() for p in raw_inp.strip("[]").split(",") if p.strip()]
+            out_precision = [p.strip() for p in raw_out.strip("[]").split(",") if p.strip()]

395-400: Prefer explicit exceptions over assert for user input validation.

Asserts can be stripped with optimization and yield generic errors.

Apply:

-            assert len(inp_precision) == num_inps, (
-                f"Number of inputs doesn't match expectation: {len(inp_precision)} vs {num_inps}."
-            )
-            assert len(out_precision) == num_outs, (
-                f"Number of outputs doesn't match expectation: {len(out_precision)} vs {num_outs}."
-            )
+            if len(inp_precision) != num_inps:
+                raise ValueError(
+                    f"Number of inputs doesn't match expectation: {len(inp_precision)} vs {num_inps} for op '{op_type}'."
+                )
+            if len(out_precision) != num_outs:
+                raise ValueError(
+                    f"Number of outputs doesn't match expectation: {len(out_precision)} vs {num_outs} for op '{op_type}'."
+                )

218-224: Avoid mutable default args.

trt_plugins=[] and all_tensor_info={} are mutable defaults.

Apply:

-def infer_types_shapes_tensorrt(
-    model: onnx.ModelProto,
-    trt_plugins: list[str] = [],
-    all_tensor_info: dict = {},
+def infer_types_shapes_tensorrt(
+    model: onnx.ModelProto,
+    trt_plugins: list[str] | None = None,
+    all_tensor_info: dict | None = None,
@@
-    if not all_tensor_info:
-        _, all_tensor_info = get_custom_layers(model, trt_plugins, strongly_typed)
+    if not all_tensor_info:
+        _, all_tensor_info = get_custom_layers(model, trt_plugins or [], strongly_typed)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d8089c5 and fabe5b8.

📒 Files selected for processing (2)
  • modelopt/onnx/quantization/qdq_utils.py (2 hunks)
  • modelopt/onnx/trt_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/onnx/quantization/qdq_utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: gpu-tests-pr
  • GitHub Check: partial-install (torch)
  • GitHub Check: multi-transformers (min)
  • GitHub Check: multi-torch (26)
  • GitHub Check: multi-torch (27)
  • GitHub Check: windows
🔇 Additional comments (2)
modelopt/onnx/trt_utils.py (2)

419-423: Good guard: only add quantize entries when non-empty.

Prevents empty {inp: [], out: []} entries in custom_ops_to_quantize and aligns behavior with intent. Nice fix.


419-423: The scripts above will print the interpreter function and examine how custom_ops_to_quantize is used downstream. Once you provide the output, I’ll confirm whether any direct indexing could raise KeyError or if all usages safely guard missing keys.

@gcunhase gcunhase merged commit d5c88e7 into NVIDIA:main Sep 5, 2025
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants