Skip to content

Commit f88ee74

Browse files
author
Han Wang
committed
add fitting global polarizability
1 parent d17c857 commit f88ee74

File tree

4 files changed

+61
-8
lines changed

4 files changed

+61
-8
lines changed

source/train/Fitting.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,34 @@ def build (self,
423423
return tf.reshape(outs, [-1])
424424

425425

426+
class GlobalPolarFittingSeA () :
427+
def __init__ (self, jdata, descrpt) :
428+
if not isinstance(descrpt, DescrptSeA) :
429+
raise RuntimeError('GlobalPolarFittingSeA only supports DescrptSeA')
430+
self.ntypes = descrpt.get_ntypes()
431+
self.dim_descrpt = descrpt.get_dim_out()
432+
self.polar_fitting = PolarFittingSeA(jdata, descrpt)
433+
434+
def get_sel_type(self):
435+
return self.polar_fitting.get_sel_type()
436+
437+
def get_out_size(self):
438+
return self.polar_fitting.get_out_size()
439+
440+
def build (self,
441+
input_d,
442+
rot_mat,
443+
natoms,
444+
reuse = None,
445+
suffix = '') :
446+
inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]])
447+
outs = self.polar_fitting.build(input_d, rot_mat, natoms, reuse, suffix)
448+
# nframes x natoms x 9
449+
outs = tf.reshape(outs, [tf.shape(inputs)[0], -1, 9])
450+
outs = tf.reduce_sum(outs, axis = 1)
451+
return tf.reshape(outs, [-1])
452+
453+
426454
class DipoleFittingSeA () :
427455
def __init__ (self, jdata, descrpt) :
428456
if not isinstance(descrpt, DescrptSeA) :

source/train/Loss.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,11 @@ def __init__ (self, jdata, **kwarg) :
189189
self.tensor_name = kwarg['tensor_name']
190190
self.tensor_size = kwarg['tensor_size']
191191
self.label_name = kwarg['label_name']
192+
self.atomic = kwarg.get('atomic', True)
192193
# data required
193194
add_data_requirement(self.label_name,
194195
self.tensor_size,
195-
atomic=True,
196+
atomic=self.atomic,
196197
must=True,
197198
high_prec=False,
198199
type_sel = type_sel)
@@ -206,6 +207,9 @@ def build (self,
206207
polar_hat = label_dict[self.label_name]
207208
polar = model_dict[self.tensor_name]
208209
l2_loss = tf.reduce_mean( tf.square(polar - polar_hat), name='l2_'+suffix)
210+
if not self.atomic :
211+
atom_norm = 1./ global_cvt_2_tf_float(natoms[0])
212+
l2_loss = l2_loss * atom_norm
209213
self.l2_l = l2_loss
210214
more_loss = {}
211215

source/train/Model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,3 +333,10 @@ def __init__(self, jdata, descrpt, fitting) :
333333
class PolarModel(TensorModel):
334334
def __init__(self, jdata, descrpt, fitting) :
335335
TensorModel.__init__(self, jdata, descrpt, fitting, 'polar')
336+
337+
338+
class GlobalPolarModel(TensorModel):
339+
def __init__(self, jdata, descrpt, fitting) :
340+
TensorModel.__init__(self, jdata, descrpt, fitting, 'global_polar')
341+
342+

source/train/Trainer.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from deepmd.RunOptions import global_ener_float_precision
1313
from deepmd.RunOptions import global_cvt_2_tf_float
1414
from deepmd.RunOptions import global_cvt_2_ener_float
15-
from deepmd.Fitting import EnerFitting, WFCFitting, PolarFittingLocFrame, PolarFittingSeA, DipoleFittingSeA
15+
from deepmd.Fitting import EnerFitting, WFCFitting, PolarFittingLocFrame, PolarFittingSeA, GlobalPolarFittingSeA, DipoleFittingSeA
1616
from deepmd.DescrptLocFrame import DescrptLocFrame
1717
from deepmd.DescrptSeA import DescrptSeA
1818
from deepmd.DescrptSeR import DescrptSeR
1919
from deepmd.DescrptSeAR import DescrptSeAR
20-
from deepmd.Model import Model, WFCModel, DipoleModel, PolarModel
20+
from deepmd.Model import Model, WFCModel, DipoleModel, PolarModel, GlobalPolarModel
2121
from deepmd.Loss import EnerStdLoss, TensorLoss
2222
from deepmd.LearningRate import LearningRateExp
2323

@@ -94,18 +94,23 @@ def _init_param(self, jdata):
9494
self.fitting = EnerFitting(fitting_param, self.descrpt)
9595
elif fitting_type == 'wfc':
9696
self.fitting = WFCFitting(fitting_param, self.descrpt)
97+
elif fitting_type == 'dipole':
98+
if descrpt_type == 'se_a':
99+
self.fitting = DipoleFittingSeA(fitting_param, self.descrpt)
100+
else :
101+
raise RuntimeError('fitting dipole only supports descrptors: se_a')
97102
elif fitting_type == 'polar':
98103
if descrpt_type == 'loc_frame':
99104
self.fitting = PolarFittingLocFrame(fitting_param, self.descrpt)
100105
elif descrpt_type == 'se_a':
101106
self.fitting = PolarFittingSeA(fitting_param, self.descrpt)
102107
else :
103108
raise RuntimeError('fitting polar only supports descrptors: loc_frame and se_a')
104-
elif fitting_type == 'dipole':
109+
elif fitting_type == 'global_polar':
105110
if descrpt_type == 'se_a':
106-
self.fitting = DipoleFittingSeA(fitting_param, self.descrpt)
111+
self.fitting = GlobalPolarFittingSeA(fitting_param, self.descrpt)
107112
else :
108-
raise RuntimeError('fitting dipole only supports descrptors: se_a')
113+
raise RuntimeError('fitting global_polar only supports descrptors: loc_frame and se_a')
109114
else :
110115
raise RuntimeError('unknow fitting type ' + fitting_type)
111116

@@ -115,10 +120,12 @@ def _init_param(self, jdata):
115120
self.model = Model(model_param, self.descrpt, self.fitting)
116121
elif fitting_type == 'wfc':
117122
self.model = WFCModel(model_param, self.descrpt, self.fitting)
118-
elif fitting_type == 'polar':
119-
self.model = PolarModel(model_param, self.descrpt, self.fitting)
120123
elif fitting_type == 'dipole':
121124
self.model = DipoleModel(model_param, self.descrpt, self.fitting)
125+
elif fitting_type == 'polar':
126+
self.model = PolarModel(model_param, self.descrpt, self.fitting)
127+
elif fitting_type == 'global_polar':
128+
self.model = GlobalPolarModel(model_param, self.descrpt, self.fitting)
122129
else :
123130
raise RuntimeError('get unknown fitting type when building model')
124131

@@ -159,6 +166,13 @@ def _init_param(self, jdata):
159166
tensor_name = 'polar',
160167
tensor_size = 9,
161168
label_name = 'polarizability')
169+
elif fitting_type == 'global_polar':
170+
self.loss = TensorLoss(loss_param,
171+
model = self.model,
172+
tensor_name = 'global_polar',
173+
tensor_size = 9,
174+
atomic = False,
175+
label_name = 'polarizability')
162176
else :
163177
raise RuntimeError('get unknown fitting type when building loss function')
164178

0 commit comments

Comments
 (0)