@@ -383,6 +383,8 @@ def __init__ (self, jdata, descrpt) :
383383 .add ('neuron' , list , default = [120 ,120 ,120 ], alias = 'n_neuron' )\
384384 .add ('resnet_dt' , bool , default = True )\
385385 .add ('fit_diag' , bool , default = True )\
386+ .add ('diag_shift' , [list ,float ], default = [0.0 for ii in range (self .ntypes )])\
387+ .add ('scale' , [list ,float ], default = [1.0 for ii in range (self .ntypes )])\
386388 .add ('sel_type' , [list ,int ], default = [ii for ii in range (self .ntypes )], alias = 'pol_type' )\
387389 .add ('seed' , int )
388390 class_data = args .parse (jdata )
@@ -391,6 +393,14 @@ def __init__ (self, jdata, descrpt) :
391393 self .sel_type = class_data ['sel_type' ]
392394 self .fit_diag = class_data ['fit_diag' ]
393395 self .seed = class_data ['seed' ]
396+ self .diag_shift = class_data ['diag_shift' ]
397+ self .scale = class_data ['scale' ]
398+ if type (self .sel_type ) is not list :
399+ self .sel_type = [self .sel_type ]
400+ if type (self .diag_shift ) is not list :
401+ self .diag_shift = [self .diag_shift ]
402+ if type (self .scale ) is not list :
403+ self .scale = [self .scale ]
394404 self .dim_rot_mat_1 = descrpt .get_dim_rot_mat_1 ()
395405 self .dim_rot_mat = self .dim_rot_mat_1 * 3
396406 self .useBN = False
@@ -477,6 +487,10 @@ def build (self,
477487 final_layer = tf .matmul (rot_mat_i , final_layer , transpose_a = True )
478488 # nframes x natoms x 3 x 3
479489 final_layer = tf .reshape (final_layer , [tf .shape (inputs )[0 ], natoms [2 + type_i ], 3 , 3 ])
490+ # shift and scale
491+ sel_type_idx = self .sel_type .index (type_i )
492+ final_layer = final_layer * self .scale [sel_type_idx ]
493+ final_layer = final_layer + self .diag_shift [sel_type_idx ] * tf .eye (3 , batch_shape = [tf .shape (inputs )[0 ], natoms [2 + type_i ]], dtype = global_tf_float_precision )
480494
481495 # concat the results
482496 if count == 0 :
0 commit comments