@@ -851,9 +851,20 @@ def init_out_stat(self, suffix: str = "") -> None:
851851 else :
852852 dim_out = 1
853853
854- # Initialize out_bias and out_std as numpy arrays first
855- out_bias_data = np .zeros ([1 , ntypes , dim_out ], dtype = GLOBAL_NP_FLOAT_PRECISION )
856- out_std_data = np .ones ([1 , ntypes , dim_out ], dtype = GLOBAL_NP_FLOAT_PRECISION )
854+ # Initialize out_bias and out_std as numpy arrays, preserving existing values if set
855+ if hasattr (self , "out_bias" ) and self .out_bias is not None :
856+ out_bias_data = self .out_bias .copy ()
857+ else :
858+ out_bias_data = np .zeros (
859+ [1 , ntypes , dim_out ], dtype = GLOBAL_NP_FLOAT_PRECISION
860+ )
861+
862+ if hasattr (self , "out_std" ) and self .out_std is not None :
863+ out_std_data = self .out_std .copy ()
864+ else :
865+ out_std_data = np .ones (
866+ [1 , ntypes , dim_out ], dtype = GLOBAL_NP_FLOAT_PRECISION
867+ )
857868
858869 # Create TensorFlow variables
859870 with tf .variable_scope ("model_attr" + suffix , reuse = tf .AUTO_REUSE ):
@@ -960,43 +971,7 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
960971 data = data .copy ()
961972 check_version_compatibility (data .pop ("@version" , 2 ), 2 , 1 )
962973 descriptor = Descriptor .deserialize (data .pop ("descriptor" ), suffix = suffix )
963- if data ["fitting" ].get ("@variables" , {}).get ("bias_atom_e" ) is not None :
964- # careful: copy each level and don't modify the input array,
965- # otherwise it will affect the original data
966- # deepcopy is not used for performance reasons
967- data ["fitting" ] = data ["fitting" ].copy ()
968- data ["fitting" ]["@variables" ] = data ["fitting" ]["@variables" ].copy ()
969- if (
970- int (np .any (data ["fitting" ]["@variables" ]["bias_atom_e" ]))
971- + int (np .any (data ["@variables" ]["out_bias" ]))
972- > 1
973- ):
974- raise ValueError (
975- "fitting/@variables/bias_atom_e and @variables/out_bias should not be both non-zero"
976- )
977- # Improved handling for different shapes (dipole/polar vs energy)
978- bias_atom_e_shape = data ["fitting" ]["@variables" ]["bias_atom_e" ].shape
979- out_bias_data = data ["@variables" ]["out_bias" ]
980-
981- # For dipole/polar models, out_bias has shape [1, ntypes, 3]
982- # but bias_atom_e has shape [ntypes] where embedding_width might != 3
983- if len (bias_atom_e_shape ) == 1 and len (out_bias_data .shape ) == 3 :
984- # Convert out_bias to bias_atom_e shape safely
985- # We sum over the output dimensions for energy-like models
986- if out_bias_data .shape [2 ] == 1 :
987- # Energy case: out_bias [1, ntypes, 1] -> bias_atom_e [ntypes]
988- bias_increment = out_bias_data [0 , :, 0 ]
989- else :
990- # Dipole/Polar case: take norm or sum for compatibility
991- # This is still a workaround, but safer than reshape
992- bias_increment = np .linalg .norm (out_bias_data [0 ], axis = - 1 )
993- else :
994- # Fallback to original reshape if shapes are compatible
995- bias_increment = out_bias_data .reshape (bias_atom_e_shape )
996-
997- data ["fitting" ]["@variables" ]["bias_atom_e" ] = (
998- data ["fitting" ]["@variables" ]["bias_atom_e" ] + bias_increment
999- )
974+ # bias_atom_e and out_bias are now completely independent - no conversion needed
1000975 fitting = Fitting .deserialize (data .pop ("fitting" ), suffix = suffix )
1001976 # pass descriptor type embedding to model
1002977 if descriptor .explicit_ntypes :
@@ -1029,57 +1004,6 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
10291004 model .out_std = out_std
10301005 return model
10311006
1032- def apply_out_stat (
1033- self ,
1034- ret : dict [str , np .ndarray ],
1035- atype : np .ndarray ,
1036- ) -> dict [str , np .ndarray ]:
1037- """Apply the bias and std to the atomic output.
1038-
1039- Parameters
1040- ----------
1041- ret : dict[str, np.ndarray]
1042- The returned dict by the forward_atomic method
1043- atype : np.ndarray
1044- The atom types. nf x nloc
1045-
1046- Returns
1047- -------
1048- dict[str, np.ndarray]
1049- The output with bias and std applied
1050- """
1051- if self .out_bias is None :
1052- return ret
1053-
1054- # Get the output keys that need bias/std applied
1055- fitting_output_def = (
1056- self .fitting .fitting_output_def ()
1057- if hasattr (self .fitting , "fitting_output_def" )
1058- else {}
1059- )
1060-
1061- # Apply bias for each output
1062- for kk in ret .keys ():
1063- if kk in ["mask" ]: # Skip mask
1064- continue
1065-
1066- # Get the corresponding bias and std
1067- # For now, we assume single output (idx=0), which works for most cases
1068- bias_idx = 0
1069- ntypes = self .get_ntypes ()
1070-
1071- if self .out_bias .shape [0 ] > bias_idx :
1072- # Extract bias for this output: shape [ntypes, output_dim]
1073- out_bias_kk = self .out_bias [bias_idx ] # [ntypes, output_dim]
1074-
1075- # Apply bias: ret[kk] shape is [nframes, nloc, output_dim]
1076- # atype shape is [nframes, nloc]
1077- # We need to index out_bias_kk with atype to get [nframes, nloc, output_dim]
1078- bias_for_atoms = out_bias_kk [atype ] # [nframes, nloc, output_dim]
1079- ret [kk ] = ret [kk ] + bias_for_atoms
1080-
1081- return ret
1082-
10831007 def serialize (self , suffix : str = "" ) -> dict :
10841008 """Serialize the model.
10851009
0 commit comments