Skip to content

Commit c6aaf49

Browse files
author
Han Wang
committed
init session only once for data stats
1 parent 1c20a69 commit c6aaf49

File tree

3 files changed

+104
-74
lines changed

3 files changed

+104
-74
lines changed

source/train/DescrptLocFrame.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,33 @@ def __init__(self, jdata):
4545
self.davg = None
4646
self.dstd = None
4747

48+
self.place_holders = {}
49+
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
50+
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
51+
sub_graph = tf.Graph()
52+
with sub_graph.as_default():
53+
name_pfx = 'd_lf_'
54+
for ii in ['coord', 'box']:
55+
self.place_holders[ii] = tf.placeholder(global_np_float_precision, [None, None], name = name_pfx+'t_'+ii)
56+
self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx+'t_type')
57+
self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name=name_pfx+'t_natoms')
58+
self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name=name_pfx+'t_mesh')
59+
self.stat_descrpt, descrpt_deriv, rij, nlist, axis, rot_mat \
60+
= op_module.descrpt (self.place_holders['coord'],
61+
self.place_holders['type'],
62+
self.place_holders['natoms_vec'],
63+
self.place_holders['box'],
64+
self.place_holders['default_mesh'],
65+
tf.constant(avg_zero),
66+
tf.constant(std_ones),
67+
rcut_a = self.rcut_a,
68+
rcut_r = self.rcut_r,
69+
sel_a = self.sel_a,
70+
sel_r = self.sel_r,
71+
axis_rule = self.axis_rule)
72+
self.sub_sess = tf.Session(graph = sub_graph)
73+
74+
4875
def get_rcut (self) :
4976
return self.rcut_r
5077

@@ -174,31 +201,15 @@ def _compute_dstats_sys_nonsmth (self,
174201
data_atype,
175202
natoms_vec,
176203
mesh) :
177-
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
178-
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
179-
sub_graph = tf.Graph()
180-
with sub_graph.as_default():
181-
descrpt, descrpt_deriv, rij, nlist, axis, rot_mat \
182-
= op_module.descrpt (tf.constant(data_coord),
183-
tf.constant(data_atype),
184-
tf.constant(natoms_vec, dtype = tf.int32),
185-
tf.constant(data_box),
186-
tf.constant(mesh),
187-
tf.constant(avg_zero),
188-
tf.constant(std_ones),
189-
rcut_a = self.rcut_a,
190-
rcut_r = self.rcut_r,
191-
sel_a = self.sel_a,
192-
sel_r = self.sel_r,
193-
axis_rule = self.axis_rule)
194-
# self.sess.run(tf.global_variables_initializer())
195-
# sub_sess = tf.Session(graph = sub_graph,
196-
# config=tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads,
197-
# inter_op_parallelism_threads=self.run_opt.num_inter_threads
198-
# ))
199-
sub_sess = tf.Session(graph = sub_graph)
200-
dd_all = sub_sess.run(descrpt)
201-
sub_sess.close()
204+
dd_all \
205+
= self.sub_sess.run(self.stat_descrpt,
206+
feed_dict = {
207+
self.place_holders['coord']: data_coord,
208+
self.place_holders['type']: data_atype,
209+
self.place_holders['natoms_vec']: natoms_vec,
210+
self.place_holders['box']: data_box,
211+
self.place_holders['default_mesh']: mesh,
212+
})
202213
natoms = natoms_vec
203214
dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]])
204215
start_index = 0

source/train/DescrptSeA.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,32 @@ def __init__ (self, jdata):
5656
self.dstd = None
5757
self.davg = None
5858

59+
self.place_holders = {}
60+
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
61+
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
62+
sub_graph = tf.Graph()
63+
with sub_graph.as_default():
64+
name_pfx = 'd_sea_'
65+
for ii in ['coord', 'box']:
66+
self.place_holders[ii] = tf.placeholder(global_np_float_precision, [None, None], name = name_pfx+'t_'+ii)
67+
self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx+'t_type')
68+
self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name=name_pfx+'t_natoms')
69+
self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name=name_pfx+'t_mesh')
70+
self.stat_descrpt, descrpt_deriv, rij, nlist \
71+
= op_module.descrpt_se_a(self.place_holders['coord'],
72+
self.place_holders['type'],
73+
self.place_holders['natoms_vec'],
74+
self.place_holders['box'],
75+
self.place_holders['default_mesh'],
76+
tf.constant(avg_zero),
77+
tf.constant(std_ones),
78+
rcut_a = self.rcut_a,
79+
rcut_r = self.rcut_r,
80+
rcut_r_smth = self.rcut_r_smth,
81+
sel_a = self.sel_a,
82+
sel_r = self.sel_r)
83+
self.sub_sess = tf.Session(graph = sub_graph)
84+
5985

6086
def get_rcut (self) :
6187
return self.rcut_r
@@ -240,32 +266,15 @@ def _compute_dstats_sys_smth (self,
240266
data_atype,
241267
natoms_vec,
242268
mesh) :
243-
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
244-
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
245-
sub_graph = tf.Graph()
246-
with sub_graph.as_default():
247-
descrpt, descrpt_deriv, rij, nlist \
248-
= op_module.descrpt_se_a (tf.constant(data_coord),
249-
tf.constant(data_atype),
250-
tf.constant(natoms_vec, dtype = tf.int32),
251-
tf.constant(data_box),
252-
tf.constant(mesh),
253-
tf.constant(avg_zero),
254-
tf.constant(std_ones),
255-
rcut_a = self.rcut_a,
256-
rcut_r = self.rcut_r,
257-
rcut_r_smth = self.rcut_r_smth,
258-
sel_a = self.sel_a,
259-
sel_r = self.sel_r)
260-
# self.sess.run(tf.global_variables_initializer())
261-
# sub_sess = tf.Session(graph = sub_graph,
262-
# config=tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads,
263-
# inter_op_parallelism_threads=self.run_opt.num_inter_threads
264-
265-
# ))
266-
sub_sess = tf.Session(graph = sub_graph)
267-
dd_all = sub_sess.run(descrpt)
268-
sub_sess.close()
269+
dd_all \
270+
= self.sub_sess.run(self.stat_descrpt,
271+
feed_dict = {
272+
self.place_holders['coord']: data_coord,
273+
self.place_holders['type']: data_atype,
274+
self.place_holders['natoms_vec']: natoms_vec,
275+
self.place_holders['box']: data_box,
276+
self.place_holders['default_mesh']: mesh,
277+
})
269278
natoms = natoms_vec
270279
dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]])
271280
start_index = 0

source/train/DescrptSeR.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,30 @@ def __init__ (self, jdata):
5252
self.davg = None
5353
self.dstd = None
5454

55+
self.place_holders = {}
56+
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
57+
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
58+
sub_graph = tf.Graph()
59+
with sub_graph.as_default():
60+
name_pfx = 'd_ser_'
61+
for ii in ['coord', 'box']:
62+
self.place_holders[ii] = tf.placeholder(global_np_float_precision, [None, None], name = name_pfx+'t_'+ii)
63+
self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx+'t_type')
64+
self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name=name_pfx+'t_natoms')
65+
self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name=name_pfx+'t_mesh')
66+
self.stat_descrpt, descrpt_deriv, rij, nlist \
67+
= op_module.descrpt_se_r(self.place_holders['coord'],
68+
self.place_holders['type'],
69+
self.place_holders['natoms_vec'],
70+
self.place_holders['box'],
71+
self.place_holders['default_mesh'],
72+
tf.constant(avg_zero),
73+
tf.constant(std_ones),
74+
rcut = self.rcut,
75+
rcut_smth = self.rcut_smth,
76+
sel = self.sel_r)
77+
self.sub_sess = tf.Session(graph = sub_graph)
78+
5579

5680
def get_rcut (self) :
5781
return self.rcut
@@ -197,29 +221,15 @@ def _compute_dstats_sys_se_r (self,
197221
data_atype,
198222
natoms_vec,
199223
mesh) :
200-
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
201-
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
202-
sub_graph = tf.Graph()
203-
with sub_graph.as_default():
204-
descrpt, descrpt_deriv, rij, nlist \
205-
= op_module.descrpt_se_r (tf.constant(data_coord),
206-
tf.constant(data_atype),
207-
tf.constant(natoms_vec, dtype = tf.int32),
208-
tf.constant(data_box),
209-
tf.constant(mesh),
210-
tf.constant(avg_zero),
211-
tf.constant(std_ones),
212-
rcut = self.rcut,
213-
rcut_smth = self.rcut_smth,
214-
sel = self.sel_r)
215-
# sub_sess = tf.Session(graph = sub_graph,
216-
# config=tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads,
217-
# inter_op_parallelism_threads=self.run_opt.num_inter_threads
218-
219-
# ))
220-
sub_sess = tf.Session(graph = sub_graph)
221-
dd_all = sub_sess.run(descrpt)
222-
sub_sess.close()
224+
dd_all \
225+
= self.sub_sess.run(self.stat_descrpt,
226+
feed_dict = {
227+
self.place_holders['coord']: data_coord,
228+
self.place_holders['type']: data_atype,
229+
self.place_holders['natoms_vec']: natoms_vec,
230+
self.place_holders['box']: data_box,
231+
self.place_holders['default_mesh']: mesh,
232+
})
223233
natoms = natoms_vec
224234
dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]])
225235
start_index = 0

0 commit comments

Comments
 (0)