@@ -871,6 +871,78 @@ def remove_input_dq_and_output_q(
871871 return onnx_model
872872
873873
874+ def remove_graph_input_q (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
875+ """Remove Q nodes from the inputs of a quantized ONNX model.
876+
877+ This supports generating quantized models with low-precision graph I/O.
878+
879+ Args:
880+ onnx_model: ONNX model protobuf to convert
881+
882+ Returns:
883+ ONNX model protobuf with only DQ in the inputs whenever possible.
884+
885+ Raises:
886+ ValueError: If the model is invalid or removal fails
887+ RuntimeError: If graph operations fail
888+ """
889+ logger .info ("Deleting Q nodes in the input of a quantized ONNX model." )
890+ if not isinstance (onnx_model , onnx .ModelProto ):
891+ raise ValueError ("Input must be an ONNX model protobuf" )
892+
893+ graph = onnx_model .graph
894+ if not graph .node :
895+ raise ValueError ("Model graph is empty" )
896+
897+ initializers , _ , tensor_consumers = _get_graph_metadata (graph )
898+ q_nodes = [
899+ (idx , node ) for idx , node in enumerate (graph .node ) if node .op_type == "QuantizeLinear"
900+ ]
901+ q_indices = []
902+ graph_input_names = {inp .name : inp for inp in graph .input }
903+
904+ # Remove Q nodes in the graph inputs
905+ for node_idx , node in q_nodes :
906+ if not any (inp in graph_input_names for inp in node .input ):
907+ continue
908+
909+ inp = node .input [0 ]
910+ for out_name in node .output :
911+ logger .debug (f"Processing QDQ node for output { out_name } " )
912+
913+ try :
914+ # Update the Q node output name, each Q should only have one DQ consumer
915+ dq_node = tensor_consumers [out_name ]
916+ assert len (dq_node ) == 1 , f"Expected single consumer for { node .name } "
917+ assert dq_node [0 ].op_type == "DequantizeLinear" , (
918+ f"Expected DequantizeLinear producer for { node .name } "
919+ )
920+
921+ # Rewire graph to connect the graph input to the output of the Q node
922+ dq_node [0 ].input [0 ] = inp
923+
924+ # Set the input precision to match the zero-point precision in the DQ node
925+ inp_tensor = graph_input_names [inp ]
926+ inp_tensor .type .tensor_type .elem_type = initializers [dq_node [0 ].input [2 ]].data_type
927+
928+ # Track QuantizeLinear node indices for cleanup
929+ q_indices .append (node_idx )
930+
931+ except Exception as e :
932+ raise RuntimeError (f"Failed to convert node { node .name } : { e !s} " )
933+
934+ # Remove processed nodes
935+ for node_idx in sorted (q_indices , reverse = True ):
936+ del graph .node [node_idx ]
937+
938+ logger .info (f"Removed { len (q_indices )} Q node{ '' if len (q_indices ) == 1 else 's' } " )
939+
940+ # TODO: remove manual ir_version change once ORT supports ir_version 11
941+ onnx_model .ir_version = 10
942+
943+ return onnx_model
944+
945+
874946def _cast_initializer_to_dtype (
875947 node : onnx .NodeProto , dtype : str , initializer_map : dict [str , onnx .TensorProto ]
876948):
0 commit comments