Skip to content

Commit 93cd873

Browse files
Copilotnjzjz
andcommitted
feat(tf): implement decoupled out_bias and out_std in TensorFlow backend with Model-level application
Co-authored-by: njzjz <[email protected]>
1 parent 06a69f5 commit 93cd873

File tree

5 files changed

+198
-160
lines changed

5 files changed

+198
-160
lines changed

deepmd/tf/model/dos.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@ def build(
149149
t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string)
150150
t_od = tf.constant(self.numb_dos, name="output_dim", dtype=tf.int32)
151151

152+
# Initialize out_bias and out_std for DOS models
153+
self.init_out_stat(suffix=suffix)
154+
152155
coord = tf.reshape(coord_, [-1, natoms[1] * 3])
153156
atype = tf.reshape(atype_, [-1, natoms[1]])
154157
input_dict["nframes"] = tf.shape(coord)[0]
@@ -181,6 +184,31 @@ def build(
181184
atom_dos = self.fitting.build(
182185
dout, natoms, input_dict, reuse=reuse, suffix=suffix
183186
)
187+
188+
# Apply out_bias and out_std directly to DOS output
189+
# atom_dos shape: [nframes * nloc * numb_dos] for DOS models
190+
# t_out_bias shape: [1, ntypes, numb_dos], t_out_std shape: [1, ntypes, numb_dos]
191+
if hasattr(self, "t_out_bias") and hasattr(self, "t_out_std"):
192+
nframes = tf.shape(coord)[0]
193+
nloc = natoms[0]
194+
# Reshape atom_dos to [nframes, nloc, numb_dos] for bias/std application
195+
atom_dos_reshaped = tf.reshape(atom_dos, [nframes, nloc, self.numb_dos])
196+
197+
# Get bias and std for each atom type: [nframes, nloc, numb_dos]
198+
atype_flat = tf.reshape(atype, [nframes, nloc])
199+
bias_per_atom = tf.gather(
200+
self.t_out_bias[0], atype_flat
201+
) # [nframes, nloc, numb_dos]
202+
std_per_atom = tf.gather(
203+
self.t_out_std[0], atype_flat
204+
) # [nframes, nloc, numb_dos]
205+
206+
# Apply bias and std: dos = dos * std + bias
207+
atom_dos_reshaped = atom_dos_reshaped * std_per_atom + bias_per_atom
208+
209+
# Reshape back to original shape
210+
atom_dos = tf.reshape(atom_dos_reshaped, tf.shape(atom_dos))
211+
184212
self.atom_dos = atom_dos
185213

186214
dos_raw = atom_dos

deepmd/tf/model/ener.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ def build(
193193
t_mt = tf.constant(self.model_type, name="model_type", dtype=tf.string)
194194
t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string)
195195

196+
# Initialize out_bias and out_std for energy models
197+
self.init_out_stat(suffix=suffix)
198+
196199
if self.srtab is not None:
197200
tab_info, tab_data = self.srtab.get()
198201
self.tab_info = tf.get_variable(
@@ -253,6 +256,32 @@ def build(
253256
atom_ener = self.fitting.build(
254257
dout, natoms, input_dict, reuse=reuse, suffix=suffix
255258
)
259+
260+
# Apply out_bias and out_std directly to atom energy
261+
# atom_ener shape: [nframes * nloc] (for energy models, dim_out=1)
262+
# t_out_bias shape: [1, ntypes, 1], t_out_std shape: [1, ntypes, 1]
263+
# atype shape: [nframes, nloc]
264+
if hasattr(self, "t_out_bias") and hasattr(self, "t_out_std"):
265+
# Reshape atom_ener to [nframes, nloc, 1] to match bias/std application
266+
nframes = tf.shape(coord)[0]
267+
nloc = natoms[0]
268+
atom_ener_reshaped = tf.reshape(atom_ener, [nframes, nloc, 1])
269+
270+
# Get bias and std for each atom type: [nframes, nloc, 1]
271+
atype_flat = tf.reshape(atype, [nframes, nloc])
272+
bias_per_atom = tf.gather(
273+
self.t_out_bias[0], atype_flat
274+
) # [nframes, nloc, 1]
275+
std_per_atom = tf.gather(
276+
self.t_out_std[0], atype_flat
277+
) # [nframes, nloc, 1]
278+
279+
# Apply bias and std: energy = energy * std + bias
280+
atom_ener_reshaped = atom_ener_reshaped * std_per_atom + bias_per_atom
281+
282+
# Reshape back to original shape
283+
atom_ener = tf.reshape(atom_ener_reshaped, tf.shape(atom_ener))
284+
256285
self.atom_ener = atom_ener
257286

258287
if self.srtab is not None:

deepmd/tf/model/model.py

Lines changed: 15 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -851,9 +851,20 @@ def init_out_stat(self, suffix: str = "") -> None:
851851
else:
852852
dim_out = 1
853853

854-
# Initialize out_bias and out_std as numpy arrays first
855-
out_bias_data = np.zeros([1, ntypes, dim_out], dtype=GLOBAL_NP_FLOAT_PRECISION)
856-
out_std_data = np.ones([1, ntypes, dim_out], dtype=GLOBAL_NP_FLOAT_PRECISION)
854+
# Initialize out_bias and out_std as numpy arrays, preserving existing values if set
855+
if hasattr(self, "out_bias") and self.out_bias is not None:
856+
out_bias_data = self.out_bias.copy()
857+
else:
858+
out_bias_data = np.zeros(
859+
[1, ntypes, dim_out], dtype=GLOBAL_NP_FLOAT_PRECISION
860+
)
861+
862+
if hasattr(self, "out_std") and self.out_std is not None:
863+
out_std_data = self.out_std.copy()
864+
else:
865+
out_std_data = np.ones(
866+
[1, ntypes, dim_out], dtype=GLOBAL_NP_FLOAT_PRECISION
867+
)
857868

858869
# Create TensorFlow variables
859870
with tf.variable_scope("model_attr" + suffix, reuse=tf.AUTO_REUSE):
@@ -960,43 +971,7 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
960971
data = data.copy()
961972
check_version_compatibility(data.pop("@version", 2), 2, 1)
962973
descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix)
963-
if data["fitting"].get("@variables", {}).get("bias_atom_e") is not None:
964-
# careful: copy each level and don't modify the input array,
965-
# otherwise it will affect the original data
966-
# deepcopy is not used for performance reasons
967-
data["fitting"] = data["fitting"].copy()
968-
data["fitting"]["@variables"] = data["fitting"]["@variables"].copy()
969-
if (
970-
int(np.any(data["fitting"]["@variables"]["bias_atom_e"]))
971-
+ int(np.any(data["@variables"]["out_bias"]))
972-
> 1
973-
):
974-
raise ValueError(
975-
"fitting/@variables/bias_atom_e and @variables/out_bias should not be both non-zero"
976-
)
977-
# Improved handling for different shapes (dipole/polar vs energy)
978-
bias_atom_e_shape = data["fitting"]["@variables"]["bias_atom_e"].shape
979-
out_bias_data = data["@variables"]["out_bias"]
980-
981-
# For dipole/polar models, out_bias has shape [1, ntypes, 3]
982-
# but bias_atom_e has shape [ntypes] where embedding_width might != 3
983-
if len(bias_atom_e_shape) == 1 and len(out_bias_data.shape) == 3:
984-
# Convert out_bias to bias_atom_e shape safely
985-
# We sum over the output dimensions for energy-like models
986-
if out_bias_data.shape[2] == 1:
987-
# Energy case: out_bias [1, ntypes, 1] -> bias_atom_e [ntypes]
988-
bias_increment = out_bias_data[0, :, 0]
989-
else:
990-
# Dipole/Polar case: take norm or sum for compatibility
991-
# This is still a workaround, but safer than reshape
992-
bias_increment = np.linalg.norm(out_bias_data[0], axis=-1)
993-
else:
994-
# Fallback to original reshape if shapes are compatible
995-
bias_increment = out_bias_data.reshape(bias_atom_e_shape)
996-
997-
data["fitting"]["@variables"]["bias_atom_e"] = (
998-
data["fitting"]["@variables"]["bias_atom_e"] + bias_increment
999-
)
974+
# bias_atom_e and out_bias are now completely independent - no conversion needed
1000975
fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix)
1001976
# pass descriptor type embedding to model
1002977
if descriptor.explicit_ntypes:
@@ -1029,57 +1004,6 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
10291004
model.out_std = out_std
10301005
return model
10311006

1032-
def apply_out_stat(
1033-
self,
1034-
ret: dict[str, np.ndarray],
1035-
atype: np.ndarray,
1036-
) -> dict[str, np.ndarray]:
1037-
"""Apply the bias and std to the atomic output.
1038-
1039-
Parameters
1040-
----------
1041-
ret : dict[str, np.ndarray]
1042-
The returned dict by the forward_atomic method
1043-
atype : np.ndarray
1044-
The atom types. nf x nloc
1045-
1046-
Returns
1047-
-------
1048-
dict[str, np.ndarray]
1049-
The output with bias and std applied
1050-
"""
1051-
if self.out_bias is None:
1052-
return ret
1053-
1054-
# Get the output keys that need bias/std applied
1055-
fitting_output_def = (
1056-
self.fitting.fitting_output_def()
1057-
if hasattr(self.fitting, "fitting_output_def")
1058-
else {}
1059-
)
1060-
1061-
# Apply bias for each output
1062-
for kk in ret.keys():
1063-
if kk in ["mask"]: # Skip mask
1064-
continue
1065-
1066-
# Get the corresponding bias and std
1067-
# For now, we assume single output (idx=0), which works for most cases
1068-
bias_idx = 0
1069-
ntypes = self.get_ntypes()
1070-
1071-
if self.out_bias.shape[0] > bias_idx:
1072-
# Extract bias for this output: shape [ntypes, output_dim]
1073-
out_bias_kk = self.out_bias[bias_idx] # [ntypes, output_dim]
1074-
1075-
# Apply bias: ret[kk] shape is [nframes, nloc, output_dim]
1076-
# atype shape is [nframes, nloc]
1077-
# We need to index out_bias_kk with atype to get [nframes, nloc, output_dim]
1078-
bias_for_atoms = out_bias_kk[atype] # [nframes, nloc, output_dim]
1079-
ret[kk] = ret[kk] + bias_for_atoms
1080-
1081-
return ret
1082-
10831007
def serialize(self, suffix: str = "") -> dict:
10841008
"""Serialize the model.
10851009

deepmd/tf/model/tensor.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,9 @@ def build(
126126
t_ver = tf.constant(MODEL_VERSION, name="model_version", dtype=tf.string)
127127
t_od = tf.constant(self.get_out_size(), name="output_dim", dtype=tf.int32)
128128

129+
# Initialize out_bias and out_std for tensor models (dipole/polar)
130+
self.init_out_stat(suffix=suffix)
131+
129132
natomsel = sum(natoms[2 + type_i] for type_i in self.get_sel_type())
130133
nout = self.get_out_size()
131134

@@ -164,6 +167,31 @@ def build(
164167
output = self.fitting.build(
165168
dout, rot_mat, natoms, input_dict, reuse=reuse, suffix=suffix
166169
)
170+
171+
# Apply out_bias and out_std directly to tensor output
172+
# output shape: [nframes * natomsel * nout] for tensor models
173+
# t_out_bias shape: [1, ntypes, nout], t_out_std shape: [1, ntypes, nout]
174+
if hasattr(self, "t_out_bias") and hasattr(self, "t_out_std"):
175+
nframes = tf.shape(coord)[0]
176+
# Reshape output to [nframes, natomsel, nout] for bias/std application
177+
output_reshaped = tf.reshape(output, [nframes, natomsel, nout])
178+
179+
# Get atom types for selected atoms only (matching natomsel)
180+
atype_selected = self._get_selected_atype(atype, natoms)
181+
182+
# Get bias and std for each selected atom type: [nframes, natomsel, nout]
183+
bias_per_atom = tf.gather(
184+
self.t_out_bias[0], atype_selected
185+
) # [nframes, natomsel, nout]
186+
std_per_atom = tf.gather(
187+
self.t_out_std[0], atype_selected
188+
) # [nframes, natomsel, nout]
189+
190+
# Apply bias and std: output = output * std + bias
191+
output_reshaped = output_reshaped * std_per_atom + bias_per_atom
192+
193+
# Reshape back to original shape
194+
output = tf.reshape(output_reshaped, tf.shape(output))
167195
framesize = nout if "global" in self.model_type else natomsel * nout
168196
output = tf.reshape(
169197
output, [-1, framesize], name="o_" + self.model_type + suffix
@@ -206,6 +234,24 @@ def build(
206234

207235
return model_dict
208236

237+
def _get_selected_atype(self, atype, natoms):
238+
"""Get atom types for selected atoms only (matching tensor model selection)."""
239+
# For tensor models, the fitting output corresponds to selected atom types
240+
# atype shape: [nframes, nloc]
241+
# We need to extract atom types that match the natomsel count
242+
243+
# Simplified approach: take the first natomsel atoms from each frame
244+
# This works because natoms and descriptor arrangement should be consistent
245+
nframes = tf.shape(atype)[0]
246+
selected_types = self.get_sel_type()
247+
natomsel = sum(natoms[2 + type_i] for type_i in selected_types)
248+
249+
# Take the first natomsel atoms from each frame
250+
# This assumes the atom ordering is consistent with how fitting produces output
251+
atype_selected = atype[:, :natomsel] # [nframes, natomsel]
252+
253+
return atype_selected
254+
209255
def init_variables(
210256
self,
211257
graph: tf.Graph,

0 commit comments

Comments
 (0)