@@ -871,6 +871,78 @@ def remove_input_dq_and_output_q(
871
871
return onnx_model
872
872
873
873
874
+ def remove_graph_input_q (onnx_model : onnx .ModelProto ) -> onnx .ModelProto :
875
+ """Remove Q nodes from the inputs of a quantized ONNX model.
876
+
877
+ This supports generating quantized models with low-precision graph I/O.
878
+
879
+ Args:
880
+ onnx_model: ONNX model protobuf to convert
881
+
882
+ Returns:
883
+ ONNX model protobuf with only DQ in the inputs whenever possible.
884
+
885
+ Raises:
886
+ ValueError: If the model is invalid or removal fails
887
+ RuntimeError: If graph operations fail
888
+ """
889
+ logger .info ("Deleting Q nodes in the input of a quantized ONNX model." )
890
+ if not isinstance (onnx_model , onnx .ModelProto ):
891
+ raise ValueError ("Input must be an ONNX model protobuf" )
892
+
893
+ graph = onnx_model .graph
894
+ if not graph .node :
895
+ raise ValueError ("Model graph is empty" )
896
+
897
+ initializers , _ , tensor_consumers = _get_graph_metadata (graph )
898
+ q_nodes = [
899
+ (idx , node ) for idx , node in enumerate (graph .node ) if node .op_type == "QuantizeLinear"
900
+ ]
901
+ q_indices = []
902
+ graph_input_names = {inp .name : inp for inp in graph .input }
903
+
904
+ # Remove Q nodes in the graph inputs
905
+ for node_idx , node in q_nodes :
906
+ if not any (inp in graph_input_names for inp in node .input ):
907
+ continue
908
+
909
+ inp = node .input [0 ]
910
+ for out_name in node .output :
911
+ logger .debug (f"Processing QDQ node for output { out_name } " )
912
+
913
+ try :
914
+ # Update the Q node output name, each Q should only have one DQ consumer
915
+ dq_node = tensor_consumers [out_name ]
916
+ assert len (dq_node ) == 1 , f"Expected single consumer for { node .name } "
917
+ assert dq_node [0 ].op_type == "DequantizeLinear" , (
918
+ f"Expected DequantizeLinear producer for { node .name } "
919
+ )
920
+
921
+ # Rewire graph to connect the graph input to the output of the Q node
922
+ dq_node [0 ].input [0 ] = inp
923
+
924
+ # Set the input precision to match the zero-point precision in the DQ node
925
+ inp_tensor = graph_input_names [inp ]
926
+ inp_tensor .type .tensor_type .elem_type = initializers [dq_node [0 ].input [2 ]].data_type
927
+
928
+ # Track QuantizeLinear node indices for cleanup
929
+ q_indices .append (node_idx )
930
+
931
+ except Exception as e :
932
+ raise RuntimeError (f"Failed to convert node { node .name } : { e !s} " )
933
+
934
+ # Remove processed nodes
935
+ for node_idx in sorted (q_indices , reverse = True ):
936
+ del graph .node [node_idx ]
937
+
938
+ logger .info (f"Removed { len (q_indices )} Q node{ '' if len (q_indices ) == 1 else 's' } " )
939
+
940
+ # TODO: remove manual ir_version change once ORT supports ir_version 11
941
+ onnx_model .ir_version = 10
942
+
943
+ return onnx_model
944
+
945
+
874
946
def _cast_initializer_to_dtype (
875
947
node : onnx .NodeProto , dtype : str , initializer_map : dict [str , onnx .TensorProto ]
876
948
):
0 commit comments