@@ -352,12 +352,14 @@ def __init__ (self, jdata, descrpt) :
352352 args = ClassArg ()\
353353 .add ('neuron' , list , default = [120 ,120 ,120 ], alias = 'n_neuron' )\
354354 .add ('resnet_dt' , bool , default = True )\
355+ .add ('fit_diag' , bool , default = True )\
355356 .add ('sel_type' , [list ,int ], default = [ii for ii in range (self .ntypes )], alias = 'pol_type' )\
356357 .add ('seed' , int )
357358 class_data = args .parse (jdata )
358359 self .n_neuron = class_data ['neuron' ]
359360 self .resnet_dt = class_data ['resnet_dt' ]
360361 self .sel_type = class_data ['sel_type' ]
362+ self .fit_diag = class_data ['fit_diag' ]
361363 self .seed = class_data ['seed' ]
362364 self .dim_rot_mat_1 = descrpt .get_dim_rot_mat_1 ()
363365 self .dim_rot_mat = self .dim_rot_mat_1 * 3
@@ -400,12 +402,20 @@ def build (self,
400402 layer += one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed , use_timestep = self .resnet_dt )
401403 else :
402404 layer = one_layer (layer , self .n_neuron [ii ], name = 'layer_' + str (ii )+ '_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed )
403- # (nframes x natoms) x (naxis x naxis)
404- final_layer = one_layer (layer , self .dim_rot_mat_1 * self .dim_rot_mat_1 , activation_fn = None , name = 'final_layer_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed )
405- # (nframes x natoms) x naxis x naxis
406- final_layer = tf .reshape (final_layer , [tf .shape (inputs )[0 ] * natoms [2 + type_i ], self .dim_rot_mat_1 , self .dim_rot_mat_1 ])
407- # (nframes x natoms) x naxis x naxis
408- final_layer = final_layer + tf .transpose (final_layer , perm = [0 ,2 ,1 ])
405+ if self .fit_diag :
406+ # (nframes x natoms) x naxis
407+ final_layer = one_layer (layer , self .dim_rot_mat_1 , activation_fn = None , name = 'final_layer_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed )
408+ # (nframes x natoms) x naxis
409+ final_layer = tf .reshape (final_layer , [tf .shape (inputs )[0 ] * natoms [2 + type_i ], self .dim_rot_mat_1 ])
410+ # (nframes x natoms) x naxis x naxis
411+ final_layer = tf .matrix_diag (final_layer )
412+ else :
413+ # (nframes x natoms) x (naxis x naxis)
414+ final_layer = one_layer (layer , self .dim_rot_mat_1 * self .dim_rot_mat_1 , activation_fn = None , name = 'final_layer_type_' + str (type_i )+ suffix , reuse = reuse , seed = self .seed )
415+ # (nframes x natoms) x naxis x naxis
416+ final_layer = tf .reshape (final_layer , [tf .shape (inputs )[0 ] * natoms [2 + type_i ], self .dim_rot_mat_1 , self .dim_rot_mat_1 ])
417+ # (nframes x natoms) x naxis x naxis
418+ final_layer = final_layer + tf .transpose (final_layer , perm = [0 ,2 ,1 ])
409419 # (nframes x natoms) x naxis x 3(coord)
410420 final_layer = tf .matmul (final_layer , rot_mat_i )
411421 # (nframes x natoms) x 3(coord) x 3(coord)
0 commit comments