diff --git a/backend/dynamic_metadata.py b/backend/dynamic_metadata.py index e7763cac84..d5f7b370b5 100644 --- a/backend/dynamic_metadata.py +++ b/backend/dynamic_metadata.py @@ -48,8 +48,11 @@ def dynamic_metadata( ] optional_dependencies["lmp"].extend(find_libpython_requires) optional_dependencies["ipi"].extend(find_libpython_requires) + torch_static_requirement = optional_dependencies.pop("torch", ()) return { **optional_dependencies, **get_tf_requirement(tf_version), - **get_pt_requirement(pt_version), + **get_pt_requirement( + pt_version, static_requirement=tuple(torch_static_requirement) + ), } diff --git a/backend/find_pytorch.py b/backend/find_pytorch.py index d50f57bf5e..9fedf000b5 100644 --- a/backend/find_pytorch.py +++ b/backend/find_pytorch.py @@ -90,7 +90,10 @@ def find_pytorch() -> tuple[str | None, list[str]]: @lru_cache -def get_pt_requirement(pt_version: str = "") -> dict: +def get_pt_requirement( + pt_version: str = "", + static_requirement: tuple[str] | None = None, +) -> dict: """Get PyTorch requirement when PT is not installed. If pt_version is not given and the environment variable `PYTORCH_VERSION` is set, use it as the requirement. @@ -99,6 +102,8 @@ def get_pt_requirement(pt_version: str = "") -> dict: ---------- pt_version : str, optional PT version + static_requirement : tuple[str] or None, optional + Static requirements Returns ------- @@ -125,6 +130,8 @@ def get_pt_requirement(pt_version: str = "") -> dict: mpi_requirement = ["mpich"] else: mpi_requirement = [] + if static_requirement is None: + static_requirement = () return { "torch": [ @@ -138,6 +145,7 @@ def get_pt_requirement(pt_version: str = "") -> dict: else "torch>=2.1.0", *mpi_requirement, *cibw_requirement, + *static_requirement, ], } diff --git a/deepmd/dpmodel/modifier/base_modifier.py b/deepmd/dpmodel/modifier/base_modifier.py index febb9b75e8..5a3f266b1b 100644 --- a/deepmd/dpmodel/modifier/base_modifier.py +++ b/deepmd/dpmodel/modifier/base_modifier.py @@ -32,7 +32,6 @@ def serialize(self) -> dict: dict The serialized data """ - pass @classmethod def deserialize(cls, data: dict) -> "BaseModifier": diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 46ad8a6cd0..0f68052287 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -1,10 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import argparse import copy -import io import json import logging import os +import pickle from pathlib import ( Path, ) @@ -401,17 +401,11 @@ def freeze( model.eval() model = torch.jit.script(model) - dm_output = "data_modifier.pth" - extra_files = {dm_output: ""} - if tester.modifier is not None: - dm = tester.modifier - dm.eval() - buffer = io.BytesIO() - torch.jit.save( - torch.jit.script(dm), - buffer, - ) - extra_files = {dm_output: buffer.getvalue()} + extra_files = {"modifier_data": ""} + dm = tester.modifier + if dm is not None: + bytes_data = pickle.dumps(dm.serialize()) + extra_files = {"modifier_data": bytes_data} torch.jit.save( model, output, diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 6e63ecb2fc..b909f5416d 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import io import json import logging +import pickle from collections.abc import ( Callable, ) @@ -49,6 +49,9 @@ from deepmd.pt.model.network.network import ( TypeEmbedNetConsistent, ) +from deepmd.pt.modifier import ( + BaseModifier, +) from deepmd.pt.train.wrapper import ( ModelWrapper, ) @@ -172,19 +175,20 @@ def __init__( self.dp = ModelWrapper(model) self.dp.load_state_dict(state_dict) elif str(self.model_path).endswith(".pth"): - extra_files = {"data_modifier.pth": ""} + extra_files = {"modifier_data": ""} model = torch.jit.load( model_file, map_location=env.DEVICE, _extra_files=extra_files ) modifier = None # Load modifier if it exists in extra_files - if len(extra_files["data_modifier.pth"]) > 0: - # Create a file-like object from the in-memory data - modifier_data = extra_files["data_modifier.pth"] + if len(extra_files["modifier_data"]) > 0: + modifier_data = extra_files["modifier_data"] if isinstance(modifier_data, bytes): - modifier_data = io.BytesIO(modifier_data) + modifier_data = pickle.loads(modifier_data) # Load the modifier directly from the file-like object - modifier = torch.jit.load(modifier_data, map_location=env.DEVICE) + modifier = BaseModifier.get_class_by_type( + modifier_data["type"] + ).deserialize(modifier_data) self.dp = ModelWrapper(model, modifier=modifier) self.modifier = modifier model_def_script = self.dp.model["Default"].get_model_def_script() diff --git a/deepmd/pt/modifier/__init__.py b/deepmd/pt/modifier/__init__.py index 71d847bcbc..f196ebf000 100644 --- a/deepmd/pt/modifier/__init__.py +++ b/deepmd/pt/modifier/__init__.py @@ -7,9 +7,13 @@ from .base_modifier import ( BaseModifier, ) +from .dipole_charge import ( + DipoleChargeModifier, +) __all__ = [ "BaseModifier", + "DipoleChargeModifier", "get_data_modifier", ] diff --git a/deepmd/pt/modifier/base_modifier.py b/deepmd/pt/modifier/base_modifier.py index 5a8c6538b0..db37694305 100644 --- a/deepmd/pt/modifier/base_modifier.py +++ b/deepmd/pt/modifier/base_modifier.py @@ -47,6 +47,7 @@ def serialize(self) -> dict: data = { "@class": "Modifier", "type": self.modifier_type, + "use_cache": self.use_cache, "@version": 3, } return data diff --git a/deepmd/pt/modifier/dipole_charge.py b/deepmd/pt/modifier/dipole_charge.py new file mode 100644 index 0000000000..8997cbee0b --- /dev/null +++ b/deepmd/pt/modifier/dipole_charge.py @@ -0,0 +1,429 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import os + +import numpy as np +import torch +from torch_admp.pme import ( + CoulombForceModule, +) +from torch_admp.utils import ( + calc_grads, +) + +from deepmd.pt.model.model import ( + DipoleModel, +) +from deepmd.pt.modifier.base_modifier import ( + BaseModifier, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.serialization import ( + serialize_from_file, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + + +@BaseModifier.register("dipole_charge") +class DipoleChargeModifier(BaseModifier): + """Parameters + ---------- + model_name + The model file for the DeepDipole model + model_charge_map + Gives the amount of charge for the wfcc + sys_charge_map + Gives the amount of charge for the real atoms + ewald_h + Grid spacing of the reciprocal part of Ewald sum. Unit: A + ewald_beta + Splitting parameter of the Ewald sum. Unit: A^{-1} + """ + + def __init__( + self, + model_name: str | None, + model_charge_map: list[float], + sys_charge_map: list[float], + ewald_h: float = 1.0, + ewald_beta: float = 1.0, + ewald_batch_size: int = 5, + dp_batch_size: int | None = None, + model: DipoleModel | None = None, + use_cache: bool = True, + ) -> None: + """Constructor.""" + super().__init__(use_cache=use_cache) + self.modifier_type = "dipole_charge" + + if model_name is None and model is None: + raise AttributeError("`model_name` or `model` should be specified.") + if model_name is not None and model is not None: + raise AttributeError( + "`model_name` and `model` cannot be used simultaneously." + ) + + if model is not None: + self._model = model.to(env.DEVICE) + if model_name is not None: + data = serialize_from_file(model_name) + self._model = DipoleModel.deserialize(data["model"]).to(env.DEVICE) + self._model.eval() + + # use jit model for inference + self.model = torch.jit.script(self._model) + self.rcut = self.model.get_rcut() + self.type_map = self.model.get_type_map() + sel_type = self.model.get_sel_type() + self.sel_type = to_torch_tensor(np.array(sel_type)) + self.model_charge_map = to_torch_tensor(np.array(model_charge_map)) + self.sys_charge_map = to_torch_tensor(np.array(sys_charge_map)) + self._model_charge_map = model_charge_map + self._sys_charge_map = sys_charge_map + + # Validate that model_charge_map and sel_type have matching lengths + if len(model_charge_map) != len(sel_type): + raise ValueError( + f"model_charge_map length ({len(model_charge_map)}) must match " + f"sel_type length ({len(sel_type)})" + ) + + # init ewald recp + self.ewald_h = ewald_h + self.ewald_beta = ewald_beta + self.er = CoulombForceModule( + rcut=self.rcut, + rspace=False, + kappa=ewald_beta, + spacing=ewald_h, + ) + self.placeholder_pairs = torch.ones((1, 2), device=env.DEVICE, dtype=torch.long) + self.placeholder_ds = torch.ones((1), device=env.DEVICE, dtype=torch.float64) + self.placeholder_buffer_scales = torch.zeros( + (1), device=env.DEVICE, dtype=torch.float64 + ) + + self.ewald_batch_size = ewald_batch_size + if dp_batch_size is None: + dp_batch_size = int(os.environ.get("DP_INFER_BATCH_SIZE", 1)) + self.dp_batch_size = dp_batch_size + + def serialize(self) -> dict: + """Serialize the modifier. + + Returns + ------- + dict + The serialized data + """ + dd = BaseModifier.serialize(self) + dd.update( + { + "model": self._model.serialize(), + "model_charge_map": self._model_charge_map, + "sys_charge_map": self._sys_charge_map, + "ewald_h": self.ewald_h, + "ewald_beta": self.ewald_beta, + "ewald_batch_size": self.ewald_batch_size, + "dp_batch_size": self.dp_batch_size, + } + ) + return dd + + @classmethod + def deserialize(cls, data: dict) -> "DipoleChargeModifier": + data = data.copy() + data.pop("@class", None) + data.pop("type", None) + data.pop("@version", None) + model_obj = DipoleModel.deserialize(data.pop("model")) + data["model"] = model_obj + data["model_name"] = None + return cls(**data) + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + """Compute energy, force, and virial corrections for dipole-charge systems. + + This method extends the system with Wannier Function Charge Centers (WFCC) + by adding dipole vectors to atomic coordinates for selected atom types. + It then calculates the electrostatic interactions using Ewald reciprocal + summation to obtain energy, force, and virial corrections. + + Parameters + ---------- + coord : torch.Tensor + The coordinates of atoms with shape (nframes, natoms, 3) + atype : torch.Tensor + The atom types with shape (nframes, natoms) + box : torch.Tensor | None, optional + The simulation box with shape (nframes, 3, 3), by default None + Note: This modifier can only be applied for periodic systems + fparam : torch.Tensor | None, optional + Frame parameters with shape (nframes, nfp), by default None + aparam : torch.Tensor | None, optional + Atom parameters with shape (nframes, natoms, nap), by default None + do_atomic_virial : bool, optional + Whether to compute atomic virial, by default False + Note: This parameter is currently not implemented and is ignored + + Returns + ------- + dict[str, torch.Tensor] + Dictionary containing the correction terms: + - energy: Energy correction tensor with shape (nframes, 1) + - force: Force correction tensor with shape (nframes, natoms, 3) + - virial: Virial correction tensor with shape (nframes, 3, 3) + """ + if box is None: + raise RuntimeError( + "dipole_charge data modifier can only be applied for periodic systems." + ) + modifier_pred = {} + nframes = coord.shape[0] + natoms = coord.shape[1] + + input_box = box.reshape(nframes, 9) + input_box.requires_grad_(True) + + detached_box = input_box.detach() + sfactor = torch.matmul( + torch.inverse(detached_box.reshape(nframes, 3, 3)), + input_box.reshape(nframes, 3, 3), + ) + input_coord = torch.matmul(coord, sfactor).reshape(nframes, -1) + + extended_coord, extended_charge, _atomic_dipole = self.extend_system( + input_coord, + atype, + input_box, + fparam, + aparam, + ) + + # add Ewald reciprocal correction + tot_e: list[torch.Tensor] = [] + chunk_coord = torch.split( + extended_coord.reshape(nframes, -1, 3), self.dp_batch_size, dim=0 + ) + chunk_box = torch.split( + input_box.reshape(nframes, 3, 3), self.dp_batch_size, dim=0 + ) + chunk_charge = torch.split( + extended_charge.reshape(nframes, -1), self.dp_batch_size, dim=0 + ) + for _coord, _box, _charge in zip( + chunk_coord, chunk_box, chunk_charge, strict=True + ): + self.er( + _coord, + _box, + self.placeholder_pairs, + self.placeholder_ds, + self.placeholder_buffer_scales, + {"charge": _charge}, + ) + tot_e.append(self.er.reciprocal_energy.unsqueeze(0)) + # nframe, + tot_e = torch.concat(tot_e, dim=0) + # nframe, nat * 3 + tot_f = -calc_grads(tot_e, input_coord) + # nframe, nat, 3 + tot_f = torch.reshape(tot_f, (nframes, natoms, 3)) + # nframe, 9 + tot_v = calc_grads(tot_e, input_box) + tot_v = torch.reshape(tot_v, (nframes, 3, 3)) + # nframe, 3, 3 + tot_v = -torch.matmul(tot_v.transpose(2, 1), input_box.reshape(nframes, 3, 3)) + + modifier_pred["energy"] = tot_e + modifier_pred["force"] = tot_f + modifier_pred["virial"] = tot_v + return modifier_pred + + def extend_system( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Extend the system with WFCC (Wannier Function Charge Centers). + + Parameters + ---------- + coord : torch.Tensor + The coordinates of atoms with shape (nframes, natoms * 3) + atype : torch.Tensor + The atom types with shape (nframes, natoms) + box : torch.Tensor + The simulation box with shape (nframes, 9) + fparam : torch.Tensor | None, optional + Frame parameters with shape (nframes, nfp), by default None + aparam : torch.Tensor | None, optional + Atom parameters with shape (nframes, natoms, nap), by default None + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + A tuple containing three tensors: + - extended_coord : torch.Tensor + Extended coordinates with shape (nframes, 2 * natoms * 3) + - extended_charge : torch.Tensor + Extended charges with shape (nframes, 2 * natoms) + - atomic_dipole : torch.Tensor + Atomic dipoles with shape (nframes, natoms, 3) + """ + # nframes, natoms, 3 + extended_coord, atomic_dipole = self.extend_system_coord( + coord, + atype, + box, + fparam, + aparam, + ) + # Get ion charges based on atom types + # nframe x nat + ion_charge = self.sys_charge_map[atype] + # Initialize wfcc charges + wc_charge = torch.zeros_like(ion_charge) + # Assign charges to selected atom types + for ii, charge in enumerate(self.model_charge_map): + wc_charge[atype == self.sel_type[ii]] = charge + # Concatenate ion charges and wfcc charges + extended_charge = torch.cat([ion_charge, wc_charge], dim=1) + return extended_coord, extended_charge, atomic_dipole + + def extend_system_coord( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Extend the system with WFCC (Wannier Function Charge Centers). + + This function calculates Wannier Function Charge Centers (WFCC) by adding dipole + vectors to atomic coordinates for selected atom types, then concatenates these + WFCC coordinates with the original atomic coordinates. + + Parameters + ---------- + coord : torch.Tensor + The coordinates of atoms with shape (nframes, natoms * 3) + atype : torch.Tensor + The atom types with shape (nframes, natoms) + box : torch.Tensor + The simulation box with shape (nframes, 9) + fparam : torch.Tensor | None, optional + Frame parameters with shape (nframes, nfp), by default None + aparam : torch.Tensor | None, optional + Atom parameters with shape (nframes, natoms, nap), by default None + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + A tuple containing two tensors: + - all_coord : torch.Tensor + Extended coordinates with shape (nframes, 2 * natoms * 3) + where nsel is the number of selected atoms + - dipole_reshaped : torch.Tensor + Atomic dipoles with shape (nframes, natoms, 3) + """ + nframes = coord.shape[0] + natoms = coord.shape[1] // 3 + + all_dipole: list[torch.Tensor] = [] + chunk_coord = torch.split(coord, self.dp_batch_size, dim=0) + chunk_atype = torch.split(atype, self.dp_batch_size, dim=0) + chunk_box = torch.split(box, self.dp_batch_size, dim=0) + # use placeholder to make the jit happy for fparam/aparam is None + chunk_fparam = ( + torch.split(fparam, self.dp_batch_size, dim=0) + if fparam is not None + else chunk_atype + ) + chunk_aparam = ( + torch.split(aparam, self.dp_batch_size, dim=0) + if aparam is not None + else chunk_atype + ) + for _coord, _atype, _box, _fparam, _aparam in zip( + chunk_coord, chunk_atype, chunk_box, chunk_fparam, chunk_aparam, strict=True + ): + dipole_batch = self.model( + coord=_coord, + atype=_atype, + box=_box, + do_atomic_virial=False, + fparam=_fparam if fparam is not None else None, + aparam=_aparam if aparam is not None else None, + ) + # Extract dipole from the output dictionary + all_dipole.append(dipole_batch["dipole"]) + + # nframe x natoms x 3 + dipole = torch.cat(all_dipole, dim=0) + if dipole.shape[0] != nframes: + raise RuntimeError( + f"Dipole shape mismatch: expected {nframes} frames, got {dipole.shape[0]}" + ) + + dipole_reshaped = dipole.reshape(nframes, natoms, 3) + coord_reshaped = coord.reshape(nframes, natoms, 3) + wfcc_coord = coord_reshaped + dipole_reshaped + all_coord = torch.cat((coord_reshaped, wfcc_coord), dim=1) + return all_coord, dipole_reshaped + + def eval_np( + self, + coord: np.ndarray, + box: np.ndarray, + atype: np.ndarray, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + nf = coord.shape[0] + na = coord.reshape(nf, -1, 3).shape[1] + + if fparam is not None: + _fparam = ( + to_torch_tensor(fparam) + .reshape(nf, -1) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + ) + else: + _fparam = None + if aparam is not None: + _aparam = ( + to_torch_tensor(aparam) + .reshape(nf, na, -1) + .to(env.GLOBAL_PT_FLOAT_PRECISION) + ) + else: + _aparam = None + modifier_pred = self.forward( + to_torch_tensor(coord).reshape(nf, -1, 3).to(env.GLOBAL_PT_FLOAT_PRECISION), + to_torch_tensor(atype).reshape(nf, -1).to(torch.long), + to_torch_tensor(box).reshape(nf, 3, 3).to(env.GLOBAL_PT_FLOAT_PRECISION), + _fparam, + _aparam, + ) + return ( + to_numpy_array(modifier_pred["energy"]), + to_numpy_array(modifier_pred["force"]), + to_numpy_array(modifier_pred["virial"]), + ) diff --git a/deepmd/pt/train/wrapper.py b/deepmd/pt/train/wrapper.py index ddb4a4323d..d2cef75614 100644 --- a/deepmd/pt/train/wrapper.py +++ b/deepmd/pt/train/wrapper.py @@ -191,7 +191,7 @@ def forward( if self.modifier is not None: modifier_pred = self.modifier(**input_dict) for k, v in modifier_pred.items(): - model_pred[k] = model_pred[k] + v + model_pred[k] = model_pred[k] + v.reshape(model_pred[k].shape) return model_pred, None, None else: natoms = atype.shape[-1] diff --git a/deepmd/tf/modifier/dipole_charge.py b/deepmd/tf/modifier/dipole_charge.py index a2d3efb353..ad3016c677 100644 --- a/deepmd/tf/modifier/dipole_charge.py +++ b/deepmd/tf/modifier/dipole_charge.py @@ -52,8 +52,8 @@ class DipoleChargeModifier(DeepDipole, BaseModifier): Splitting parameter of the Ewald sum. Unit: A^{-1} """ - def __new__(cls, *args, model_name=None, **kwargs): - return super().__new__(cls, model_name) + def __new__(cls, *args, model_name: str, **kwargs): + return super().__new__(cls, model_file=model_name) def __init__( self, diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 8c20bb8bf4..c82b24d52a 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2310,7 +2310,7 @@ def model_args(exclude_hybrid: bool = False) -> list[Argument]: [], [modifier_variant_type_args()], optional=True, - doc=doc_only_tf_supported + doc_modifier, + doc=doc_modifier, ), Argument( "compress", diff --git a/doc/model/dplr.md b/doc/model/dplr.md index 61327bb55e..66b385f155 100644 --- a/doc/model/dplr.md +++ b/doc/model/dplr.md @@ -1,10 +1,10 @@ -# Deep potential long-range (DPLR) {{ tensorflow_icon }} +# Deep potential long-range (DPLR) {{ tensorflow_icon }} {{ pytorch_icon }} :::{note} -**Supported backends**: TensorFlow {{ tensorflow_icon }} +**Supported backends**: TensorFlow {{ tensorflow_icon }}, PyTorch {{ pytorch_icon }} ::: -Notice: **The interfaces of DPLR are not stable and subject to change** +Notice: **The interfaces of DPLR are not stable and subject to change. In addition, DP/LAMMPS interface does not yet support PyTorch DPLR models.** The method of DPLR is described in [this paper][1]. One is recommended to read the paper before using the DPLR. diff --git a/pyproject.toml b/pyproject.toml index 9c6f213cfd..ae6e3eb248 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -140,6 +140,9 @@ jax = [ # The pinning of ml_dtypes may conflict with TF # 'jax-ai-stack;python_version>="3.10"', ] +torch = [ + "torch-admp", +] [tool.deepmd_build_backend.scripts] dp = "deepmd.main:main" @@ -165,9 +168,11 @@ pin_pytorch_cpu = [ # macos x86 has been deprecated "torch>=2.8,<2.10; platform_machine!='x86_64' or platform_system != 'Darwin'", "torch; platform_machine=='x86_64' and platform_system == 'Darwin'", + "torch-admp==1.1.3", ] pin_pytorch_gpu = [ "torch>=2.7,<2.10", + "torch-admp==1.1.3", ] pin_jax = [ "jax==0.5.0;python_version>='3.10'", diff --git a/source/tests/pt/modifier/__init__.py b/source/tests/pt/modifier/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt/modifier/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt/test_data_modifier.py b/source/tests/pt/modifier/test_data_modifier.py similarity index 89% rename from source/tests/pt/test_data_modifier.py rename to source/tests/pt/modifier/test_data_modifier.py index 18d66ef2ff..ddba54cd05 100644 --- a/source/tests/pt/test_data_modifier.py +++ b/source/tests/pt/modifier/test_data_modifier.py @@ -36,12 +36,18 @@ freeze, get_trainer, ) +from deepmd.pt.model.model import ( + EnergyModel, +) from deepmd.pt.modifier.base_modifier import ( BaseModifier, ) from deepmd.pt.utils import ( env, ) +from deepmd.pt.utils.serialization import ( + serialize_from_file, +) from deepmd.pt.utils.utils import ( to_numpy_array, ) @@ -52,7 +58,7 @@ DeepmdData, ) -from ..consistent.common import ( +from ...consistent.common import ( parameterized, ) @@ -85,7 +91,9 @@ def modifier_scaling_tester() -> list[Argument]: doc_sfactor = "The scaling factor for correction." doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation." return [ - Argument("model_name", str, optional=False, doc=doc_model_name), + Argument( + "model_name", str, alias=["model"], optional=False, doc=doc_model_name + ), Argument("sfactor", float, optional=False, doc=doc_sfactor), Argument("use_cache", bool, optional=True, doc=doc_use_cache), ] @@ -181,21 +189,69 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N @BaseModifier.register("scaling_tester") class ModifierScalingTester(BaseModifier): - def __new__(cls, *args, **kwargs): - return super().__new__(cls) + def __new__( + cls, + *args: tuple, + model: str | None = None, + model_name: str | None = None, + **kwargs: dict, + ) -> "ModifierScalingTester": + return super().__new__(cls, model_name if model_name is not None else model) def __init__( self, - model_name: str, + model: torch.nn.Module | None = None, + model_name: str | None = None, sfactor: float = 1.0, use_cache: bool = True, ) -> None: """Initialize a test modifier that applies scaled model predictions using a frozen model.""" super().__init__(use_cache) self.modifier_type = "scaling_tester" - self.model_name = model_name self.sfactor = sfactor - self.model = torch.jit.load(model_name, map_location=env.DEVICE) + + if model_name is None and model is None: + raise AttributeError("`model_name` or `model` should be specified.") + if model_name is not None and model is not None: + raise AttributeError( + "`model_name` and `model` cannot be used simultaneously." + ) + + if model is not None: + self._model = model.to(env.DEVICE) + if model_name is not None: + data = serialize_from_file(model_name) + self._model = EnergyModel.deserialize(data["model"]).to(env.DEVICE) + + # use jit model for inference + self.model = torch.jit.script(self._model) + + def serialize(self) -> dict: + """Serialize the modifier. + + Returns + ------- + dict + The serialized data + """ + dd = BaseModifier.serialize(self) + dd.update( + { + "model": self._model.serialize(), + "sfactor": self.sfactor, + } + ) + return dd + + @classmethod + def deserialize(cls, data: dict) -> "ModifierScalingTester": + data = data.copy() + data.pop("@class", None) + data.pop("type", None) + data.pop("@version", None) + model_obj = EnergyModel.deserialize(data.pop("model")) + data["model"] = model_obj + return cls(**data) def forward( self, diff --git a/source/tests/pt/modifier/test_dipole_charge.py b/source/tests/pt/modifier/test_dipole_charge.py new file mode 100644 index 0000000000..6c08f5cd1c --- /dev/null +++ b/source/tests/pt/modifier/test_dipole_charge.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +import os +import tempfile +import unittest +from pathlib import ( + Path, +) + +import numpy as np +import torch + +from deepmd.entrypoints.convert_backend import ( + convert_backend, +) +from deepmd.pt.entrypoints.main import ( + freeze, + get_trainer, +) +from deepmd.pt.modifier import DipoleChargeModifier as PTDipoleChargeModifier +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.tf.modifier import DipoleChargeModifier as TFDipoleChargeModifier + +from ...seed import ( + GLOBAL_SEED, +) + + +def ref_data(): + all_box = np.load(str(Path(__file__).parent / "water/data/data_0/set.000/box.npy")) + all_coord = np.load( + str(Path(__file__).parent / "water/data/data_0/set.000/coord.npy") + ) + nframe = len(all_box) + rng = np.random.default_rng(GLOBAL_SEED) + selected_id = rng.integers(nframe) + + coord = all_coord[selected_id].reshape(1, -1) + box = all_box[selected_id].reshape(1, -1) + atype = np.loadtxt( + str(Path(__file__).parent / "water/data/data_0/type.raw"), + dtype=int, + ).reshape(1, -1) + return coord, box, atype + + +class TestDipoleChargeModifier(unittest.TestCase): + def setUp(self) -> None: + self.test_dir = tempfile.TemporaryDirectory() + self.orig_dir = os.getcwd() + os.chdir(self.test_dir.name) + # setup parameter + # numerical consistency can only be achieved with high prec + self.ewald_h = 0.1 + self.ewald_beta = 0.5 + self.model_charge_map = [-8.0] + self.sys_charge_map = [6.0, 1.0] + self.descriptor_dict = { + "type": "se_e2_a", + "sel": [12, 24], + "rcut_smth": 0.5, + "rcut": 4.00, + "neuron": [6, 12, 24], + } + + # Train DW model + input_json = str(Path(__file__).parent / "water_tensor/se_e2_a.json") + with open(input_json, encoding="utf-8") as f: + config = json.load(f) + config["model"]["descriptor"].update(self.descriptor_dict) + config["training"]["numb_steps"] = 1 + config["training"]["save_freq"] = 1 + config["learning_rate"]["start_lr"] = 1.0 + config["training"]["training_data"]["systems"] = [ + str(Path(__file__).parent / "water_tensor/dipole/O78H156"), + ] + config["training"]["validation_data"]["systems"] = [ + str(Path(__file__).parent / "water_tensor/dipole/O96H192") + ] + + trainer = get_trainer(config) + trainer.run() + freeze( + model="model.ckpt.pt", + output="dw_model.pth", + head=None, + ) + # Convert pb model to pth model + convert_backend(INPUT="dw_model.pth", OUTPUT="dw_model.pb") + + self.dm_pt = PTDipoleChargeModifier( + "dw_model.pth", + self.model_charge_map, + self.sys_charge_map, + self.ewald_h, + self.ewald_beta, + ) + self.dm_tf = TFDipoleChargeModifier( + "dw_model.pb", + self.model_charge_map, + self.sys_charge_map, + self.ewald_h, + self.ewald_beta, + ) + + def test_jit(self): + torch.jit.script(self.dm_pt) + + def test_consistency(self): + coord, box, atype = ref_data() + + pt_data = self.dm_pt.eval_np( + coord=coord, + atype=atype, + box=box, + ) + tf_data = self.dm_tf.eval( + coord=coord, + box=box, + atype=atype.reshape(-1), + ) + output_names = ["energy", "force", "virial"] + for ii, name in enumerate(output_names): + np.testing.assert_allclose( + pt_data[ii].reshape(-1), + tf_data[ii].reshape(-1), + atol=1e-6, + rtol=1e-6, + err_msg=f"Mismatch in {name}", + ) + + def test_serialize(self): + """Test the serialize method of DipoleChargeModifier.""" + coord, box, atype = ref_data() + # consistent with the input shape from BaseModifier.modify_data + t_coord = ( + to_torch_tensor(coord).to(env.GLOBAL_PT_FLOAT_PRECISION).reshape(1, -1, 3) + ) + t_box = to_torch_tensor(box).to(env.GLOBAL_PT_FLOAT_PRECISION).reshape(1, 3, 3) + t_atype = to_torch_tensor(atype).to(torch.long).reshape(1, -1) + + dm0 = self.dm_pt.to(env.DEVICE) + dm1 = PTDipoleChargeModifier.deserialize(dm0.serialize()).to(env.DEVICE) + + ret0 = dm0( + coord=t_coord, + atype=t_atype, + box=t_box, + ) + ret1 = dm1( + coord=t_coord, + atype=t_atype, + box=t_box, + ) + + np.testing.assert_allclose( + to_numpy_array(ret0["energy"]), to_numpy_array(ret1["energy"]) + ) + np.testing.assert_allclose( + to_numpy_array(ret0["force"]), to_numpy_array(ret1["force"]) + ) + np.testing.assert_allclose( + to_numpy_array(ret0["virial"]), to_numpy_array(ret1["virial"]) + ) + + def test_box_none_error(self): + """Test that a RuntimeError is raised when box is None.""" + coord, _b, atype = ref_data() + # consistent with the input shape from BaseModifier.modify_data + t_coord = ( + to_torch_tensor(coord).to(env.GLOBAL_PT_FLOAT_PRECISION).reshape(1, -1, 3) + ) + t_atype = to_torch_tensor(atype).to(torch.long).reshape(1, -1) + + with self.assertRaises(RuntimeError) as context: + self.dm_pt( + coord=t_coord, + atype=t_atype, + box=None, # Pass None to trigger the error + ) + + self.assertIn( + "dipole_charge data modifier can only be applied for periodic systems", + str(context.exception), + ) + + def test_train(self): + input_json = str(Path(__file__).parent / "water/se_e2_a.json") + with open(input_json, encoding="utf-8") as f: + config = json.load(f) + config["model"]["descriptor"].update(self.descriptor_dict) + config["training"]["save_freq"] = 1 + config["learning_rate"]["start_lr"] = 1.0 + config["training"]["training_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/data_0"), + str(Path(__file__).parent / "water/data/data_1"), + ] + config["training"]["validation_data"]["systems"] = [ + str(Path(__file__).parent / "water/data/single"), + ] + config["training"]["numb_steps"] = 1 + + trainer = get_trainer(config) + trainer.run() + # Verify model checkpoint was created + self.assertTrue( + Path("model.ckpt.pt").exists(), + "Training should produce a model checkpoint", + ) + + def tearDown(self) -> None: + os.chdir(self.orig_dir) + self.test_dir.cleanup() diff --git a/source/tests/pt/modifier/water b/source/tests/pt/modifier/water new file mode 120000 index 0000000000..b4ce4e224a --- /dev/null +++ b/source/tests/pt/modifier/water @@ -0,0 +1 @@ +../model/water/ \ No newline at end of file diff --git a/source/tests/pt/modifier/water_tensor b/source/tests/pt/modifier/water_tensor new file mode 120000 index 0000000000..a8c63dbb30 --- /dev/null +++ b/source/tests/pt/modifier/water_tensor @@ -0,0 +1 @@ +../water_tensor/ \ No newline at end of file