21
21
# pylint: disable=unused-argument,unused-import,no-value-for-parameter,unexpected-keyword-arg,ungrouped-imports
22
22
# pylint: disable=missing-function-docstring,import-outside-toplevel,useless-import-alias,missing-docstring
23
23
24
+
24
25
def is_tf2 ():
25
26
return tf .__version__ .startswith ("2." )
26
27
27
28
29
+ def _not_implemented_tf_placeholder (name ):
30
+ """Creates a placeholder function for missing Tensorflow imports"""
31
+
32
+ def not_implemented_tf_placeholder (* args , ** kwargs ):
33
+ raise NotImplementedError (
34
+ f'Tensorflow verison { tf .__version__ } does not implement '
35
+ f'`{ name } `, try converting your model with a different version.'
36
+ )
37
+ return not_implemented_tf_placeholder
38
+
39
+
40
+ try :
41
+ from tensorflow .python .framework .function_def_to_graph import function_def_to_graph
42
+ except ImportError :
43
+ function_def_to_graph = _not_implemented_tf_placeholder ('function_def_to_graph' )
44
+
28
45
if is_tf2 ():
29
- from tensorflow .python .framework import convert_to_constants , func_graph , function_def_to_graph
46
+ convert_variables_to_constants = tf .compat .v1 .graph_util .convert_variables_to_constants
47
+ from tensorflow .python .framework .convert_to_constants import convert_variables_to_constants_v2
30
48
else :
31
49
from tensorflow .python .framework .graph_util import convert_variables_to_constants
50
+ convert_variables_to_constants_v2 = _not_implemented_tf_placeholder ('convert_variables_to_constants_v2' )
32
51
33
52
34
53
if is_tf2 ():
@@ -63,7 +82,7 @@ def is_tf2():
63
82
64
83
65
84
def from_function (func , input_names , output_names ):
66
- frozen_func = convert_to_constants . convert_variables_to_constants_v2 (func , lower_control_flow = False )
85
+ frozen_func = convert_variables_to_constants_v2 (func , lower_control_flow = False )
67
86
graph_def = frozen_func .graph .as_graph_def (add_shapes = True )
68
87
# output_tensors = {i.name: i for i in frozen_func.outputs}
69
88
tf_reset_default_graph ()
@@ -87,10 +106,7 @@ def freeze_session(sess, input_names=None, output_names=None):
87
106
graph_def = sess .graph .as_graph_def (add_shapes = True )
88
107
for node in graph_def .node :
89
108
node .device = ""
90
- if is_tf2 ():
91
- graph_def = tf .compat .v1 .graph_util .convert_variables_to_constants (sess , graph_def , output_node_names )
92
- else :
93
- graph_def = convert_variables_to_constants (sess , graph_def , output_node_names )
109
+ graph_def = convert_variables_to_constants (sess , graph_def , output_node_names )
94
110
return graph_def
95
111
96
112
@@ -366,7 +382,7 @@ def toposort(data):
366
382
fdef = fdef .definition
367
383
if input_shapes and len (fdef .signature .input_arg ) < len (input_shapes ):
368
384
input_shapes = input_shapes [:len (fdef .signature .input_arg )]
369
- func = function_def_to_graph . function_def_to_graph (fdef , input_shapes = input_shapes )
385
+ func = function_def_to_graph (fdef , input_shapes = input_shapes )
370
386
_FUNCTIONS [k ] = func
371
387
_ , _ , _ , _ , _ , tfunctions = tflist_to_onnx (func , {})
372
388
functions .update (tfunctions )
0 commit comments