Skip to content

Commit 7f8afdd

Browse files
authored
Merge pull request #115 from amcadmus/devel
polarizability: add option to only fit the diag part
2 parents ca61f69 + 9a61ff9 commit 7f8afdd

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

examples/water/train/polar_se_a.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
},
1818
"fitting_net": {
1919
"type": "polar",
20-
"pol_type": [0],
20+
"sel_type": [0],
21+
"fit_diag": true,
2122
"neuron": [100, 100, 100],
2223
"resnet_dt": true,
2324
"seed": 1,
@@ -28,7 +29,7 @@
2829

2930
"learning_rate" :{
3031
"type": "exp",
31-
"start_lr": 0.001,
32+
"start_lr": 0.01,
3233
"decay_steps": 5000,
3334
"decay_rate": 0.95,
3435
"_comment": "that's all"

source/tests/polar_se_a.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"fitting_net": {
1919
"type": "polar",
2020
"pol_type": [0],
21+
"fit_diag": false,
2122
"neuron": [100, 100, 100],
2223
"resnet_dt": true,
2324
"seed": 1,

source/train/Fitting.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,14 @@ def __init__ (self, jdata, descrpt) :
352352
args = ClassArg()\
353353
.add('neuron', list, default = [120,120,120], alias = 'n_neuron')\
354354
.add('resnet_dt', bool, default = True)\
355+
.add('fit_diag', bool, default = True)\
355356
.add('sel_type', [list,int], default = [ii for ii in range(self.ntypes)], alias = 'pol_type')\
356357
.add('seed', int)
357358
class_data = args.parse(jdata)
358359
self.n_neuron = class_data['neuron']
359360
self.resnet_dt = class_data['resnet_dt']
360361
self.sel_type = class_data['sel_type']
362+
self.fit_diag = class_data['fit_diag']
361363
self.seed = class_data['seed']
362364
self.dim_rot_mat_1 = descrpt.get_dim_rot_mat_1()
363365
self.dim_rot_mat = self.dim_rot_mat_1 * 3
@@ -400,12 +402,20 @@ def build (self,
400402
layer+= one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed, use_timestep = self.resnet_dt)
401403
else :
402404
layer = one_layer(layer, self.n_neuron[ii], name='layer_'+str(ii)+'_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
403-
# (nframes x natoms) x (naxis x naxis)
404-
final_layer = one_layer(layer, self.dim_rot_mat_1*self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
405-
# (nframes x natoms) x naxis x naxis
406-
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], self.dim_rot_mat_1, self.dim_rot_mat_1])
407-
# (nframes x natoms) x naxis x naxis
408-
final_layer = final_layer + tf.transpose(final_layer, perm = [0,2,1])
405+
if self.fit_diag :
406+
# (nframes x natoms) x naxis
407+
final_layer = one_layer(layer, self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
408+
# (nframes x natoms) x naxis
409+
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], self.dim_rot_mat_1])
410+
# (nframes x natoms) x naxis x naxis
411+
final_layer = tf.matrix_diag(final_layer)
412+
else :
413+
# (nframes x natoms) x (naxis x naxis)
414+
final_layer = one_layer(layer, self.dim_rot_mat_1*self.dim_rot_mat_1, activation_fn = None, name='final_layer_type_'+str(type_i)+suffix, reuse=reuse, seed = self.seed)
415+
# (nframes x natoms) x naxis x naxis
416+
final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0] * natoms[2+type_i], self.dim_rot_mat_1, self.dim_rot_mat_1])
417+
# (nframes x natoms) x naxis x naxis
418+
final_layer = final_layer + tf.transpose(final_layer, perm = [0,2,1])
409419
# (nframes x natoms) x naxis x 3(coord)
410420
final_layer = tf.matmul(final_layer, rot_mat_i)
411421
# (nframes x natoms) x 3(coord) x 3(coord)

0 commit comments

Comments
 (0)