Skip to content

Commit 5bc8493

Browse files
authored
modify tf import for tf2 (#27)
1 parent 02cd953 commit 5bc8493

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

dlclive/graph.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77

88

99
import tensorflow as tf
10+
vers = (tf.__version__).split('.')
11+
if int(vers[0])==2 or int(vers[0])==1 and int(vers[1])>12:
12+
tf=tf.compat.v1
13+
else:
14+
tf=tf
1015

1116

1217
def read_graph(file):
@@ -25,7 +30,7 @@ def read_graph(file):
2530
"""
2631

2732
with tf.io.gfile.GFile(file, "rb") as f:
28-
graph_def = tf.compat.v1.GraphDef()
33+
graph_def = tf.GraphDef()
2934
graph_def.ParseFromString(f.read())
3035
return graph_def
3136

@@ -125,7 +130,7 @@ def extract_graph(graph, tf_config=None):
125130

126131
input_tensor = get_input_tensor(graph)
127132
output_tensor = get_output_tensors(graph)
128-
sess = tf.compat.v1.Session(graph=graph, config=tf_config)
133+
sess = tf.Session(graph=graph, config=tf_config)
129134
inputs = graph.get_tensor_by_name(input_tensor)
130135
outputs = [graph.get_tensor_by_name(out) for out in output_tensor]
131136

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
long_description = fh.read()
1818

1919
install_requires = [
20-
"numpy",
20+
"numpy<1.19.0",
2121
"ruamel.yaml",
2222
"colorcet",
2323
"pillow",
@@ -31,7 +31,7 @@
3131
)
3232
else:
3333
install_requires.append("opencv-python")
34-
install_requires.append("tensorflow==1.13.1")
34+
install_requires.append("tensorflow")
3535
install_requires.append("pandas")
3636
install_requires.append("tables")
3737

0 commit comments

Comments
 (0)