Skip to content

Commit 6cceef0

Browse files
authored
Merge pull request #143 from amcadmus/devel
Devel
2 parents f5835ef + 30482b3 commit 6cceef0

File tree

7 files changed

+219
-77
lines changed

7 files changed

+219
-77
lines changed

source/train/DescrptLocFrame.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,33 @@ def __init__(self, jdata):
4545
self.davg = None
4646
self.dstd = None
4747

48+
self.place_holders = {}
49+
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
50+
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
51+
sub_graph = tf.Graph()
52+
with sub_graph.as_default():
53+
name_pfx = 'd_lf_'
54+
for ii in ['coord', 'box']:
55+
self.place_holders[ii] = tf.placeholder(global_np_float_precision, [None, None], name = name_pfx+'t_'+ii)
56+
self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx+'t_type')
57+
self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name=name_pfx+'t_natoms')
58+
self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name=name_pfx+'t_mesh')
59+
self.stat_descrpt, descrpt_deriv, rij, nlist, axis, rot_mat \
60+
= op_module.descrpt (self.place_holders['coord'],
61+
self.place_holders['type'],
62+
self.place_holders['natoms_vec'],
63+
self.place_holders['box'],
64+
self.place_holders['default_mesh'],
65+
tf.constant(avg_zero),
66+
tf.constant(std_ones),
67+
rcut_a = self.rcut_a,
68+
rcut_r = self.rcut_r,
69+
sel_a = self.sel_a,
70+
sel_r = self.sel_r,
71+
axis_rule = self.axis_rule)
72+
self.sub_sess = tf.Session(graph = sub_graph)
73+
74+
4875
def get_rcut (self) :
4976
return self.rcut_r
5077

@@ -174,31 +201,15 @@ def _compute_dstats_sys_nonsmth (self,
174201
data_atype,
175202
natoms_vec,
176203
mesh) :
177-
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
178-
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
179-
sub_graph = tf.Graph()
180-
with sub_graph.as_default():
181-
descrpt, descrpt_deriv, rij, nlist, axis, rot_mat \
182-
= op_module.descrpt (tf.constant(data_coord),
183-
tf.constant(data_atype),
184-
tf.constant(natoms_vec, dtype = tf.int32),
185-
tf.constant(data_box),
186-
tf.constant(mesh),
187-
tf.constant(avg_zero),
188-
tf.constant(std_ones),
189-
rcut_a = self.rcut_a,
190-
rcut_r = self.rcut_r,
191-
sel_a = self.sel_a,
192-
sel_r = self.sel_r,
193-
axis_rule = self.axis_rule)
194-
# self.sess.run(tf.global_variables_initializer())
195-
# sub_sess = tf.Session(graph = sub_graph,
196-
# config=tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads,
197-
# inter_op_parallelism_threads=self.run_opt.num_inter_threads
198-
# ))
199-
sub_sess = tf.Session(graph = sub_graph)
200-
dd_all = sub_sess.run(descrpt)
201-
sub_sess.close()
204+
dd_all \
205+
= self.sub_sess.run(self.stat_descrpt,
206+
feed_dict = {
207+
self.place_holders['coord']: data_coord,
208+
self.place_holders['type']: data_atype,
209+
self.place_holders['natoms_vec']: natoms_vec,
210+
self.place_holders['box']: data_box,
211+
self.place_holders['default_mesh']: mesh,
212+
})
202213
natoms = natoms_vec
203214
dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]])
204215
start_index = 0

source/train/DescrptSeA.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,32 @@ def __init__ (self, jdata):
5656
self.dstd = None
5757
self.davg = None
5858

59+
self.place_holders = {}
60+
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
61+
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
62+
sub_graph = tf.Graph()
63+
with sub_graph.as_default():
64+
name_pfx = 'd_sea_'
65+
for ii in ['coord', 'box']:
66+
self.place_holders[ii] = tf.placeholder(global_np_float_precision, [None, None], name = name_pfx+'t_'+ii)
67+
self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx+'t_type')
68+
self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name=name_pfx+'t_natoms')
69+
self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name=name_pfx+'t_mesh')
70+
self.stat_descrpt, descrpt_deriv, rij, nlist \
71+
= op_module.descrpt_se_a(self.place_holders['coord'],
72+
self.place_holders['type'],
73+
self.place_holders['natoms_vec'],
74+
self.place_holders['box'],
75+
self.place_holders['default_mesh'],
76+
tf.constant(avg_zero),
77+
tf.constant(std_ones),
78+
rcut_a = self.rcut_a,
79+
rcut_r = self.rcut_r,
80+
rcut_r_smth = self.rcut_r_smth,
81+
sel_a = self.sel_a,
82+
sel_r = self.sel_r)
83+
self.sub_sess = tf.Session(graph = sub_graph)
84+
5985

6086
def get_rcut (self) :
6187
return self.rcut_r
@@ -240,32 +266,15 @@ def _compute_dstats_sys_smth (self,
240266
data_atype,
241267
natoms_vec,
242268
mesh) :
243-
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
244-
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
245-
sub_graph = tf.Graph()
246-
with sub_graph.as_default():
247-
descrpt, descrpt_deriv, rij, nlist \
248-
= op_module.descrpt_se_a (tf.constant(data_coord),
249-
tf.constant(data_atype),
250-
tf.constant(natoms_vec, dtype = tf.int32),
251-
tf.constant(data_box),
252-
tf.constant(mesh),
253-
tf.constant(avg_zero),
254-
tf.constant(std_ones),
255-
rcut_a = self.rcut_a,
256-
rcut_r = self.rcut_r,
257-
rcut_r_smth = self.rcut_r_smth,
258-
sel_a = self.sel_a,
259-
sel_r = self.sel_r)
260-
# self.sess.run(tf.global_variables_initializer())
261-
# sub_sess = tf.Session(graph = sub_graph,
262-
# config=tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads,
263-
# inter_op_parallelism_threads=self.run_opt.num_inter_threads
264-
265-
# ))
266-
sub_sess = tf.Session(graph = sub_graph)
267-
dd_all = sub_sess.run(descrpt)
268-
sub_sess.close()
269+
dd_all \
270+
= self.sub_sess.run(self.stat_descrpt,
271+
feed_dict = {
272+
self.place_holders['coord']: data_coord,
273+
self.place_holders['type']: data_atype,
274+
self.place_holders['natoms_vec']: natoms_vec,
275+
self.place_holders['box']: data_box,
276+
self.place_holders['default_mesh']: mesh,
277+
})
269278
natoms = natoms_vec
270279
dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]])
271280
start_index = 0

source/train/DescrptSeR.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,30 @@ def __init__ (self, jdata):
5252
self.davg = None
5353
self.dstd = None
5454

55+
self.place_holders = {}
56+
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
57+
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
58+
sub_graph = tf.Graph()
59+
with sub_graph.as_default():
60+
name_pfx = 'd_ser_'
61+
for ii in ['coord', 'box']:
62+
self.place_holders[ii] = tf.placeholder(global_np_float_precision, [None, None], name = name_pfx+'t_'+ii)
63+
self.place_holders['type'] = tf.placeholder(tf.int32, [None, None], name=name_pfx+'t_type')
64+
self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name=name_pfx+'t_natoms')
65+
self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name=name_pfx+'t_mesh')
66+
self.stat_descrpt, descrpt_deriv, rij, nlist \
67+
= op_module.descrpt_se_r(self.place_holders['coord'],
68+
self.place_holders['type'],
69+
self.place_holders['natoms_vec'],
70+
self.place_holders['box'],
71+
self.place_holders['default_mesh'],
72+
tf.constant(avg_zero),
73+
tf.constant(std_ones),
74+
rcut = self.rcut,
75+
rcut_smth = self.rcut_smth,
76+
sel = self.sel_r)
77+
self.sub_sess = tf.Session(graph = sub_graph)
78+
5579

5680
def get_rcut (self) :
5781
return self.rcut
@@ -197,29 +221,15 @@ def _compute_dstats_sys_se_r (self,
197221
data_atype,
198222
natoms_vec,
199223
mesh) :
200-
avg_zero = np.zeros([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
201-
std_ones = np.ones ([self.ntypes,self.ndescrpt]).astype(global_np_float_precision)
202-
sub_graph = tf.Graph()
203-
with sub_graph.as_default():
204-
descrpt, descrpt_deriv, rij, nlist \
205-
= op_module.descrpt_se_r (tf.constant(data_coord),
206-
tf.constant(data_atype),
207-
tf.constant(natoms_vec, dtype = tf.int32),
208-
tf.constant(data_box),
209-
tf.constant(mesh),
210-
tf.constant(avg_zero),
211-
tf.constant(std_ones),
212-
rcut = self.rcut,
213-
rcut_smth = self.rcut_smth,
214-
sel = self.sel_r)
215-
# sub_sess = tf.Session(graph = sub_graph,
216-
# config=tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads,
217-
# inter_op_parallelism_threads=self.run_opt.num_inter_threads
218-
219-
# ))
220-
sub_sess = tf.Session(graph = sub_graph)
221-
dd_all = sub_sess.run(descrpt)
222-
sub_sess.close()
224+
dd_all \
225+
= self.sub_sess.run(self.stat_descrpt,
226+
feed_dict = {
227+
self.place_holders['coord']: data_coord,
228+
self.place_holders['type']: data_atype,
229+
self.place_holders['natoms_vec']: natoms_vec,
230+
self.place_holders['box']: data_box,
231+
self.place_holders['default_mesh']: mesh,
232+
})
223233
natoms = natoms_vec
224234
dd_all = np.reshape(dd_all, [-1, self.ndescrpt * natoms[0]])
225235
start_index = 0

source/train/Loss.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,100 @@ def print_on_training(self,
178178
return print_str
179179

180180

181+
class EnerDipoleLoss () :
182+
def __init__ (self, jdata, **kwarg) :
183+
self.starter_learning_rate = kwarg['starter_learning_rate']
184+
args = ClassArg()\
185+
.add('start_pref_e', float, must = True, default = 0.1) \
186+
.add('limit_pref_e', float, must = True, default = 1.00)\
187+
.add('start_pref_ed', float, must = True, default = 1.00)\
188+
.add('limit_pref_ed', float, must = True, default = 1.00)
189+
class_data = args.parse(jdata)
190+
self.start_pref_e = class_data['start_pref_e']
191+
self.limit_pref_e = class_data['limit_pref_e']
192+
self.start_pref_ed = class_data['start_pref_ed']
193+
self.limit_pref_ed = class_data['limit_pref_ed']
194+
# data required
195+
add_data_requirement('energy', 1, atomic=False, must=True, high_prec=True)
196+
add_data_requirement('energy_dipole', 3, atomic=False, must=True, high_prec=False)
197+
198+
def build (self,
199+
learning_rate,
200+
natoms,
201+
model_dict,
202+
label_dict,
203+
suffix):
204+
coord = model_dict['coord']
205+
energy = model_dict['energy']
206+
atom_ener = model_dict['atom_ener']
207+
nframes = tf.shape(atom_ener)[0]
208+
natoms = tf.shape(atom_ener)[1]
209+
# build energy dipole
210+
atom_ener0 = atom_ener - tf.reshape(tf.tile(tf.reshape(energy/global_cvt_2_ener_float(natoms), [-1, 1]), [1, natoms]), [nframes, natoms])
211+
coord = tf.reshape(coord, [nframes, natoms, 3])
212+
atom_ener0 = tf.reshape(atom_ener0, [nframes, 1, natoms])
213+
ener_dipole = tf.matmul(atom_ener0, coord)
214+
ener_dipole = tf.reshape(ener_dipole, [nframes, 3])
215+
216+
energy_hat = label_dict['energy']
217+
ener_dipole_hat = label_dict['energy_dipole']
218+
find_energy = label_dict['find_energy']
219+
find_ener_dipole = label_dict['find_energy_dipole']
220+
221+
l2_ener_loss = tf.reduce_mean( tf.square(energy - energy_hat), name='l2_'+suffix)
222+
223+
ener_dipole_reshape = tf.reshape(ener_dipole, [-1])
224+
ener_dipole_hat_reshape = tf.reshape(ener_dipole_hat, [-1])
225+
l2_ener_dipole_loss = tf.reduce_mean( tf.square(ener_dipole_reshape - ener_dipole_hat_reshape), name='l2_'+suffix)
226+
227+
# atom_norm_ener = 1./ global_cvt_2_ener_float(natoms[0])
228+
atom_norm_ener = 1./ global_cvt_2_ener_float(natoms)
229+
pref_e = global_cvt_2_ener_float(find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate) )
230+
pref_ed = global_cvt_2_tf_float(find_ener_dipole * (self.limit_pref_ed + (self.start_pref_ed - self.limit_pref_ed) * learning_rate / self.starter_learning_rate) )
231+
232+
l2_loss = 0
233+
more_loss = {}
234+
l2_loss += atom_norm_ener * (pref_e * l2_ener_loss)
235+
l2_loss += global_cvt_2_ener_float(pref_ed * l2_ener_dipole_loss)
236+
more_loss['l2_ener_loss'] = l2_ener_loss
237+
more_loss['l2_ener_dipole_loss'] = l2_ener_dipole_loss
238+
239+
self.l2_l = l2_loss
240+
self.l2_more = more_loss
241+
return l2_loss, more_loss
242+
243+
244+
def print_header(self) :
245+
prop_fmt = ' %9s %9s'
246+
print_str = ''
247+
print_str += prop_fmt % ('l2_tst', 'l2_trn')
248+
print_str += prop_fmt % ('l2_e_tst', 'l2_e_trn')
249+
print_str += prop_fmt % ('l2_ed_tst', 'l2_ed_trn')
250+
return print_str
251+
252+
253+
def print_on_training(self,
254+
sess,
255+
natoms,
256+
feed_dict_test,
257+
feed_dict_batch) :
258+
error_test, error_e_test, error_ed_test\
259+
= sess.run([self.l2_l, \
260+
self.l2_more['l2_ener_loss'], \
261+
self.l2_more['l2_ener_dipole_loss']],
262+
feed_dict=feed_dict_test)
263+
error_train, error_e_train, error_ed_train\
264+
= sess.run([self.l2_l, \
265+
self.l2_more['l2_ener_loss'], \
266+
self.l2_more['l2_ener_dipole_loss']],
267+
feed_dict=feed_dict_batch)
268+
print_str = ""
269+
prop_fmt = " %9.2e %9.2e"
270+
print_str += prop_fmt % (np.sqrt(error_test), np.sqrt(error_train))
271+
print_str += prop_fmt % (np.sqrt(error_e_test) / natoms[0], np.sqrt(error_e_train) / natoms[0])
272+
print_str += prop_fmt % (np.sqrt(error_ed_test), np.sqrt(error_ed_train))
273+
return print_str
274+
181275

182276
class TensorLoss () :
183277
def __init__ (self, jdata, **kwarg) :

source/train/Model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ def build (self,
270270
model_dict['virial'] = virial
271271
model_dict['atom_ener'] = energy_raw
272272
model_dict['atom_virial'] = atom_virial
273+
model_dict['coord'] = coord
274+
model_dict['atype'] = atype
273275

274276
return model_dict
275277

source/train/Trainer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from deepmd.DescrptSeR import DescrptSeR
1919
from deepmd.DescrptSeAR import DescrptSeAR
2020
from deepmd.Model import Model, WFCModel, DipoleModel, PolarModel, GlobalPolarModel
21-
from deepmd.Loss import EnerStdLoss, TensorLoss
21+
from deepmd.Loss import EnerStdLoss, EnerDipoleLoss, TensorLoss
2222
from deepmd.LearningRate import LearningRateExp
2323

2424
from tensorflow.python.framework import ops
@@ -144,10 +144,18 @@ def _init_param(self, jdata):
144144
# infer loss type by fitting_type
145145
try :
146146
loss_param = jdata['loss']
147+
loss_type = loss_param.get('type', 'std')
147148
except:
148149
loss_param = None
150+
loss_type = 'std'
151+
149152
if fitting_type == 'ener':
150-
self.loss = EnerStdLoss(loss_param, starter_learning_rate = self.lr.start_lr())
153+
if loss_type == 'std':
154+
self.loss = EnerStdLoss(loss_param, starter_learning_rate = self.lr.start_lr())
155+
elif loss_type == 'ener_dipole':
156+
self.loss = EnerDipoleLoss(loss_param, starter_learning_rate = self.lr.start_lr())
157+
else:
158+
raise RuntimeError('unknow loss type')
151159
elif fitting_type == 'wfc':
152160
self.loss = TensorLoss(loss_param,
153161
model = self.model,
@@ -262,7 +270,6 @@ def _build_network(self, data):
262270
self.place_holders['natoms_vec'] = tf.placeholder(tf.int32, [self.ntypes+2], name='t_natoms')
263271
self.place_holders['default_mesh'] = tf.placeholder(tf.int32, [None], name='t_mesh')
264272
self.place_holders['is_training'] = tf.placeholder(tf.bool)
265-
266273
self.model_pred\
267274
= self.model.build (self.place_holders['coord'],
268275
self.place_holders['type'],

source/train/train.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ def train (args) :
8585
# serial training
8686
_do_work(jdata, run_opt)
8787

88+
def expand_sys_str(root_dir):
89+
all_sys = []
90+
from pathlib import Path
91+
for filename in Path(root_dir).rglob('type.raw'):
92+
all_sys.append(os.path.dirname(filename))
93+
return all_sys
94+
8895
def _do_work(jdata, run_opt):
8996
# init the model
9097
model = NNPTrainer (jdata, run_opt = run_opt)
@@ -93,6 +100,8 @@ def _do_work(jdata, run_opt):
93100
# init params and run options
94101
assert('training' in jdata)
95102
systems = j_must_have(jdata['training'], 'systems')
103+
if type(systems) == str:
104+
systems = expand_sys_str(systems)
96105
set_pfx = j_must_have(jdata['training'], 'set_prefix')
97106
numb_sys = len(systems)
98107
seed = None

0 commit comments

Comments
 (0)