Skip to content

Commit 51845f4

Browse files
committed
Fix: remove unnecessary Q/DQ nodes before cast_to_fp32 connecting to custom op
Signed-off-by: gcunhase <[email protected]>
1 parent f1ad1d7 commit 51845f4

File tree

1 file changed

+29
-14
lines changed

1 file changed

+29
-14
lines changed

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -872,22 +872,32 @@ def remove_input_dq_and_output_q(
872872
)
873873

874874
# Only remove DQs from the inputs of custom ops
875-
if consumers[0].op_type not in quantizable_custom_ops:
875+
has_cast = consumers[0].op_type == "Cast"
876+
consumers_2 = tensor_consumers[consumers[0].output[0]] if has_cast else consumers
877+
if consumers_2[0].op_type not in quantizable_custom_ops:
876878
continue
877879

878-
# Rewire graph to connect Q with the node after DQ (skip DQ)
879-
for consumer in consumers:
880-
for cons_idx, cons_inp in enumerate(consumer.input):
881-
if cons_inp == node.output[0]:
882-
# If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ.
883-
if cons_idx in quantizable_custom_ops[consumer.op_type]["inp"]:
884-
consumer.input[cons_idx] = q_node.output[0]
885-
else:
886-
q_node_prev = tensor_producers.get(q_node.input[0], None)
887-
consumer.input[cons_idx] = (
888-
q_node_prev.output[0] if q_node_prev else q_node.input[0]
889-
)
890-
break
880+
if has_cast:
881+
# Assume that this input tensor is not meant to be quantized as there's a Cast node between DQ
882+
# and the custom op. Keep the Cast node and delete both Q/DQ nodes.
883+
q_node_prev = tensor_producers.get(q_node.input[0], None)
884+
consumers[0].input[0] = (
885+
q_node_prev.output[0] if q_node_prev else q_node.input[0]
886+
)
887+
else:
888+
# Rewire graph to connect Q with the node after DQ (skip DQ)
889+
for consumer in consumers:
890+
for cons_idx, cons_inp in enumerate(consumer.input):
891+
if cons_inp == node.output[0]:
892+
# If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ.
893+
if cons_idx in quantizable_custom_ops[consumer.op_type]["inp"]:
894+
consumer.input[cons_idx] = q_node.output[0]
895+
else:
896+
q_node_prev = tensor_producers.get(q_node.input[0], None)
897+
consumer.input[cons_idx] = (
898+
q_node_prev.output[0] if q_node_prev else q_node.input[0]
899+
)
900+
break
891901

892902
# Track DequantizeLinear node indices for cleanup
893903
dq_indices.append(node_idx)
@@ -944,6 +954,11 @@ def remove_input_dq_and_output_q(
944954
f" {len(dq_indices)} DQ node{'' if len(dq_indices) == 1 else 's'}"
945955
)
946956

957+
# Cleanup graph to remove any dangling Q/DQ nodes
958+
graph = gs.import_onnx(onnx_model)
959+
graph.cleanup()
960+
onnx_model = gs.export_onnx(graph)
961+
947962
# TODO: remove manual ir_version change once ORT supports ir_version 11
948963
onnx_model.ir_version = 10
949964

0 commit comments

Comments
 (0)