Skip to content

Commit 306d40b

Browse files
authored
Merge pull request #180 from GeiduanLiu/devel
user specified activation function for embedding and fitting nets
2 parents fadabc3 + dbe4e97 commit 306d40b

File tree

4 files changed

+48
-22
lines changed

4 files changed

+48
-22
lines changed

source/train/DescrptSeA.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from deepmd.env import tf
3-
from deepmd.common import ClassArg
3+
from deepmd.common import ClassArg, get_activation_func
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
@@ -18,7 +18,8 @@ def __init__ (self, jdata):
1818
.add('trainable',bool, default = True) \
1919
.add('seed', int) \
2020
.add('exclude_types', list, default = []) \
21-
.add('set_davg_zero', bool, default = False)
21+
.add('set_davg_zero', bool, default = False) \
22+
.add('activation_function', str, default = 'tanh')
2223
class_data = args.parse(jdata)
2324
self.sel_a = class_data['sel']
2425
self.rcut_r = class_data['rcut']
@@ -28,6 +29,7 @@ def __init__ (self, jdata):
2829
self.filter_resnet_dt = class_data['resnet_dt']
2930
self.seed = class_data['seed']
3031
self.trainable = class_data['trainable']
32+
self.filter_activation_fn = get_activation_func(class_data['activation_function'])
3133
exclude_types = class_data['exclude_types']
3234
self.exclude_types = set()
3335
for tt in exclude_types:
@@ -245,7 +247,7 @@ def _pass_filter(self,
245247
[ 0, start_index* self.ndescrpt],
246248
[-1, natoms[2+type_i]* self.ndescrpt] )
247249
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
248-
layer, qmat = self._filter(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable)
250+
layer, qmat = self._filter(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
249251
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
250252
qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_rot_mat_1() * 3])
251253
output.append(layer)

source/train/DescrptSeR.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22
from deepmd.env import tf
3-
from deepmd.common import ClassArg
3+
from deepmd.common import ClassArg, get_activation_func
44
from deepmd.RunOptions import global_tf_float_precision
55
from deepmd.RunOptions import global_np_float_precision
66
from deepmd.env import op_module
@@ -17,7 +17,8 @@ def __init__ (self, jdata):
1717
.add('trainable',bool, default = True) \
1818
.add('seed', int) \
1919
.add('exclude_types', list, default = []) \
20-
.add('set_davg_zero', bool, default = False)
20+
.add('set_davg_zero', bool, default = False) \
21+
.add("activation_function", str, default = "tanh")
2122
class_data = args.parse(jdata)
2223
self.sel_r = class_data['sel']
2324
self.rcut = class_data['rcut']
@@ -26,6 +27,7 @@ def __init__ (self, jdata):
2627
self.filter_resnet_dt = class_data['resnet_dt']
2728
self.seed = class_data['seed']
2829
self.trainable = class_data['trainable']
30+
self.filter_activation_fn = get_activation_func(class_data["activation_function"])
2931
exclude_types = class_data['exclude_types']
3032
self.exclude_types = set()
3133
for tt in exclude_types:
@@ -204,7 +206,7 @@ def _pass_filter(self,
204206
[ 0, start_index* self.ndescrpt],
205207
[-1, natoms[2+type_i]* self.ndescrpt] )
206208
inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt])
207-
layer = self._filter_r(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable)
209+
layer = self._filter_r(inputs_i, type_i, name='filter_type_'+str(type_i)+suffix, natoms=natoms, reuse=reuse, seed = self.seed, trainable = trainable, activation_fn = self.filter_activation_fn)
208210
layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[2+type_i] * self.get_dim_out()])
209211
output.append(layer)
210212
start_index += natoms[2+type_i]

source/train/Fitting.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33

44
from deepmd.env import tf
5-
from deepmd.common import ClassArg, add_data_requirement
5+
from deepmd.common import ClassArg, add_data_requirement, get_activation_func
66
from deepmd.Network import one_layer
77
from deepmd.DescrptLocFrame import DescrptLocFrame
88
from deepmd.DescrptSeA import DescrptSeA
@@ -21,14 +21,16 @@ def __init__ (self, jdata, descrpt):
2121
.add('resnet_dt', bool, default = True)\
2222
.add('rcond', float, default = 1e-3) \
2323
.add('seed', int) \
24-
.add('atom_ener', list, default = [])
24+
.add('atom_ener', list, default = [])\
25+
.add("activation_function", str, default = "tanh")
2526
class_data = args.parse(jdata)
2627
self.numb_fparam = class_data['numb_fparam']
2728
self.numb_aparam = class_data['numb_aparam']
2829
self.n_neuron = class_data['neuron']
2930
self.resnet_dt = class_data['resnet_dt']
3031
self.rcond = class_data['rcond']
3132
self.seed = class_data['seed']
33+
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
3234
self.atom_ener = []
3335
for at, ae in enumerate(class_data['atom_ener']):
3436
if ae is not None:
@@ -201,7 +203,7 @@ def build (self,
201203

202204
for ii in range(0,len(self.n_neuron)) :
203205
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
204-
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt)
206+
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn)
205207
else :
206208
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
207209
final_layer = one_layer(layer, 1, activation_fn = None, bavg = type_bias_ae, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
@@ -243,13 +245,15 @@ def __init__ (self, jdata, descrpt) :
243245
.add('resnet_dt', bool, default = True)\
244246
.add('wfc_numb', int, must = True)\
245247
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'wfc_type')\
246-
.add('seed', int)
248+
.add('seed', int)\
249+
.add("activation_function", str, default = "tanh")
247250
class_data = args.parse(jdata)
248251
self.n_neuron = class_data['neuron']
249252
self.resnet_dt = class_data['resnet_dt']
250253
self.wfc_numb = class_data['wfc_numb']
251254
self.sel_type = class_data['sel_type']
252255
self.seed = class_data['seed']
256+
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
253257
self.useBN = False
254258

255259

@@ -289,9 +293,9 @@ def build (self,
289293
layer = inputs_i
290294
for ii in range(0,len(self.n_neuron)) :
291295
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
292-
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt)
296+
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn)
293297
else :
294-
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
298+
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn)
295299
# (nframes x natoms) x (nwfc x 3)
296300
final_layer = one_layer(layer, self.wfc_numb * 3, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
297301
# (nframes x natoms) x nwfc(wc) x 3(coord_local)
@@ -322,12 +326,14 @@ def __init__ (self, jdata, descrpt) :
322326
.add('neuron', list, default = [120,120,120], alias = 'n_neuron')\
323327
.add('resnet_dt', bool, default = True)\
324328
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\
325-
.add('seed', int)
329+
.add('seed', int)\
330+
.add("activation_function", str, default = "tanh")
326331
class_data = args.parse(jdata)
327332
self.n_neuron = class_data['neuron']
328333
self.resnet_dt = class_data['resnet_dt']
329334
self.sel_type = class_data['sel_type']
330335
self.seed = class_data['seed']
336+
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
331337
self.useBN = False
332338

333339
def get_sel_type(self):
@@ -363,9 +369,9 @@ def build (self,
363369
layer = inputs_i
364370
for ii in range(0,len(self.n_neuron)) :
365371
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
366-
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt)
372+
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn)
367373
else :
368-
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
374+
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn)
369375
# (nframes x natoms) x 9
370376
final_layer = one_layer(layer, 9, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
371377
# (nframes x natoms) x 3 x 3
@@ -402,7 +408,8 @@ def __init__ (self, jdata, descrpt) :
402408
.add('diag_shift', [list,float], default = [0.0 for ii in range(self.ntypes)])\
403409
.add('scale', [list,float], default = [1.0 for ii in range(self.ntypes)])\
404410
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\
405-
.add('seed', int)
411+
.add('seed', int)\
412+
.add("activation_function", str , default = "tanh")
406413
class_data = args.parse(jdata)
407414
self.n_neuron = class_data['neuron']
408415
self.resnet_dt = class_data['resnet_dt']
@@ -411,6 +418,7 @@ def __init__ (self, jdata, descrpt) :
411418
self.seed = class_data['seed']
412419
self.diag_shift = class_data['diag_shift']
413420
self.scale = class_data['scale']
421+
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
414422
if type(self.sel_type) is not list:
415423
self.sel_type = [self.sel_type]
416424
if type(self.diag_shift) is not list:
@@ -471,9 +479,9 @@ def build (self,
471479
layer = inputs_i
472480
for ii in range(0,len(self.n_neuron)) :
473481
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
474-
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt)
482+
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn)
475483
else :
476-
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
484+
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn)
477485
if self.fit_diag :
478486
bavg = np.zeros(self.dim_rot_mat_1)
479487
# bavg[0] = self.avgeig[0]
@@ -555,12 +563,14 @@ def __init__ (self, jdata, descrpt) :
555563
.add('neuron', list, default = [120,120,120], alias = 'n_neuron')\
556564
.add('resnet_dt', bool, default = True)\
557565
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'dipole_type')\
558-
.add('seed', int)
566+
.add('seed', int)\
567+
.add("activation_function", str, default = "tanh")
559568
class_data = args.parse(jdata)
560569
self.n_neuron = class_data['neuron']
561570
self.resnet_dt = class_data['resnet_dt']
562571
self.sel_type = class_data['sel_type']
563572
self.seed = class_data['seed']
573+
self.fitting_activation_fn = get_activation_func(class_data["activation_function"])
564574
self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
565575
self.dim_rot_mat = self.dim_rot_mat_1 * 3
566576
self.useBN = False
@@ -598,9 +608,9 @@ def build (self,
598608
layer = inputs_i
599609
for ii in range(0,len(self.n_neuron)) :
600610
if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii-1] :
601-
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt)
611+
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt, activation_fn = self.fitting_activation_fn)
602612
else :
603-
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
613+
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, activation_fn = self.fitting_activation_fn)
604614
# (nframes x natoms) x naxis
605615
final_layer = one_layer(layer, self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
606616
# (nframes x natoms) x 1 * naxis

source/train/common.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
import warnings
22
import numpy as np
3+
from deepmd.env import tf
34

45
data_requirement = {}
5-
6+
activation_fn_dict = {
7+
"relu": tf.nn.relu,
8+
"relu6": tf.nn.relu6,
9+
"softplus": tf.nn.softplus,
10+
"sigmoid": tf.sigmoid,
11+
"tanh": tf.nn.tanh
12+
}
613
def add_data_requirement(key,
714
ndof,
815
atomic = False,
@@ -139,4 +146,9 @@ def j_must_have_d (jdata, key, deprecated_key) :
139146

140147
def j_have (jdata, key) :
141148
return key in jdata.keys()
149+
150+
def get_activation_func(activation_fn):
151+
if activation_fn not in activation_fn_dict:
152+
raise RuntimeError(activation_fn+" is not a valid activation function")
153+
return activation_fn_dict[activation_fn]
142154

0 commit comments

Comments
 (0)