Skip to content

Commit 0953edf

Browse files
Copilotnjzjz
andcommitted
refactor(tf): consolidate out_bias/out_std application into shared method
Co-authored-by: njzjz <[email protected]>
1 parent 93cd873 commit 0953edf

File tree

4 files changed

+71
-119
lines changed

4 files changed

+71
-119
lines changed

deepmd/tf/model/dos.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -186,28 +186,7 @@ def build(
186186
)
187187

188188
# Apply out_bias and out_std directly to DOS output
189-
# atom_dos shape: [nframes * nloc * numb_dos] for DOS models
190-
# t_out_bias shape: [1, ntypes, numb_dos], t_out_std shape: [1, ntypes, numb_dos]
191-
if hasattr(self, "t_out_bias") and hasattr(self, "t_out_std"):
192-
nframes = tf.shape(coord)[0]
193-
nloc = natoms[0]
194-
# Reshape atom_dos to [nframes, nloc, numb_dos] for bias/std application
195-
atom_dos_reshaped = tf.reshape(atom_dos, [nframes, nloc, self.numb_dos])
196-
197-
# Get bias and std for each atom type: [nframes, nloc, numb_dos]
198-
atype_flat = tf.reshape(atype, [nframes, nloc])
199-
bias_per_atom = tf.gather(
200-
self.t_out_bias[0], atype_flat
201-
) # [nframes, nloc, numb_dos]
202-
std_per_atom = tf.gather(
203-
self.t_out_std[0], atype_flat
204-
) # [nframes, nloc, numb_dos]
205-
206-
# Apply bias and std: dos = dos * std + bias
207-
atom_dos_reshaped = atom_dos_reshaped * std_per_atom + bias_per_atom
208-
209-
# Reshape back to original shape
210-
atom_dos = tf.reshape(atom_dos_reshaped, tf.shape(atom_dos))
189+
atom_dos = self._apply_out_bias_std(atom_dos, atype, natoms, coord)
211190

212191
self.atom_dos = atom_dos
213192

deepmd/tf/model/ener.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -258,29 +258,7 @@ def build(
258258
)
259259

260260
# Apply out_bias and out_std directly to atom energy
261-
# atom_ener shape: [nframes * nloc] (for energy models, dim_out=1)
262-
# t_out_bias shape: [1, ntypes, 1], t_out_std shape: [1, ntypes, 1]
263-
# atype shape: [nframes, nloc]
264-
if hasattr(self, "t_out_bias") and hasattr(self, "t_out_std"):
265-
# Reshape atom_ener to [nframes, nloc, 1] to match bias/std application
266-
nframes = tf.shape(coord)[0]
267-
nloc = natoms[0]
268-
atom_ener_reshaped = tf.reshape(atom_ener, [nframes, nloc, 1])
269-
270-
# Get bias and std for each atom type: [nframes, nloc, 1]
271-
atype_flat = tf.reshape(atype, [nframes, nloc])
272-
bias_per_atom = tf.gather(
273-
self.t_out_bias[0], atype_flat
274-
) # [nframes, nloc, 1]
275-
std_per_atom = tf.gather(
276-
self.t_out_std[0], atype_flat
277-
) # [nframes, nloc, 1]
278-
279-
# Apply bias and std: energy = energy * std + bias
280-
atom_ener_reshaped = atom_ener_reshaped * std_per_atom + bias_per_atom
281-
282-
# Reshape back to original shape
283-
atom_ener = tf.reshape(atom_ener_reshaped, tf.shape(atom_ener))
261+
atom_ener = self._apply_out_bias_std(atom_ener, atype, natoms, coord)
284262

285263
self.atom_ener = atom_ener
286264

deepmd/tf/model/model.py

Lines changed: 65 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -823,33 +823,21 @@ def init_out_stat(self, suffix: str = "") -> None:
823823
"""Initialize the output bias and std variables."""
824824
ntypes = self.get_ntypes()
825825

826-
# Get output dimension from fitting serialization, with fallback
827-
try:
828-
dict_fit = self.fitting.serialize(suffix=suffix)
829-
dim_out = dict_fit.get("dim_out", 1)
830-
except (AttributeError, TypeError):
831-
# Fallback to default dimensions for different fitting types
832-
from deepmd.tf.fit.dipole import (
833-
DipoleFittingSeA,
834-
)
835-
from deepmd.tf.fit.dos import (
836-
DOSFitting,
837-
)
838-
from deepmd.tf.fit.ener import (
839-
EnerFitting,
840-
)
841-
from deepmd.tf.fit.polar import (
842-
PolarFittingSeA,
843-
)
844-
845-
if isinstance(self.fitting, EnerFitting):
846-
dim_out = 1
847-
elif isinstance(self.fitting, (DipoleFittingSeA, PolarFittingSeA)):
848-
dim_out = 3
849-
elif isinstance(self.fitting, DOSFitting):
850-
dim_out = getattr(self.fitting, "numb_dos", 1)
851-
else:
852-
dim_out = 1
826+
# Determine output dimension based on model type instead of fitting type
827+
if hasattr(self, "model_type"):
828+
model_type = self.model_type
829+
else:
830+
# Fallback to fitting type for compatibility
831+
model_type = getattr(self.fitting, "model_type", "ener")
832+
833+
if model_type == "ener":
834+
dim_out = 1
835+
elif model_type in ["dipole", "polar"]:
836+
dim_out = 3
837+
elif model_type == "dos":
838+
dim_out = getattr(self.fitting, "numb_dos", 1)
839+
else:
840+
dim_out = 1
853841

854842
# Initialize out_bias and out_std as numpy arrays, preserving existing values if set
855843
if hasattr(self, "out_bias") and self.out_bias is not None:
@@ -887,31 +875,57 @@ def init_out_stat(self, suffix: str = "") -> None:
887875
self.out_bias = out_bias_data
888876
self.out_std = out_std_data
889877

890-
def get_out_bias(self) -> np.ndarray:
891-
"""Get the output bias."""
892-
return self.out_bias
893-
894-
def get_out_std(self) -> np.ndarray:
895-
"""Get the output standard deviation."""
896-
return self.out_std
897-
898-
def set_out_bias(self, out_bias: np.ndarray) -> None:
899-
"""Set the output bias."""
900-
self.out_bias = out_bias
901-
if hasattr(self, "t_out_bias"):
902-
# Note: TensorFlow variable assignment would require a session context in TF 1.x
903-
# For TF 2.x, the variable assignment happens differently
904-
# Here we just update the numpy array, and TF variables are updated when rebuilt
905-
pass
878+
def _apply_out_bias_std(self, output, atype, natoms, coord, selected_atype=None):
879+
"""Apply output bias and standard deviation to the model output.
906880
907-
def set_out_std(self, out_std: np.ndarray) -> None:
908-
"""Set the output standard deviation."""
909-
self.out_std = out_std
910-
if hasattr(self, "t_out_std"):
911-
# Note: TensorFlow variable assignment would require a session context in TF 1.x
912-
# For TF 2.x, the variable assignment happens differently
913-
# Here we just update the numpy array, and TF variables are updated when rebuilt
914-
pass
881+
Parameters
882+
----------
883+
output : tf.Tensor
884+
The model output tensor
885+
atype : tf.Tensor
886+
Atom types with shape [nframes, nloc]
887+
natoms : list[int]
888+
Number of atoms [nloc, ntypes, ...]
889+
coord : tf.Tensor
890+
Coordinates for getting nframes
891+
selected_atype : tf.Tensor, optional
892+
Selected atom types for tensor models. If None, uses all atoms.
893+
894+
Returns
895+
-------
896+
tf.Tensor
897+
Output with bias and std applied
898+
"""
899+
nframes = tf.shape(coord)[0]
900+
901+
if selected_atype is not None:
902+
# For tensor models (dipole, polar) with selected atoms
903+
natomsel = tf.shape(selected_atype)[1]
904+
nout = self.get_out_size() # Use the model's output size method
905+
output_reshaped = tf.reshape(output, [nframes, natomsel, nout])
906+
atype_for_gather = selected_atype
907+
else:
908+
# For energy and DOS models with all atoms
909+
nloc = natoms[0]
910+
if hasattr(self, "numb_dos"):
911+
# DOS model: output shape [nframes * nloc * numb_dos]
912+
nout = self.numb_dos
913+
output_reshaped = tf.reshape(output, [nframes, nloc, nout])
914+
else:
915+
# Energy model: output shape [nframes * nloc]
916+
nout = 1
917+
output_reshaped = tf.reshape(output, [nframes, nloc, 1])
918+
atype_for_gather = tf.reshape(atype, [nframes, nloc])
919+
920+
# Get bias and std for each atom type
921+
bias_per_atom = tf.gather(self.t_out_bias[0], atype_for_gather)
922+
std_per_atom = tf.gather(self.t_out_std[0], atype_for_gather)
923+
924+
# Apply bias and std: output = output * std + bias
925+
output_reshaped = output_reshaped * std_per_atom + bias_per_atom
926+
927+
# Reshape back to original shape
928+
return tf.reshape(output_reshaped, tf.shape(output))
915929

916930
@classmethod
917931
def update_sel(

deepmd/tf/model/tensor.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -169,29 +169,10 @@ def build(
169169
)
170170

171171
# Apply out_bias and out_std directly to tensor output
172-
# output shape: [nframes * natomsel * nout] for tensor models
173-
# t_out_bias shape: [1, ntypes, nout], t_out_std shape: [1, ntypes, nout]
174-
if hasattr(self, "t_out_bias") and hasattr(self, "t_out_std"):
175-
nframes = tf.shape(coord)[0]
176-
# Reshape output to [nframes, natomsel, nout] for bias/std application
177-
output_reshaped = tf.reshape(output, [nframes, natomsel, nout])
178-
179-
# Get atom types for selected atoms only (matching natomsel)
180-
atype_selected = self._get_selected_atype(atype, natoms)
181-
182-
# Get bias and std for each selected atom type: [nframes, natomsel, nout]
183-
bias_per_atom = tf.gather(
184-
self.t_out_bias[0], atype_selected
185-
) # [nframes, natomsel, nout]
186-
std_per_atom = tf.gather(
187-
self.t_out_std[0], atype_selected
188-
) # [nframes, natomsel, nout]
189-
190-
# Apply bias and std: output = output * std + bias
191-
output_reshaped = output_reshaped * std_per_atom + bias_per_atom
192-
193-
# Reshape back to original shape
194-
output = tf.reshape(output_reshaped, tf.shape(output))
172+
atype_selected = self._get_selected_atype(atype, natoms)
173+
output = self._apply_out_bias_std(
174+
output, atype, natoms, coord, selected_atype=atype_selected
175+
)
195176
framesize = nout if "global" in self.model_type else natomsel * nout
196177
output = tf.reshape(
197178
output, [-1, framesize], name="o_" + self.model_type + suffix

0 commit comments

Comments
 (0)