Skip to content

Commit 25a7f8a

Browse files
author
Han Wang
committed
fix the bug of loading dp model to the default tf graph
1 parent 193570f commit 25a7f8a

File tree

6 files changed

+28
-18
lines changed

6 files changed

+28
-18
lines changed

source/train/DataModifier.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ def __init__(self,
2222
# the dipole model is loaded with prefix 'dipole_charge'
2323
self.modifier_prefix = 'dipole_charge'
2424
# init dipole model
25-
DeepDipole.__init__(self, model_name, load_prefix = self.modifier_prefix)
25+
DeepDipole.__init__(self,
26+
model_name,
27+
load_prefix = self.modifier_prefix,
28+
default_tf_graph = True)
2629
self.model_name = model_name
2730
self.model_charge_map = model_charge_map
2831
self.sys_charge_map = sys_charge_map

source/train/DeepDipole.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
class DeepDipole (DeepTensor) :
66
def __init__(self,
77
model_file,
8-
load_prefix = 'load') :
9-
DeepTensor.__init__(self, model_file, 'dipole', 3, load_prefix = load_prefix)
8+
load_prefix = 'load',
9+
default_tf_graph = False) :
10+
DeepTensor.__init__(self, model_file, 'dipole', 3, load_prefix = load_prefix, default_tf_graph = default_tf_graph)
1011

source/train/DeepEval.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,18 @@ class DeepEval():
1313
"""
1414
def __init__(self,
1515
model_file,
16-
load_prefix = 'load') :
17-
self.graph = self._load_graph (model_file, prefix = load_prefix)
16+
load_prefix = 'load',
17+
default_tf_graph = False) :
18+
self.graph = self._load_graph (model_file, prefix = load_prefix, default_tf_graph = default_tf_graph)
1819
t_mt = self.graph.get_tensor_by_name(os.path.join(load_prefix, 'model_attr/model_type:0'))
1920
sess = tf.Session (graph = self.graph, config=default_tf_session_config)
2021
[mt] = sess.run([t_mt], feed_dict = {})
2122
self.model_type = mt.decode('utf-8')
2223

2324
def _load_graph(self,
24-
frozen_graph_filename,
25-
prefix = 'load',
26-
default_tf_graph = True):
25+
frozen_graph_filename,
26+
prefix = 'load',
27+
default_tf_graph = False):
2728
# We load the protobuf file from the disk and parse it to retrieve the
2829
# unserialized graph_def
2930
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
@@ -102,8 +103,9 @@ def __init__(self,
102103
model_file,
103104
variable_name,
104105
variable_dof,
105-
load_prefix = 'load') :
106-
DeepEval.__init__(self, model_file, load_prefix = load_prefix)
106+
load_prefix = 'load',
107+
default_tf_graph = False) :
108+
DeepEval.__init__(self, model_file, load_prefix = load_prefix, default_tf_graph = default_tf_graph)
107109
# self.model_file = model_file
108110
# self.graph = self.load_graph (self.model_file)
109111
self.variable_name = variable_name

source/train/DeepPolar.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44

55
class DeepPolar (DeepTensor) :
66
def __init__(self,
7-
model_file) :
8-
DeepTensor.__init__(self, model_file, 'polar', 9)
7+
model_file,
8+
default_tf_graph = False) :
9+
DeepTensor.__init__(self, model_file, 'polar', 9, default_tf_graph = default_tf_graph)
910

1011

1112
class DeepGlobalPolar (DeepTensor) :
1213
def __init__(self,
13-
model_file) :
14-
DeepTensor.__init__(self, model_file, 'global_polar', 9)
14+
model_file,
15+
default_tf_graph = False) :
16+
DeepTensor.__init__(self, model_file, 'global_polar', 9, default_tf_graph = default_tf_graph)
1517

1618
def eval(self,
1719
coords,

source/train/DeepPot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
class DeepPot (DeepEval) :
1111
def __init__(self,
12-
model_file) :
13-
DeepEval.__init__(self, model_file)
12+
model_file,
13+
default_tf_graph = False) :
14+
DeepEval.__init__(self, model_file, default_tf_graph = default_tf_graph)
1415
# self.model_file = model_file
1516
# self.graph = self.load_graph (self.model_file)
1617
# checkout input/output tensors from graph

source/train/DeepWFC.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
class DeepWFC (DeepTensor) :
66
def __init__(self,
7-
model_file) :
8-
DeepTensor.__init__(self, model_file, 'wfc', 12)
7+
model_file,
8+
default_tf_graph = False) :
9+
DeepTensor.__init__(self, model_file, 'wfc', 12, default_tf_graph = default_tf_graph)
910

0 commit comments

Comments
 (0)