Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions deepmd/pt/loss/ener_spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,22 @@ def forward(
rmse_ae.detach(), find_atom_ener
)

if self.has_v and "virial" in model_pred and "virial" in label:
find_virial = label.get("find_virial", 0.0)
pref_v = pref_v * find_virial
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
loss += atom_norm * (pref_v * l2_virial_loss)
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)

if not self.inference:
more_loss["rmse"] = torch.sqrt(loss.detach())
return model_pred, loss, more_loss
Expand Down
23 changes: 20 additions & 3 deletions deepmd/pt/model/model/make_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def __init__(
self,
*args: Any,
# underscore to prevent conflict with normal inputs
atomic_model_: T_AtomicModel | None = None,
atomic_model_: T_AtomicModel | None = None, # type: ignore
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
if atomic_model_ is not None:
self.atomic_model: T_AtomicModel = atomic_model_
self.atomic_model: T_AtomicModel = atomic_model_ # type: ignore
else:
self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs)
self.atomic_model: T_AtomicModel = T_AtomicModel(*args, **kwargs) # type: ignore
self.precision_dict = PRECISION_DICT
self.reverse_precision_dict = RESERVED_PRECISION_DICT
self.global_pt_float_precision = GLOBAL_PT_FLOAT_PRECISION
Expand Down Expand Up @@ -138,6 +138,7 @@ def forward_common(
fparam: torch.Tensor | None = None,
aparam: torch.Tensor | None = None,
do_atomic_virial: bool = False,
coord_corr_for_virial: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Return model prediction.

Expand All @@ -156,6 +157,9 @@ def forward_common(
atomic parameter. nf x nloc x nda
do_atomic_virial
If calculate the atomic virial.
coord_corr_for_virial
The coordinates correction of the atoms for virial.
shape: nf x (nloc x 3)

Returns
-------
Expand Down Expand Up @@ -183,6 +187,14 @@ def forward_common(
mixed_types=True,
box=bb,
)
if coord_corr_for_virial is not None:
coord_corr_for_virial = coord_corr_for_virial.to(cc.dtype)
extended_coord_corr = torch.gather(
coord_corr_for_virial, 1, mapping.unsqueeze(-1).expand(-1, -1, 3)
)
else:
extended_coord_corr = None

model_predict_lower = self.forward_common_lower(
extended_coord,
extended_atype,
Expand All @@ -191,6 +203,7 @@ def forward_common(
do_atomic_virial=do_atomic_virial,
fparam=fp,
aparam=ap,
extended_coord_corr=extended_coord_corr,
)
model_predict = communicate_extended_output(
model_predict_lower,
Expand Down Expand Up @@ -247,6 +260,7 @@ def forward_common_lower(
do_atomic_virial: bool = False,
comm_dict: dict[str, torch.Tensor] | None = None,
extra_nlist_sort: bool = False,
extended_coord_corr: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Return model prediction. Lower interface that takes
extended atomic coordinates and types, nlist, and mapping
Expand All @@ -273,6 +287,8 @@ def forward_common_lower(
The data needed for communication for parallel inference.
extra_nlist_sort
whether to forcibly sort the nlist.
extended_coord_corr
coordinates correction for virial in extended region. nf x (nall x 3)

Returns
-------
Expand Down Expand Up @@ -305,6 +321,7 @@ def forward_common_lower(
do_atomic_virial=do_atomic_virial,
create_graph=self.training,
mask=atomic_ret["mask"] if "mask" in atomic_ret else None,
extended_coord_corr=extended_coord_corr,
)
model_predict = self.output_type_cast(model_predict, input_prec)
return model_predict
Expand Down
86 changes: 71 additions & 15 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,31 @@ def __init__(

def process_spin_input(
self, coord: torch.Tensor, atype: torch.Tensor, spin: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Generate virtual coordinates and types, concat into the input."""
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Generate virtual coordinates and types, concat into the input.

Returns
-------
coord_spin : torch.Tensor
Concatenated coordinates with shape (nframes, 2*nloc, 3).
atype_spin : torch.Tensor
Concatenated atom types with shape (nframes, 2*nloc).
coord_corr : torch.Tensor
Coordinate correction for virial with shape (nframes, 2*nloc, 3).
"""
nframes, nloc = atype.shape
coord = coord.reshape(nframes, nloc, 3)
spin = spin.reshape(nframes, nloc, 3)
atype_spin = torch.concat([atype, atype + self.ntypes_real], dim=-1)
virtual_coord = coord + spin * (self.virtual_scale_mask.to(atype.device))[
atype
].reshape([nframes, nloc, 1])
# spin_dist = s_i * \mu_i
spin_dist = spin * (self.virtual_scale_mask.to(atype.device))[atype].reshape(
[nframes, nloc, 1]
)
virtual_coord = coord + spin_dist
coord_spin = torch.concat([coord, virtual_coord], dim=-2)
return coord_spin, atype_spin
# for spin virial corr
coord_corr = torch.concat([torch.zeros_like(coord), -spin_dist], dim=-2)
return coord_spin, atype_spin, coord_corr

def process_spin_input_lower(
self,
Expand All @@ -72,24 +86,47 @@ def process_spin_input_lower(
extended_spin: torch.Tensor,
nlist: torch.Tensor,
mapping: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None]:
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor
]:
"""
Add `extended_spin` into `extended_coord` to generate virtual atoms, and extend `nlist` and `mapping`.
Note that the final `extended_coord_updated` with shape [nframes, nall + nall, 3] has the following order:

Returns
-------
extended_coord_updated : torch.Tensor
Updated coordinates with virtual atoms, shape (nframes, 2*nall, 3).
extended_atype_updated : torch.Tensor
Updated atom types with virtual atoms, shape (nframes, 2*nall).
nlist_updated : torch.Tensor
Updated neighbor list including virtual atoms.
mapping_updated : torch.Tensor or None
Updated mapping indices, or None if input mapping is None.
extended_coord_corr : torch.Tensor
Coordinate correction for virial with shape (nframes, 2*nall, 3).

Notes
-----
The final `extended_coord_updated` with shape [nframes, nall + nall, 3] has the following order:
- [:, :nloc]: original nloc real atoms.
- [:, nloc: nloc + nloc]: virtual atoms corresponding to nloc real atoms.
- [:, nloc + nloc: nloc + nall]: ghost real atoms.
- [:, nloc + nall: nall + nall]: virtual atoms corresponding to ghost real atoms.
"""
nframes, nall = extended_coord.shape[:2]
nloc = nlist.shape[1]
virtual_extended_coord = extended_coord + extended_spin * (
extended_spin_dist = extended_spin * (
self.virtual_scale_mask.to(extended_atype.device)
)[extended_atype].reshape([nframes, nall, 1])
virtual_extended_coord = extended_coord + extended_spin_dist
virtual_extended_atype = extended_atype + self.ntypes_real
extended_coord_updated = concat_switch_virtual(
extended_coord, virtual_extended_coord, nloc
)
# for spin virial corr
extended_coord_corr = concat_switch_virtual(
torch.zeros_like(extended_coord), -extended_spin_dist, nloc
)
extended_atype_updated = concat_switch_virtual(
extended_atype, virtual_extended_atype, nloc
)
Expand All @@ -105,6 +142,7 @@ def process_spin_input_lower(
extended_atype_updated,
nlist_updated,
mapping_updated,
extended_coord_corr,
)

def process_spin_output(
Expand Down Expand Up @@ -376,7 +414,7 @@ def spin_sampled_func() -> list[dict[str, Any]]:
sampled = sampled_func()
spin_sampled = []
for sys in sampled:
coord_updated, atype_updated = self.process_spin_input(
coord_updated, atype_updated, _ = self.process_spin_input(
sys["coord"], sys["atype"], sys["spin"]
)
tmp_dict = {
Expand Down Expand Up @@ -407,7 +445,9 @@ def forward_common(
do_atomic_virial: bool = False,
) -> dict[str, torch.Tensor]:
nframes, nloc = atype.shape
coord_updated, atype_updated = self.process_spin_input(coord, atype, spin)
coord_updated, atype_updated, coord_corr_for_virial = self.process_spin_input(
coord, atype, spin
)
if aparam is not None:
aparam = self.expand_aparam(aparam, nloc * 2)
model_ret = self.backbone_model.forward_common(
Expand All @@ -417,6 +457,7 @@ def forward_common(
fparam=fparam,
aparam=aparam,
do_atomic_virial=do_atomic_virial,
coord_corr_for_virial=coord_corr_for_virial,
)
model_output_type = self.backbone_model.model_output_type()
if "mask" in model_output_type:
Expand All @@ -439,7 +480,7 @@ def forward_common(
) = self.process_spin_output(
atype,
model_ret[f"{var_name}_derv_c"],
add_mag=False,
add_mag=True,
virtual_scale=False,
)
return model_ret
Expand All @@ -463,6 +504,7 @@ def forward_common_lower(
extended_atype_updated,
nlist_updated,
mapping_updated,
extended_coord_corr_for_virial,
) = self.process_spin_input_lower(
extended_coord, extended_atype, extended_spin, nlist, mapping=mapping
)
Expand All @@ -478,6 +520,7 @@ def forward_common_lower(
do_atomic_virial=do_atomic_virial,
comm_dict=comm_dict,
extra_nlist_sort=extra_nlist_sort,
extended_coord_corr=extended_coord_corr_for_virial,
)
model_output_type = self.backbone_model.model_output_type()
if "mask" in model_output_type:
Expand All @@ -503,7 +546,7 @@ def forward_common_lower(
extended_atype,
model_ret[f"{var_name}_derv_c"],
nloc,
add_mag=False,
add_mag=True,
virtual_scale=False,
)
return model_ret
Expand Down Expand Up @@ -550,6 +593,11 @@ def translated_output_def(self) -> dict[str, Any]:
output_def["force"].squeeze(-2)
output_def["force_mag"] = deepcopy(out_def_data["energy_derv_r_mag"])
output_def["force_mag"].squeeze(-2)
if self.do_grad_c("energy"):
output_def["virial"] = deepcopy(out_def_data["energy_derv_c_redu"])
output_def["virial"].squeeze(-2)
output_def["atom_virial"] = deepcopy(out_def_data["energy_derv_c"])
output_def["atom_virial"].squeeze(-3)
return output_def

def forward(
Expand Down Expand Up @@ -578,7 +626,10 @@ def forward(
if self.backbone_model.do_grad_r("energy"):
model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2)
model_predict["force_mag"] = model_ret["energy_derv_r_mag"].squeeze(-2)
# not support virial by far
if self.backbone_model.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3)
return model_predict

@torch.jit.export
Expand Down Expand Up @@ -615,5 +666,10 @@ def forward_lower(
model_predict["extended_force_mag"] = model_ret[
"energy_derv_r_mag"
].squeeze(-2)
# not support virial by far
if self.backbone_model.do_grad_c("energy"):
model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2)
if do_atomic_virial:
model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze(
-3
)
return model_predict
7 changes: 7 additions & 0 deletions deepmd/pt/model/model/transform_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def fit_output_to_model_output(
do_atomic_virial: bool = False,
create_graph: bool = True,
mask: torch.Tensor | None = None,
extended_coord_corr: torch.Tensor | None = None,
) -> dict[str, torch.Tensor]:
"""Transform the output of the fitting network to
the model output.
Expand Down Expand Up @@ -192,6 +193,12 @@ def fit_output_to_model_output(
model_ret[kk_derv_r] = dr
if vdef.c_differentiable:
assert dc is not None
if extended_coord_corr is not None:
dc_corr = (
dr.squeeze(-2).unsqueeze(-1)
@ extended_coord_corr.unsqueeze(-2)
).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005
dc = dc + dc_corr
model_ret[kk_derv_c] = dc
model_ret[kk_derv_c + "_redu"] = torch.sum(
model_ret[kk_derv_c].to(redu_prec), dim=1
Expand Down
12 changes: 6 additions & 6 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2602,9 +2602,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi {
for (int j = 0; j < natoms * 3; j++) {
force_mag[i][j] = force_mag_flat[i * natoms * 3 + j];
}
// for (int j = 0; j < 9; j++) {
// virial[i][j] = virial_flat[i * 9 + j];
// }
for (int j = 0; j < 9; j++) {
virial[i][j] = virial_flat[i * 9 + j];
}
}
};
/**
Expand Down Expand Up @@ -2705,9 +2705,9 @@ class DeepSpinModelDevi : public DeepBaseModelDevi {
for (int j = 0; j < natoms * 3; j++) {
force_mag[i][j] = force_mag_flat[i * natoms * 3 + j];
}
// for (int j = 0; j < 9; j++) {
// virial[i][j] = virial_flat[i * 9 + j];
// }
for (int j = 0; j < 9; j++) {
virial[i][j] = virial_flat[i * 9 + j];
}
for (int j = 0; j < natoms; j++) {
atom_energy[i][j] = atom_energy_flat[i * natoms + j];
}
Expand Down
10 changes: 5 additions & 5 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -862,11 +862,11 @@ void DP_DeepSpinModelDeviCompute_variant(DP_DeepSpinModelDevi* dp,
flatten_vector(fm_flat, fm);
std::copy(fm_flat.begin(), fm_flat.end(), force_mag);
}
// if (virial) {
// std::vector<VALUETYPE> v_flat;
// flatten_vector(v_flat, v);
// std::copy(v_flat.begin(), v_flat.end(), virial);
// }
if (virial) {
std::vector<VALUETYPE> v_flat;
flatten_vector(v_flat, v);
std::copy(v_flat.begin(), v_flat.end(), virial);
}
if (atomic_energy) {
std::vector<VALUETYPE> ae_flat;
flatten_vector(ae_flat, ae);
Expand Down
Loading