@@ -383,6 +383,8 @@ def __init__ (self, jdata, descrpt) :
383383 .add ('neuron' , list , default = [120 ,120 ,120 ], alias = 'n_neuron' )\
384384 .add ('resnet_dt' , bool , default = True )\
385385 .add ('fit_diag' , bool , default = True )\
386+ .add ('diag_shift' , [list ,float ], default = [0.0 for ii in range (self .ntypes )])\
387+ .add ('scale' , [list ,float ], default = [1.0 for ii in range (self .ntypes )])\
386388 .add ('sel_type' , [list ,int ], default = [ii for ii in range (self .ntypes )], alias = 'pol_type' )\
387389 .add ('seed' , int )
388390 class_data = args .parse (jdata )
@@ -391,6 +393,8 @@ def __init__ (self, jdata, descrpt) :
391393 self .sel_type = class_data ['sel_type' ]
392394 self .fit_diag = class_data ['fit_diag' ]
393395 self .seed = class_data ['seed' ]
396+ self .diag_shift = class_data ['diag_shift' ]
397+ self .scale = class_data ['scale' ]
394398 self .dim_rot_mat_1 = descrpt .get_dim_rot_mat_1 ()
395399 self .dim_rot_mat = self .dim_rot_mat_1 * 3
396400 self .useBN = False
@@ -477,6 +481,10 @@ def build (self,
477481 final_layer = tf .matmul (rot_mat_i , final_layer , transpose_a = True )
478482 # nframes x natoms x 3 x 3
479483 final_layer = tf .reshape (final_layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ], 3 , 3 ])
484+ # shift and scale
485+ sel_type_idx = self .sel_type .index (type_i )
486+ final_layer = final_layer * self .scale [sel_type_idx ]
487+ final_layer = final_layer + self .diag_shift [sel_type_idx ] * tf .eye (3 , batch_shape = [tf .shape (inputs )[0 ], natoms [2 + type_i ]], dtype = global_tf_float_precision )
480488
481489 # concat the results
482490 if count == 0 :
0 commit comments