Skip to content

Commit f7a088e

Browse files
committed
Use non-buggy onnx-graphsurgeon via onnxslim for toposort
1 parent 3502ddb commit f7a088e

File tree

3 files changed

+9
-59
lines changed

3 files changed

+9
-59
lines changed

scripts/float16.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import warnings
3030
from onnx import helper, numpy_helper
3131
from onnx import onnx_pb as onnx_proto
32+
import onnxslim.third_party.onnx_graphsurgeon as gs
3233

3334

3435
FLOAT32 = 1
@@ -277,9 +278,14 @@ def convert_float_to_float16(
277278
is_top_level = False # Going to process sub-graph
278279
graph_stack = next_level
279280

280-
sort_topology(model.graph)
281281
remove_unnecessary_cast_node(model.graph)
282282

283+
# Topologically sort the graph
284+
# NOTE: We do not perform another round of optimization as the model is already optimized
285+
graph = gs.import_onnx(model)
286+
graph.toposort()
287+
model = gs.export_onnx(graph)
288+
283289
return model
284290

285291

@@ -693,56 +699,6 @@ def convert_float_to_float16_model_path(
693699
)
694700

695701

696-
def sort_graph_node(graph_proto):
697-
# find the "first" node in Nodes that its input is not any node's output
698-
def find_first_node(output2node_dict):
699-
for node in org_nodes:
700-
is_not_first_node = any(item in output2node_dict for item in node.input)
701-
if not is_not_first_node:
702-
return node
703-
return None
704-
705-
# remove the node from output2node_dict using output as key
706-
def remove_first_node_from_dict2(first_node):
707-
for output in first_node.output:
708-
if output in output2node_dict:
709-
del output2node_dict[output]
710-
711-
org_nodes = graph_proto.node
712-
# create a dict to store output as key and node as value
713-
output2node_dict = {}
714-
for node in org_nodes:
715-
for output in node.output:
716-
output2node_dict[output] = node
717-
718-
# save the final node after sorted
719-
sorted_node = []
720-
# traverse the Nodes to find the first node
721-
while len(output2node_dict) > 0:
722-
first_node = find_first_node(output2node_dict)
723-
sorted_node.append(first_node)
724-
remove_first_node_from_dict2(first_node)
725-
# del node from original nodes list to avoid duplicate traverse
726-
org_nodes.remove(first_node)
727-
728-
for new_node in sorted_node:
729-
graph_proto.node.extend([new_node])
730-
731-
732-
# The input graph should be mode.graph
733-
# Recursively sort the topology for each sub-graph
734-
def sort_topology(graph_proto):
735-
assert isinstance(graph_proto, onnx_proto.GraphProto)
736-
sort_graph_node(graph_proto) # sort global graph
737-
for node in graph_proto.node:
738-
for attr in node.attribute:
739-
if isinstance(attr.g, onnx_proto.GraphProto) and len(attr.g.node) > 0:
740-
sort_topology(attr.g) # sort sub-graph
741-
for g in attr.graphs:
742-
if isinstance(g, onnx_proto.GraphProto):
743-
sort_topology(g) # sort sub-graph
744-
745-
746702
def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
747703
# 1. find all cast nodes in the graph
748704
cast_node_list = []

scripts/quantize.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from onnxruntime.quantization.registry import IntegerOpsRegistry
1515
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
1616
from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer
17-
import onnx_graphsurgeon as gs
1817

1918
from . import float16
2019
from .utils import check_and_save_model
@@ -221,10 +220,6 @@ def quantize_fp16(
221220
disable_shape_infer=disable_shape_infer,
222221
op_block_list=blocked_ops,
223222
)
224-
225-
graph = gs.import_onnx(model_fp16)
226-
graph.toposort()
227-
model_fp16 = gs.export_onnx(graph)
228223
check_and_save_model(model_fp16, save_path)
229224

230225

scripts/requirements.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
transformers[torch]==4.48.3
1+
transformers[torch]==4.49.0
22
onnxruntime==1.20.1
3-
optimum@git+https://github.com/huggingface/optimum.git@ce533cf1a9e144d4040581947f301dc3f454b279
3+
optimum@git+https://github.com/huggingface/optimum.git@b04feaea78cda58d79b8da67dca3fd0c4ab33435
44
onnx==1.17.0
55
tqdm==4.67.1
66
onnxslim==0.1.48
7-
onnx-graphsurgeon==0.5.5

0 commit comments

Comments
 (0)