|
29 | 29 | import warnings |
30 | 30 | from onnx import helper, numpy_helper |
31 | 31 | from onnx import onnx_pb as onnx_proto |
| 32 | +import onnxslim.third_party.onnx_graphsurgeon as gs |
32 | 33 |
|
33 | 34 |
|
34 | 35 | FLOAT32 = 1 |
@@ -277,9 +278,14 @@ def convert_float_to_float16( |
277 | 278 | is_top_level = False # Going to process sub-graph |
278 | 279 | graph_stack = next_level |
279 | 280 |
|
280 | | - sort_topology(model.graph) |
281 | 281 | remove_unnecessary_cast_node(model.graph) |
282 | 282 |
|
| 283 | + # Topologically sort the graph |
| 284 | + # NOTE: We do not perform another round of optimization as the model is already optimized |
| 285 | + graph = gs.import_onnx(model) |
| 286 | + graph.toposort() |
| 287 | + model = gs.export_onnx(graph) |
| 288 | + |
283 | 289 | return model |
284 | 290 |
|
285 | 291 |
|
@@ -693,56 +699,6 @@ def convert_float_to_float16_model_path( |
693 | 699 | ) |
694 | 700 |
|
695 | 701 |
|
696 | | -def sort_graph_node(graph_proto): |
697 | | - # find the "first" node in Nodes that its input is not any node's output |
698 | | - def find_first_node(output2node_dict): |
699 | | - for node in org_nodes: |
700 | | - is_not_first_node = any(item in output2node_dict for item in node.input) |
701 | | - if not is_not_first_node: |
702 | | - return node |
703 | | - return None |
704 | | - |
705 | | - # remove the node from output2node_dict using output as key |
706 | | - def remove_first_node_from_dict2(first_node): |
707 | | - for output in first_node.output: |
708 | | - if output in output2node_dict: |
709 | | - del output2node_dict[output] |
710 | | - |
711 | | - org_nodes = graph_proto.node |
712 | | - # create a dict to store output as key and node as value |
713 | | - output2node_dict = {} |
714 | | - for node in org_nodes: |
715 | | - for output in node.output: |
716 | | - output2node_dict[output] = node |
717 | | - |
718 | | - # save the final node after sorted |
719 | | - sorted_node = [] |
720 | | - # traverse the Nodes to find the first node |
721 | | - while len(output2node_dict) > 0: |
722 | | - first_node = find_first_node(output2node_dict) |
723 | | - sorted_node.append(first_node) |
724 | | - remove_first_node_from_dict2(first_node) |
725 | | - # del node from original nodes list to avoid duplicate traverse |
726 | | - org_nodes.remove(first_node) |
727 | | - |
728 | | - for new_node in sorted_node: |
729 | | - graph_proto.node.extend([new_node]) |
730 | | - |
731 | | - |
732 | | -# The input graph should be mode.graph |
733 | | -# Recursively sort the topology for each sub-graph |
734 | | -def sort_topology(graph_proto): |
735 | | - assert isinstance(graph_proto, onnx_proto.GraphProto) |
736 | | - sort_graph_node(graph_proto) # sort global graph |
737 | | - for node in graph_proto.node: |
738 | | - for attr in node.attribute: |
739 | | - if isinstance(attr.g, onnx_proto.GraphProto) and len(attr.g.node) > 0: |
740 | | - sort_topology(attr.g) # sort sub-graph |
741 | | - for g in attr.graphs: |
742 | | - if isinstance(g, onnx_proto.GraphProto): |
743 | | - sort_topology(g) # sort sub-graph |
744 | | - |
745 | | - |
746 | 702 | def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto): |
747 | 703 | # 1. find all cast nodes in the graph |
748 | 704 | cast_node_list = [] |
|
0 commit comments