diff --git a/scripts/float16.py b/scripts/float16.py index 815134b9e..0be170f33 100644 --- a/scripts/float16.py +++ b/scripts/float16.py @@ -29,6 +29,7 @@ import warnings from onnx import helper, numpy_helper from onnx import onnx_pb as onnx_proto +import onnxslim.third_party.onnx_graphsurgeon as gs FLOAT32 = 1 @@ -179,8 +180,12 @@ def make_value_info_from_tensor(tensor): "Max", "Upsample", # NEW: - "Cast", "RandomNormalLike", + # TODO: Ideally, "Cast" nodes should not be here, for the following reasons: + # - It breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime". + # - When fp32 casts already exist in the model (e.g., for rotary embeddings), this script will insert redundant casts around it. + # However, without it, the graphs produced are invalid. Eventually, we will resolve this. + "Cast", ] @@ -277,9 +282,14 @@ def convert_float_to_float16( is_top_level = False # Going to process sub-graph graph_stack = next_level - sort_topology(model.graph) remove_unnecessary_cast_node(model.graph) + # Topologically sort the graph + # NOTE: We do not perform another round of optimization as the model is already optimized + graph = gs.import_onnx(model) + graph.toposort() + model = gs.export_onnx(graph) + return model @@ -311,21 +321,26 @@ def process_graph_input( graph, graph_input.name ) for d_node in downstream_nodes: - cast_node_name = graph_input.name + "_cast_to_" + d_node.name - cast_node_output_name = graph_input.name + "_cast_to_" + d_node.name - add_cast_node( - graph, - [graph_input.name], - [cast_node_output_name], - cast_node_name, - FLOAT16, - ) - add_new_value_info( - graph, - graph_input, - cast_node_output_name, - onnx_proto.TensorProto.FLOAT16, - ) + # More than one node may consume the model input, so we only create + # a single cast node, and then reuse this node when needed. + cast_exists = graph_input.name in global_input_name_dict + if cast_exists: + cast_node_output_name = global_input_name_dict[graph_input.name] + else: + cast_node_output_name = graph_input.name + "_fp16" + add_cast_node( + graph, + [graph_input.name], + [cast_node_output_name], + cast_node_output_name, # Set node name same as output name + FLOAT16, + ) + add_new_value_info( + graph, + graph_input, + cast_node_output_name, + onnx_proto.TensorProto.FLOAT16, + ) for i, input_name in enumerate(d_node.input): if input_name == graph_input.name: d_node.input[i] = ( @@ -414,8 +429,7 @@ def process_node_in_block_list( def insert_cast32_before_node( graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict ): - for i in range(len(node.input)): - input_name = node.input[i] + for i, input_name in enumerate(node.input): for value_info in itertools.chain(graph.value_info, graph.input): if input_name == value_info.name: if ( @@ -443,8 +457,7 @@ def insert_cast32_before_node( def insert_cast16_after_node( graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict ): - for i in range(len(node.output)): - output_name = node.output[i] + for i, output_name in enumerate(node.output): for value_info in itertools.chain(graph.value_info, graph.output): if output_name == value_info.name: if ( @@ -693,56 +706,6 @@ def convert_float_to_float16_model_path( ) -def sort_graph_node(graph_proto): - # find the "first" node in Nodes that its input is not any node's output - def find_first_node(output2node_dict): - for node in org_nodes: - is_not_first_node = any(item in output2node_dict for item in node.input) - if not is_not_first_node: - return node - return None - - # remove the node from output2node_dict using output as key - def remove_first_node_from_dict2(first_node): - for output in first_node.output: - if output in output2node_dict: - del output2node_dict[output] - - org_nodes = graph_proto.node - # create a dict to store output as key and node as value - output2node_dict = {} - for node in org_nodes: - for output in node.output: - output2node_dict[output] = node - - # save the final node after sorted - sorted_node = [] - # traverse the Nodes to find the first node - while len(output2node_dict) > 0: - first_node = find_first_node(output2node_dict) - sorted_node.append(first_node) - remove_first_node_from_dict2(first_node) - # del node from original nodes list to avoid duplicate traverse - org_nodes.remove(first_node) - - for new_node in sorted_node: - graph_proto.node.extend([new_node]) - - -# The input graph should be mode.graph -# Recursively sort the topology for each sub-graph -def sort_topology(graph_proto): - assert isinstance(graph_proto, onnx_proto.GraphProto) - sort_graph_node(graph_proto) # sort global graph - for node in graph_proto.node: - for attr in node.attribute: - if isinstance(attr.g, onnx_proto.GraphProto) and len(attr.g.node) > 0: - sort_topology(attr.g) # sort sub-graph - for g in attr.graphs: - if isinstance(g, onnx_proto.GraphProto): - sort_topology(g) # sort sub-graph - - def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto): # 1. find all cast nodes in the graph cast_node_list = [] @@ -837,8 +800,8 @@ def get_type(name: str) -> Optional[int]: else: if ( downstream_node.op_type == "Cast" - and cast_node.attribute[0].i == 10 - and downstream_node.attribute[0].i == 1 + and cast_node.attribute[0].i == FLOAT16 + and downstream_node.attribute[0].i == FLOAT32 and downstream_node in cast_node_list and cast_node in cast_node_list ): diff --git a/scripts/quantize.py b/scripts/quantize.py index 9d8aebb51..3f73f0916 100644 --- a/scripts/quantize.py +++ b/scripts/quantize.py @@ -14,7 +14,6 @@ from onnxruntime.quantization.registry import IntegerOpsRegistry from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer -import onnx_graphsurgeon as gs from . import float16 from .utils import check_and_save_model @@ -221,10 +220,6 @@ def quantize_fp16( disable_shape_infer=disable_shape_infer, op_block_list=blocked_ops, ) - - graph = gs.import_onnx(model_fp16) - graph.toposort() - model_fp16 = gs.export_onnx(graph) check_and_save_model(model_fp16, save_path) diff --git a/scripts/requirements.txt b/scripts/requirements.txt index 460275540..a848c225a 100644 --- a/scripts/requirements.txt +++ b/scripts/requirements.txt @@ -1,7 +1,6 @@ -transformers[torch]==4.48.3 +transformers[torch]==4.49.0 onnxruntime==1.20.1 -optimum@git+https://github.com/huggingface/optimum.git@ce533cf1a9e144d4040581947f301dc3f454b279 +optimum@git+https://github.com/huggingface/optimum.git@b04feaea78cda58d79b8da67dca3fd0c4ab33435 onnx==1.17.0 tqdm==4.67.1 onnxslim==0.1.48 -onnx-graphsurgeon==0.5.5