Skip to content

Commit 557939b

Browse files
authored
Merge pull request #858 from andhus/add_missing_import
Add missing import in tf<2
2 parents 48fdca5 + 35b8a73 commit 557939b

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

tf2onnx/tf_loader.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,33 @@
2121
# pylint: disable=unused-argument,unused-import,no-value-for-parameter,unexpected-keyword-arg,ungrouped-imports
2222
# pylint: disable=missing-function-docstring,import-outside-toplevel,useless-import-alias,missing-docstring
2323

24+
2425
def is_tf2():
2526
return tf.__version__.startswith("2.")
2627

2728

29+
def _not_implemented_tf_placeholder(name):
30+
"""Creates a placeholder function for missing Tensorflow imports"""
31+
32+
def not_implemented_tf_placeholder(*args, **kwargs):
33+
raise NotImplementedError(
34+
f'Tensorflow verison {tf.__version__} does not implement '
35+
f'`{name}`, try converting your model with a different version.'
36+
)
37+
return not_implemented_tf_placeholder
38+
39+
40+
try:
41+
from tensorflow.python.framework.function_def_to_graph import function_def_to_graph
42+
except ImportError:
43+
function_def_to_graph = _not_implemented_tf_placeholder('function_def_to_graph')
44+
2845
if is_tf2():
29-
from tensorflow.python.framework import convert_to_constants, func_graph, function_def_to_graph
46+
convert_variables_to_constants = tf.compat.v1.graph_util.convert_variables_to_constants
47+
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
3048
else:
3149
from tensorflow.python.framework.graph_util import convert_variables_to_constants
50+
convert_variables_to_constants_v2 = _not_implemented_tf_placeholder('convert_variables_to_constants_v2')
3251

3352

3453
if is_tf2():
@@ -63,7 +82,7 @@ def is_tf2():
6382

6483

6584
def from_function(func, input_names, output_names):
66-
frozen_func = convert_to_constants.convert_variables_to_constants_v2(func, lower_control_flow=False)
85+
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
6786
graph_def = frozen_func.graph.as_graph_def(add_shapes=True)
6887
# output_tensors = {i.name: i for i in frozen_func.outputs}
6988
tf_reset_default_graph()
@@ -87,10 +106,7 @@ def freeze_session(sess, input_names=None, output_names=None):
87106
graph_def = sess.graph.as_graph_def(add_shapes=True)
88107
for node in graph_def.node:
89108
node.device = ""
90-
if is_tf2():
91-
graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, graph_def, output_node_names)
92-
else:
93-
graph_def = convert_variables_to_constants(sess, graph_def, output_node_names)
109+
graph_def = convert_variables_to_constants(sess, graph_def, output_node_names)
94110
return graph_def
95111

96112

@@ -366,7 +382,7 @@ def toposort(data):
366382
fdef = fdef.definition
367383
if input_shapes and len(fdef.signature.input_arg) < len(input_shapes):
368384
input_shapes = input_shapes[:len(fdef.signature.input_arg)]
369-
func = function_def_to_graph.function_def_to_graph(fdef, input_shapes=input_shapes)
385+
func = function_def_to_graph(fdef, input_shapes=input_shapes)
370386
_FUNCTIONS[k] = func
371387
_, _, _, _, _, tfunctions = tflist_to_onnx(func, {})
372388
functions.update(tfunctions)

0 commit comments

Comments
 (0)