44from distutils .version import LooseVersion
55
66try :
7- import onnx
87 from tf2onnx .tfonnx import process_tf_graph , tf_optimize
98 from tf2onnx import optimizer
109
@@ -126,16 +125,6 @@ def convert_frozen_to_onnx(
126125) -> Any :
127126 # This is basically https://github.com/onnx/tensorflow-onnx/blob/master/tf2onnx/convert.py
128127
129- # Some constants in the graph need to be read by the inference system.
130- # These aren't used by the model anywhere, so trying to make sure they propagate
131- # through conversion and import is a losing battle. Instead, save them now,
132- # so that we can add them back later.
133- constant_values = {}
134- for n in frozen_graph_def .node :
135- if n .name in MODEL_CONSTANTS :
136- val = n .attr ["value" ].tensor .int_val [0 ]
137- constant_values [n .name ] = val
138-
139128 inputs = _get_input_node_names (frozen_graph_def )
140129 outputs = _get_output_node_names (frozen_graph_def )
141130 logger .info (f"onnx export - inputs:{ inputs } outputs:{ outputs } " )
@@ -157,26 +146,9 @@ def convert_frozen_to_onnx(
157146 onnx_graph = optimizer .optimize_graph (g )
158147 model_proto = onnx_graph .make_model (settings .brain_name )
159148
160- # Save the constant values back the graph initializer.
161- # This will ensure the importer gets them as global constants.
162- constant_nodes = []
163- for k , v in constant_values .items ():
164- constant_node = _make_onnx_node_for_constant (k , v )
165- constant_nodes .append (constant_node )
166- model_proto .graph .initializer .extend (constant_nodes )
167149 return model_proto
168150
169151
170- def _make_onnx_node_for_constant (name : str , value : int ) -> Any :
171- tensor_value = onnx .TensorProto (
172- data_type = onnx .TensorProto .INT32 ,
173- name = name ,
174- int32_data = [value ],
175- dims = [1 , 1 , 1 , 1 ],
176- )
177- return tensor_value
178-
179-
180152def _get_input_node_names (frozen_graph_def : Any ) -> List [str ]:
181153 """
182154 Get the list of input node names from the graph.
@@ -201,10 +173,12 @@ def _get_input_node_names(frozen_graph_def: Any) -> List[str]:
201173def _get_output_node_names (frozen_graph_def : Any ) -> List [str ]:
202174 """
203175 Get the list of output node names from the graph.
176+ Also include constants, so that they will be readable by the
177+ onnx importer.
204178 Names are suffixed with ":0"
205179 """
206180 node_names = _get_frozen_graph_node_names (frozen_graph_def )
207- output_names = node_names & POSSIBLE_OUTPUT_NODES
181+ output_names = node_names & ( POSSIBLE_OUTPUT_NODES | MODEL_CONSTANTS )
208182 # Append the port
209183 return [f"{ n } :0" for n in output_names ]
210184
0 commit comments