21
21
# pylint: disable=unused-argument,unused-import,no-value-for-parameter,unexpected-keyword-arg,ungrouped-imports
22
22
# pylint: disable=missing-function-docstring,import-outside-toplevel,useless-import-alias,missing-docstring
23
23
24
+
24
25
def is_tf2 ():
25
26
return tf .__version__ .startswith ("2." )
26
27
27
28
29
+ def _not_implemented_tf_placeholder (name ):
30
+ """Creates a placeholder function for missing Tensorflow imports"""
31
+
32
+ def not_implemented_tf_placeholder (* args , ** kwargs ):
33
+ raise NotImplementedError (
34
+ f'Tensorflow verison { tf .__version__ } does not implement '
35
+ f'`{ name } `, try converting your model with a different version.'
36
+ )
37
+ return not_implemented_tf_placeholder
38
+
39
+
40
+ try :
41
+ from tensorflow .python .framework .function_def_to_graph import function_def_to_graph
42
+ except ImportError :
43
+ function_def_to_graph = _not_implemented_tf_placeholder ('function_def_to_graph' )
44
+
28
45
if is_tf2 ():
29
- from tensorflow .python .framework import convert_to_constants , func_graph , function_def_to_graph
46
+ convert_variables_to_constants = tf .compat .v1 .graph_util .convert_variables_to_constants
47
+ from tensorflow .python .framework .convert_to_constants import convert_variables_to_constants_v2
30
48
else :
31
49
from tensorflow .python .framework .graph_util import convert_variables_to_constants
32
- from tensorflow . python . framework import function_def_to_graph
50
+ convert_variables_to_constants_v2 = _not_implemented_tf_placeholder ( 'convert_variables_to_constants_v2' )
33
51
34
52
35
53
if is_tf2 ():
@@ -64,7 +82,7 @@ def is_tf2():
64
82
65
83
66
84
def from_function (func , input_names , output_names ):
67
- frozen_func = convert_to_constants . convert_variables_to_constants_v2 (func , lower_control_flow = False )
85
+ frozen_func = convert_variables_to_constants_v2 (func , lower_control_flow = False )
68
86
graph_def = frozen_func .graph .as_graph_def (add_shapes = True )
69
87
# output_tensors = {i.name: i for i in frozen_func.outputs}
70
88
tf_reset_default_graph ()
@@ -88,10 +106,7 @@ def freeze_session(sess, input_names=None, output_names=None):
88
106
graph_def = sess .graph .as_graph_def (add_shapes = True )
89
107
for node in graph_def .node :
90
108
node .device = ""
91
- if is_tf2 ():
92
- graph_def = tf .compat .v1 .graph_util .convert_variables_to_constants (sess , graph_def , output_node_names )
93
- else :
94
- graph_def = convert_variables_to_constants (sess , graph_def , output_node_names )
109
+ graph_def = convert_variables_to_constants (sess , graph_def , output_node_names )
95
110
return graph_def
96
111
97
112
@@ -367,7 +382,7 @@ def toposort(data):
367
382
fdef = fdef .definition
368
383
if input_shapes and len (fdef .signature .input_arg ) < len (input_shapes ):
369
384
input_shapes = input_shapes [:len (fdef .signature .input_arg )]
370
- func = function_def_to_graph . function_def_to_graph (fdef , input_shapes = input_shapes )
385
+ func = function_def_to_graph (fdef , input_shapes = input_shapes )
371
386
_FUNCTIONS [k ] = func
372
387
_ , _ , _ , _ , _ , tfunctions = tflist_to_onnx (func , {})
373
388
functions .update (tfunctions )
0 commit comments