@@ -872,22 +872,32 @@ def remove_input_dq_and_output_q(
872872 )
873873
874874 # Only remove DQs from the inputs of custom ops
875- if consumers [0 ].op_type not in quantizable_custom_ops :
875+ has_cast = consumers [0 ].op_type == "Cast"
876+ consumers_2 = tensor_consumers [consumers [0 ].output [0 ]] if has_cast else consumers
877+ if consumers_2 [0 ].op_type not in quantizable_custom_ops :
876878 continue
877879
878- # Rewire graph to connect Q with the node after DQ (skip DQ)
879- for consumer in consumers :
880- for cons_idx , cons_inp in enumerate (consumer .input ):
881- if cons_inp == node .output [0 ]:
882- # If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ.
883- if cons_idx in quantizable_custom_ops [consumer .op_type ]["inp" ]:
884- consumer .input [cons_idx ] = q_node .output [0 ]
885- else :
886- q_node_prev = tensor_producers .get (q_node .input [0 ], None )
887- consumer .input [cons_idx ] = (
888- q_node_prev .output [0 ] if q_node_prev else q_node .input [0 ]
889- )
890- break
880+ if has_cast :
881+ # Assume that this input tensor is not meant to be quantized as there's a Cast node between DQ
882+ # and the custom op. Keep the Cast node and delete both Q/DQ nodes.
883+ q_node_prev = tensor_producers .get (q_node .input [0 ], None )
884+ consumers [0 ].input [0 ] = (
885+ q_node_prev .output [0 ] if q_node_prev else q_node .input [0 ]
886+ )
887+ else :
888+ # Rewire graph to connect Q with the node after DQ (skip DQ)
889+ for consumer in consumers :
890+ for cons_idx , cons_inp in enumerate (consumer .input ):
891+ if cons_inp == node .output [0 ]:
892+ # If the input tensor is meant to be quantized, delete DQ. Otherwise, delete both Q/DQ.
893+ if cons_idx in quantizable_custom_ops [consumer .op_type ]["inp" ]:
894+ consumer .input [cons_idx ] = q_node .output [0 ]
895+ else :
896+ q_node_prev = tensor_producers .get (q_node .input [0 ], None )
897+ consumer .input [cons_idx ] = (
898+ q_node_prev .output [0 ] if q_node_prev else q_node .input [0 ]
899+ )
900+ break
891901
892902 # Track DequantizeLinear node indices for cleanup
893903 dq_indices .append (node_idx )
@@ -944,6 +954,11 @@ def remove_input_dq_and_output_q(
944954 f" { len (dq_indices )} DQ node{ '' if len (dq_indices ) == 1 else 's' } "
945955 )
946956
957+ # Cleanup graph to remove any dangling Q/DQ nodes
958+ graph = gs .import_onnx (onnx_model )
959+ graph .cleanup ()
960+ onnx_model = gs .export_onnx (graph )
961+
947962 # TODO: remove manual ir_version change once ORT supports ir_version 11
948963 onnx_model .ir_version = 10
949964
0 commit comments