Skip to content

Commit 9e7cd86

Browse files
committed
coderabbitai fixes
Signed-off-by: Riyad Islam <[email protected]>
1 parent 9123d26 commit 9e7cd86

File tree

4 files changed

+14
-15
lines changed

4 files changed

+14
-15
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ Model Optimizer Changelog (Linux)
55
^^^^^^^^^^^^^^^^^
66

77
**Deprecations**
8+
- Deprecated ``quantize_mode`` argument in ``examples/onnx_ptq/evaluate.py`` to support strongly typing. Use ``engine_precision`` instead.
89

910
**Bug Fixes**
1011

1112
**New Features**
13+
- ``high_precision_dtype`` default to fp16 in ONNX quantization, i.e. quantized output model weights are now FP16 by default.
1214

1315
0.35 (2025-09-04)
1416
^^^^^^^^^^^^^^^^^

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,11 +527,13 @@ def _get_successive_consumers(
527527
raise ValueError(f"Invalid consumer for {node.name}")
528528

529529
quantized_node = tensor_consumers.get(dq_node.output[0], [None])[0]
530-
if quantized_node.op_type == "Cast":
531-
quantized_node = tensor_consumers.get(quantized_node.output[0], [None])[0]
532-
533530
if not quantized_node:
534531
raise ValueError(f"No consumer found for {dq_node.name}")
532+
if quantized_node.op_type == "Cast":
533+
next_node = tensor_consumers.get(quantized_node.output[0], [None])[0]
534+
if not next_node:
535+
raise ValueError(f"No consumer found after Cast for {quantized_node.name}")
536+
quantized_node = next_node
535537

536538
return dq_node, quantized_node
537539

modelopt/torch/_deploy/_runtime/tensorrt/tensorrt_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from modelopt.onnx.utils import get_batch_size
2525
from modelopt.onnx.utils import get_input_names as get_onnx_input_names
2626

27-
from .constants import TENSORRT_8_MAJOR_VERSION, TRTMode
27+
from .constants import TENSORRT_8_MAJOR_VERSION
2828

2929

3030
def is_trt8():
@@ -131,11 +131,6 @@ def get_output_shapes(
131131
return output_shapes
132132

133133

134-
def validate_precision(precision: str) -> bool:
135-
"""Returns whether an input precision is in supported set."""
136-
return precision in [TRTMode.FLOAT32, TRTMode.FLOAT16, TRTMode.INT8]
137-
138-
139134
def calib_data_generator(onnx_bytes: bytes, input_tensors: list[np.ndarray]):
140135
"""The calibation data generator that yields calibration feed_dict to tensorrt."""
141136
input_names = get_onnx_input_names(onnx.load_from_string(onnx_bytes))

tests/_test_utils/onnx_quantization/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ def _assert_nodes_are_quantized(nodes):
2020
for node in nodes:
2121
for inp_idx, inp in enumerate(node.inputs):
2222
if isinstance(inp, gs.Variable):
23-
qnode = node
24-
# After quantization, the quantized node can be casted
25-
if qnode.i(inp_idx).op == "Cast":
26-
qnode = qnode.i(inp_idx)
27-
assert qnode.i(inp_idx).op == "DequantizeLinear", (
28-
f"Input '{inp.name}' of node '{qnode.name}' is not quantized but should be!"
23+
producer = node.i(inp_idx)
24+
# Quantized path may include a Cast right after DQ
25+
if producer and producer.op == "Cast":
26+
producer = producer.i(0)
27+
assert producer and producer.op == "DequantizeLinear", (
28+
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
2929
)
3030
return True

0 commit comments

Comments
 (0)