Skip to content

Commit 0e793e6

Browse files
author
Han Wang
committed
implement dipole fitting by Linfeng
1 parent 7e6594a commit 0e793e6

File tree

3 files changed

+93
-2
lines changed

3 files changed

+93
-2
lines changed

source/train/Fitting.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,76 @@ def build (self,
421421
count += 1
422422

423423
return tf.reshape(outs, [-1])
424+
425+
426+
class DipoleFittingSeA () :
427+
def __init__ (self, jdata, descrpt) :
428+
if not isinstance(descrpt, DescrptSeA) :
429+
raise RuntimeError('DipoleFittingSeA only supports DescrptSeA')
430+
self.ntypes = descrpt.get_ntypes()
431+
self.dim_descrpt = descrpt.get_dim_out()
432+
args = ClassArg()\
433+
.add('neuron', list, default = [120,120,120], alias = 'n_neuron')\
434+
.add('resnet_dt', bool, default = True)\
435+
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'dipole_type')\
436+
.add('seed', int)
437+
class_data = args.parse(jdata)
438+
self.n_neuron = class_data['neuron']
439+
self.resnet_dt = class_data['resnet_dt']
440+
self.sel_type = class_data['sel_type']
441+
self.seed = class_data['seed']
442+
self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
443+
self.dim_rot_mat = self.dim_rot_mat_1 * 3
444+
self.useBN = False
445+
446+
def get_sel_type(self):
447+
return self.sel_type
448+
449+
def build (self,
450+
input_d,
451+
rot_mat,
452+
natoms,
453+
reuse = None,
454+
suffix = '') :
455+
start_index = 0
456+
inputs = tf.reshape(input_d, [-1, self.dim_descrpt * natoms[0]])
457+
rot_mat = tf.reshape(rot_mat, [-1, self.dim_rot_mat * natoms[0]])
458+
shape = inputs.get_shape().as_list()
459+
460+
count = 0
461+
for type_i in range(self.ntypes):
462+
# cut-out inputs
463+
inputs_i = tf.slice (inputs,
464+
[ 0, start_index* self.dim_descrpt],
465+
[-1, natoms[2+type_i]* self.dim_descrpt] )
466+
inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt])
467+
rot_mat_i = tf.slice (rot_mat,
468+
[ 0, start_index* self.dim_rot_mat],
469+
[-1, natoms[2+type_i]* self.dim_rot_mat] )
470+
rot_mat_i = tf.reshape(rot_mat_i, [-1, self.dim_rot_mat_1, 3])
471+
start_index += natoms[2+type_i]
472+
if not type_i in self.sel_type :
473+
continue
474+
layer = inputs_i
475+
for ii in range(0,len(self.n_neuron)) :
476+
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
477+
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt)
478+
else :
479+
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
480+
# (nframes x natoms) x naxis
481+
final_layer = one_layer(layer, self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
482+
# (nframes x natoms) x 1 * naxis
483+
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], 1, self.dim_rot_mat_1])
484+
# (nframes x natoms) x 1 x 3(coord)
485+
final_layer = tf.matmul(final_layer, rot_mat_i)
486+
# nframes x natoms x 3
487+
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2+type_i], 3])
488+
489+
# concat the results
490+
if count == 0:
491+
outs = final_layer
492+
else:
493+
outs = tf.concat([outs, final_layer], axis = 1)
494+
count += 1
495+
496+
return tf.reshape(outs, [-1])

source/train/Model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,11 @@ def __init__(self, jdata, descrpt, fitting) :
325325
TensorModel.__init__(self, jdata, descrpt, fitting, 'wfc')
326326

327327

328+
class DipoleModel(TensorModel):
329+
def __init__(self, jdata, descrpt, fitting) :
330+
TensorModel.__init__(self, jdata, descrpt, fitting, 'dipole')
331+
332+
328333
class PolarModel(TensorModel):
329334
def __init__(self, jdata, descrpt, fitting) :
330335
TensorModel.__init__(self, jdata, descrpt, fitting, 'polar')

source/train/Trainer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
from deepmd.RunOptions import global_ener_float_precision
1313
from deepmd.RunOptions import global_cvt_2_tf_float
1414
from deepmd.RunOptions import global_cvt_2_ener_float
15-
from deepmd.Fitting import EnerFitting, WFCFitting, PolarFittingLocFrame, PolarFittingSeA
15+
from deepmd.Fitting import EnerFitting, WFCFitting, PolarFittingLocFrame, PolarFittingSeA, DipoleFittingSeA
1616
from deepmd.DescrptLocFrame import DescrptLocFrame
1717
from deepmd.DescrptSeA import DescrptSeA
1818
from deepmd.DescrptSeR import DescrptSeR
1919
from deepmd.DescrptSeAR import DescrptSeAR
20-
from deepmd.Model import Model, WFCModel, PolarModel
20+
from deepmd.Model import Model, WFCModel, DipoleModel, PolarModel
2121
from deepmd.Loss import EnerStdLoss, TensorLoss
2222
from deepmd.LearningRate import LearningRateExp
2323

@@ -101,6 +101,11 @@ def _init_param(self, jdata):
101101
self.fitting = PolarFittingSeA(fitting_param, self.descrpt)
102102
else :
103103
raise RuntimeError('fitting polar only supports descrptors: loc_frame and se_a')
104+
elif fitting_type == 'dipole':
105+
if descrpt_type == 'se_a':
106+
self.fitting = DipoleFittingSeA(fitting_param, self.descrpt)
107+
else :
108+
raise RuntimeError('fitting dipole only supports descrptors: se_a')
104109
else :
105110
raise RuntimeError('unknow fitting type ' + fitting_type)
106111

@@ -112,6 +117,8 @@ def _init_param(self, jdata):
112117
self.model = WFCModel(model_param, self.descrpt, self.fitting)
113118
elif fitting_type == 'polar':
114119
self.model = PolarModel(model_param, self.descrpt, self.fitting)
120+
elif fitting_type == 'dipole':
121+
self.model = DipoleModel(model_param, self.descrpt, self.fitting)
115122
else :
116123
raise RuntimeError('get unknown fitting type when building model')
117124

@@ -140,6 +147,12 @@ def _init_param(self, jdata):
140147
tensor_name = 'wfc',
141148
tensor_size = self.model.get_out_size(),
142149
label_name = 'wfc')
150+
elif fitting_type == 'dipole':
151+
self.loss = TensorLoss(loss_param,
152+
model = self.model,
153+
tensor_name = 'dipole',
154+
tensor_size = 3,
155+
label_name = 'dipole')
143156
elif fitting_type == 'polar':
144157
self.loss = TensorLoss(loss_param,
145158
model = self.model,

0 commit comments

Comments
 (0)