Skip to content

Commit 5f07ee7

Browse files
committed
clean up tf-version dependent imports
1 parent 8b4ca01 commit 5f07ee7

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

tf2onnx/tf_loader.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,33 @@
2121
# pylint: disable=unused-argument,unused-import,no-value-for-parameter,unexpected-keyword-arg,ungrouped-imports
2222
# pylint: disable=missing-function-docstring,import-outside-toplevel,useless-import-alias,missing-docstring
2323

24+
2425
def is_tf2():
2526
return tf.__version__.startswith("2.")
2627

2728

29+
def _not_implemented_tf_placeholder(name):
30+
"""Creates a placeholder function for missing Tensorflow imports"""
31+
32+
def not_implemented_tf_placeholder(*args, **kwargs):
33+
raise NotImplementedError(
34+
f'Tensorflow verison {tf.__version__} does not implement '
35+
f'`{name}`, try converting your model with a different version.'
36+
)
37+
return not_implemented_tf_placeholder
38+
39+
40+
try:
41+
from tensorflow.python.framework.function_def_to_graph import function_def_to_graph
42+
except ImportError:
43+
function_def_to_graph = _not_implemented_tf_placeholder('function_def_to_graph')
44+
2845
if is_tf2():
29-
from tensorflow.python.framework import convert_to_constants, func_graph, function_def_to_graph
46+
convert_variables_to_constants = tf.compat.v1.graph_util.convert_variables_to_constants
47+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
3048
else:
3149
from tensorflow.python.framework.graph_util import convert_variables_to_constants
32-
from tensorflow.python.framework import function_def_to_graph
50+
convert_variables_to_constants_v2 = _not_implemented_tf_placeholder('convert_variables_to_constants_v2')
3351

3452

3553
if is_tf2():
@@ -64,7 +82,7 @@ def is_tf2():
6482

6583

6684
def from_function(func, input_names, output_names):
67-
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func, lower_control_flow=False)
85+
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
6886
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
6987
# output_tensors = {i.name: i for i in frozen_func.outputs}
7088
tf_reset_default_graph()
@@ -88,10 +106,7 @@ def freeze_session(sess, input_names=None, output_names=None):
88106
graph_def = sess.graph.as_graph_def(add_shapes=True)
89107
for node in graph_def.node:
90108
node.device = ""
91-
if is_tf2():
92-
graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, graph_def, output_node_names)
93-
else:
94-
graph_def = convert_variables_to_constants(sess, graph_def, output_node_names)
109+
graph_def = convert_variables_to_constants(sess, graph_def, output_node_names)
95110
return graph_def
96111

97112

@@ -367,7 +382,7 @@ def toposort(data):
367382
fdef = fdef.definition
368383
if input_shapes and len(fdef.signature.input_arg) < len(input_shapes):
369384
input_shapes = input_shapes[:len(fdef.signature.input_arg)]
370-
func = function_def_to_graph.function_def_to_graph(fdef, input_shapes=input_shapes)
385+
func = function_def_to_graph(fdef, input_shapes=input_shapes)
371386
_FUNCTIONS[k] = func
372387
_, _, _, _, _, tfunctions = tflist_to_onnx(func, {})
373388
functions.update(tfunctions)

0 commit comments

Comments
 (0)