Skip to content

Commit d5c88e7

Browse files
authored
[5477976] Fix: issue removing Q/DQ nodes around custom ops with constant inputs (#296)
Signed-off-by: gcunhase <[email protected]>
1 parent 67af656 commit d5c88e7

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -790,8 +790,10 @@ def remove_input_dq_and_output_q(
790790
if cons_idx in quantizable_custom_ops[consumer.op_type]["inp"]:
791791
consumer.input[cons_idx] = q_node.output[0]
792792
else:
793-
q_node_prev = tensor_producers[q_node.input[0]]
794-
consumer.input[cons_idx] = q_node_prev.output[0]
793+
q_node_prev = tensor_producers.get(q_node.input[0], None)
794+
consumer.input[cons_idx] = (
795+
q_node_prev.output[0] if q_node_prev else q_node.input[0]
796+
)
795797
break
796798

797799
# Track DequantizeLinear node indices for cleanup
@@ -828,8 +830,11 @@ def remove_input_dq_and_output_q(
828830
if quantizable_custom_ops[producer.op_type]["out"]:
829831
dq_node[0].input[0] = producer.output[0]
830832
else:
831-
dq_node_next = tensor_consumers[dq_node[0].output[0]]
832-
dq_node_next[0].input[0] = producer.output[0]
833+
dq_node_next = tensor_consumers.get(dq_node[0].output[0], None)
834+
if dq_node_next:
835+
dq_node_next[0].input[0] = producer.output[0]
836+
else:
837+
dq_node[0].input[0] = producer.output[0]
833838

834839
# Track QuantizeLinear node indices for cleanup
835840
q_indices.append(node_idx)

modelopt/onnx/trt_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,10 @@ def interpret_trt_plugins_precision_flag(
416416
# Will add Q/DQ nodes in the requested I/O indices
417417
inp_precision_quant = [i for i, p in enumerate(inp_precision) if p in ["int8", "fp8"]]
418418
out_precision_quant = [i for i, p in enumerate(out_precision) if p in ["int8", "fp8"]]
419-
custom_ops_to_quantize[op_type] = {
420-
"inp": inp_precision_quant,
421-
"out": out_precision_quant,
422-
}
419+
if inp_precision_quant or out_precision_quant:
420+
custom_ops_to_quantize[op_type] = {
421+
"inp": inp_precision_quant,
422+
"out": out_precision_quant,
423+
}
423424

424425
return custom_ops_to_cast, custom_ops_to_quantize

0 commit comments

Comments
 (0)