@@ -823,33 +823,21 @@ def init_out_stat(self, suffix: str = "") -> None:
823823 """Initialize the output bias and std variables."""
824824 ntypes = self .get_ntypes ()
825825
826- # Get output dimension from fitting serialization, with fallback
827- try :
828- dict_fit = self .fitting .serialize (suffix = suffix )
829- dim_out = dict_fit .get ("dim_out" , 1 )
830- except (AttributeError , TypeError ):
831- # Fallback to default dimensions for different fitting types
832- from deepmd .tf .fit .dipole import (
833- DipoleFittingSeA ,
834- )
835- from deepmd .tf .fit .dos import (
836- DOSFitting ,
837- )
838- from deepmd .tf .fit .ener import (
839- EnerFitting ,
840- )
841- from deepmd .tf .fit .polar import (
842- PolarFittingSeA ,
843- )
844-
845- if isinstance (self .fitting , EnerFitting ):
846- dim_out = 1
847- elif isinstance (self .fitting , (DipoleFittingSeA , PolarFittingSeA )):
848- dim_out = 3
849- elif isinstance (self .fitting , DOSFitting ):
850- dim_out = getattr (self .fitting , "numb_dos" , 1 )
851- else :
852- dim_out = 1
826+ # Determine output dimension based on model type instead of fitting type
827+ if hasattr (self , "model_type" ):
828+ model_type = self .model_type
829+ else :
830+ # Fallback to fitting type for compatibility
831+ model_type = getattr (self .fitting , "model_type" , "ener" )
832+
833+ if model_type == "ener" :
834+ dim_out = 1
835+ elif model_type in ["dipole" , "polar" ]:
836+ dim_out = 3
837+ elif model_type == "dos" :
838+ dim_out = getattr (self .fitting , "numb_dos" , 1 )
839+ else :
840+ dim_out = 1
853841
854842 # Initialize out_bias and out_std as numpy arrays, preserving existing values if set
855843 if hasattr (self , "out_bias" ) and self .out_bias is not None :
@@ -887,31 +875,57 @@ def init_out_stat(self, suffix: str = "") -> None:
887875 self .out_bias = out_bias_data
888876 self .out_std = out_std_data
889877
890- def get_out_bias (self ) -> np .ndarray :
891- """Get the output bias."""
892- return self .out_bias
893-
894- def get_out_std (self ) -> np .ndarray :
895- """Get the output standard deviation."""
896- return self .out_std
897-
898- def set_out_bias (self , out_bias : np .ndarray ) -> None :
899- """Set the output bias."""
900- self .out_bias = out_bias
901- if hasattr (self , "t_out_bias" ):
902- # Note: TensorFlow variable assignment would require a session context in TF 1.x
903- # For TF 2.x, the variable assignment happens differently
904- # Here we just update the numpy array, and TF variables are updated when rebuilt
905- pass
878+ def _apply_out_bias_std (self , output , atype , natoms , coord , selected_atype = None ):
879+ """Apply output bias and standard deviation to the model output.
906880
907- def set_out_std (self , out_std : np .ndarray ) -> None :
908- """Set the output standard deviation."""
909- self .out_std = out_std
910- if hasattr (self , "t_out_std" ):
911- # Note: TensorFlow variable assignment would require a session context in TF 1.x
912- # For TF 2.x, the variable assignment happens differently
913- # Here we just update the numpy array, and TF variables are updated when rebuilt
914- pass
881+ Parameters
882+ ----------
883+ output : tf.Tensor
884+ The model output tensor
885+ atype : tf.Tensor
886+ Atom types with shape [nframes, nloc]
887+ natoms : list[int]
888+ Number of atoms [nloc, ntypes, ...]
889+ coord : tf.Tensor
890+ Coordinates for getting nframes
891+ selected_atype : tf.Tensor, optional
892+ Selected atom types for tensor models. If None, uses all atoms.
893+
894+ Returns
895+ -------
896+ tf.Tensor
897+ Output with bias and std applied
898+ """
899+ nframes = tf .shape (coord )[0 ]
900+
901+ if selected_atype is not None :
902+ # For tensor models (dipole, polar) with selected atoms
903+ natomsel = tf .shape (selected_atype )[1 ]
904+ nout = self .get_out_size () # Use the model's output size method
905+ output_reshaped = tf .reshape (output , [nframes , natomsel , nout ])
906+ atype_for_gather = selected_atype
907+ else :
908+ # For energy and DOS models with all atoms
909+ nloc = natoms [0 ]
910+ if hasattr (self , "numb_dos" ):
911+ # DOS model: output shape [nframes * nloc * numb_dos]
912+ nout = self .numb_dos
913+ output_reshaped = tf .reshape (output , [nframes , nloc , nout ])
914+ else :
915+ # Energy model: output shape [nframes * nloc]
916+ nout = 1
917+ output_reshaped = tf .reshape (output , [nframes , nloc , 1 ])
918+ atype_for_gather = tf .reshape (atype , [nframes , nloc ])
919+
920+ # Get bias and std for each atom type
921+ bias_per_atom = tf .gather (self .t_out_bias [0 ], atype_for_gather )
922+ std_per_atom = tf .gather (self .t_out_std [0 ], atype_for_gather )
923+
924+ # Apply bias and std: output = output * std + bias
925+ output_reshaped = output_reshaped * std_per_atom + bias_per_atom
926+
927+ # Reshape back to original shape
928+ return tf .reshape (output_reshaped , tf .shape (output ))
915929
916930 @classmethod
917931 def update_sel (
0 commit comments