Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 36 additions & 73 deletions scripts/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import warnings
from onnx import helper, numpy_helper
from onnx import onnx_pb as onnx_proto
import onnxslim.third_party.onnx_graphsurgeon as gs


FLOAT32 = 1
Expand Down Expand Up @@ -179,8 +180,12 @@ def make_value_info_from_tensor(tensor):
"Max",
"Upsample",
# NEW:
"Cast",
"RandomNormalLike",
# TODO: Ideally, "Cast" nodes should not be here, for the following reasons:
# - It breaks the semantics that the default list contains "ops that are not supported for float16 in ONNX Runtime".
# - When fp32 casts already exist in the model (e.g., for rotary embeddings), this script will insert redundant casts around it.
# However, without it, the graphs produced are invalid. Eventually, we will resolve this.
"Cast",
]


Expand Down Expand Up @@ -277,9 +282,14 @@ def convert_float_to_float16(
is_top_level = False # Going to process sub-graph
graph_stack = next_level

sort_topology(model.graph)
remove_unnecessary_cast_node(model.graph)

# Topologically sort the graph
# NOTE: We do not perform another round of optimization as the model is already optimized
graph = gs.import_onnx(model)
graph.toposort()
model = gs.export_onnx(graph)

return model


Expand Down Expand Up @@ -311,21 +321,26 @@ def process_graph_input(
graph, graph_input.name
)
for d_node in downstream_nodes:
cast_node_name = graph_input.name + "_cast_to_" + d_node.name
cast_node_output_name = graph_input.name + "_cast_to_" + d_node.name
add_cast_node(
graph,
[graph_input.name],
[cast_node_output_name],
cast_node_name,
FLOAT16,
)
add_new_value_info(
graph,
graph_input,
cast_node_output_name,
onnx_proto.TensorProto.FLOAT16,
)
# More than one node may consume the model input, so we only create
# a single cast node, and then reuse this node when needed.
cast_exists = graph_input.name in global_input_name_dict
if cast_exists:
cast_node_output_name = global_input_name_dict[graph_input.name]
else:
cast_node_output_name = graph_input.name + "_fp16"
add_cast_node(
graph,
[graph_input.name],
[cast_node_output_name],
cast_node_output_name, # Set node name same as output name
FLOAT16,
)
add_new_value_info(
graph,
graph_input,
cast_node_output_name,
onnx_proto.TensorProto.FLOAT16,
)
for i, input_name in enumerate(d_node.input):
if input_name == graph_input.name:
d_node.input[i] = (
Expand Down Expand Up @@ -414,8 +429,7 @@ def process_node_in_block_list(
def insert_cast32_before_node(
graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict
):
for i in range(len(node.input)):
input_name = node.input[i]
for i, input_name in enumerate(node.input):
for value_info in itertools.chain(graph.value_info, graph.input):
if input_name == value_info.name:
if (
Expand Down Expand Up @@ -443,8 +457,7 @@ def insert_cast32_before_node(
def insert_cast16_after_node(
graph: onnx_proto.GraphProto, node: onnx_proto.NodeProto, global_input_name_dict
):
for i in range(len(node.output)):
output_name = node.output[i]
for i, output_name in enumerate(node.output):
for value_info in itertools.chain(graph.value_info, graph.output):
if output_name == value_info.name:
if (
Expand Down Expand Up @@ -693,56 +706,6 @@ def convert_float_to_float16_model_path(
)


def sort_graph_node(graph_proto):
# find the "first" node in Nodes that its input is not any node's output
def find_first_node(output2node_dict):
for node in org_nodes:
is_not_first_node = any(item in output2node_dict for item in node.input)
if not is_not_first_node:
return node
return None

# remove the node from output2node_dict using output as key
def remove_first_node_from_dict2(first_node):
for output in first_node.output:
if output in output2node_dict:
del output2node_dict[output]

org_nodes = graph_proto.node
# create a dict to store output as key and node as value
output2node_dict = {}
for node in org_nodes:
for output in node.output:
output2node_dict[output] = node

# save the final node after sorted
sorted_node = []
# traverse the Nodes to find the first node
while len(output2node_dict) > 0:
first_node = find_first_node(output2node_dict)
sorted_node.append(first_node)
remove_first_node_from_dict2(first_node)
# del node from original nodes list to avoid duplicate traverse
org_nodes.remove(first_node)

for new_node in sorted_node:
graph_proto.node.extend([new_node])


# The input graph should be mode.graph
# Recursively sort the topology for each sub-graph
def sort_topology(graph_proto):
assert isinstance(graph_proto, onnx_proto.GraphProto)
sort_graph_node(graph_proto) # sort global graph
for node in graph_proto.node:
for attr in node.attribute:
if isinstance(attr.g, onnx_proto.GraphProto) and len(attr.g.node) > 0:
sort_topology(attr.g) # sort sub-graph
for g in attr.graphs:
if isinstance(g, onnx_proto.GraphProto):
sort_topology(g) # sort sub-graph


def remove_unnecessary_cast_node(graph_proto: onnx_proto.GraphProto):
# 1. find all cast nodes in the graph
cast_node_list = []
Expand Down Expand Up @@ -837,8 +800,8 @@ def get_type(name: str) -> Optional[int]:
else:
if (
downstream_node.op_type == "Cast"
and cast_node.attribute[0].i == 10
and downstream_node.attribute[0].i == 1
and cast_node.attribute[0].i == FLOAT16
and downstream_node.attribute[0].i == FLOAT32
and downstream_node in cast_node_list
and cast_node in cast_node_list
):
Expand Down
5 changes: 0 additions & 5 deletions scripts/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from onnxruntime.quantization.registry import IntegerOpsRegistry
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
from onnxruntime.quantization.matmul_bnb4_quantizer import MatMulBnb4Quantizer
import onnx_graphsurgeon as gs

from . import float16
from .utils import check_and_save_model
Expand Down Expand Up @@ -221,10 +220,6 @@ def quantize_fp16(
disable_shape_infer=disable_shape_infer,
op_block_list=blocked_ops,
)

graph = gs.import_onnx(model_fp16)
graph.toposort()
model_fp16 = gs.export_onnx(graph)
check_and_save_model(model_fp16, save_path)


Expand Down
5 changes: 2 additions & 3 deletions scripts/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
transformers[torch]==4.48.3
transformers[torch]==4.49.0
onnxruntime==1.20.1
optimum@git+https://github.com/huggingface/optimum.git@ce533cf1a9e144d4040581947f301dc3f454b279
optimum@git+https://github.com/huggingface/optimum.git@b04feaea78cda58d79b8da67dca3fd0c4ab33435
onnx==1.17.0
tqdm==4.67.1
onnxslim==0.1.48
onnx-graphsurgeon==0.5.5