diff --git a/deepmd/tf/infer/deep_eval.py b/deepmd/tf/infer/deep_eval.py index d1eca83055..20f513cd75 100644 --- a/deepmd/tf/infer/deep_eval.py +++ b/deepmd/tf/infer/deep_eval.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import json +import logging from collections.abc import ( Callable, ) @@ -51,6 +52,8 @@ run_sess, ) +log = logging.getLogger(__name__) + if TYPE_CHECKING: from pathlib import ( Path, @@ -137,37 +140,25 @@ def __init__( self.has_aparam = self.tensors["aparam"] is not None self.has_spin = self.ntypes_spin > 0 - # looks ugly... - if self.modifier_type == "dipole_charge": - from deepmd.tf.modifier import ( - DipoleChargeModifier, - ) + if kwargs.get("skip_modifier", False): + self.modifier_type = None - t_mdl_name = self._get_tensor("modifier_attr/mdl_name:0") - t_mdl_charge_map = self._get_tensor("modifier_attr/mdl_charge_map:0") - t_sys_charge_map = self._get_tensor("modifier_attr/sys_charge_map:0") - t_ewald_h = self._get_tensor("modifier_attr/ewald_h:0") - t_ewald_beta = self._get_tensor("modifier_attr/ewald_beta:0") - [mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess( - self.sess, - [ - t_mdl_name, - t_mdl_charge_map, - t_sys_charge_map, - t_ewald_h, - t_ewald_beta, - ], - ) - mdl_name = mdl_name.decode("UTF-8") - mdl_charge_map = [int(ii) for ii in mdl_charge_map.decode("UTF-8").split()] - sys_charge_map = [int(ii) for ii in sys_charge_map.decode("UTF-8").split()] - self.dm = DipoleChargeModifier( - mdl_name, - mdl_charge_map, - sys_charge_map, - ewald_h=ewald_h, - ewald_beta=ewald_beta, - ) + from deepmd.tf.modifier import ( + BaseModifier, + ) + + self.dm = None + if self.modifier_type is not None: + try: + modifier = BaseModifier.get_class_by_type(self.modifier_type) + modifier_params = modifier.get_params_from_frozen_model(self) + self.dm = modifier.get_modifier(modifier_params) + except Exception as exc: + raise RuntimeError( + f"Failed to load data modifier '{self.modifier_type}'. " + f"Use skip_modifier=True to load the model without the modifier. " + f"Error: {exc}" + ) from exc def _init_tensors(self) -> None: tensor_names = { @@ -684,7 +675,8 @@ def _get_natoms_and_nframes( coords: np.ndarray, atom_types: list[int] | np.ndarray, ) -> tuple[int, int]: - natoms = len(atom_types[0]) + # (natoms,) or (nframes, natoms,) + natoms = np.shape(atom_types)[-1] if natoms == 0: assert coords.size == 0 else: diff --git a/deepmd/tf/modifier/base_modifier.py b/deepmd/tf/modifier/base_modifier.py index 4e214e0835..167811cd2a 100644 --- a/deepmd/tf/modifier/base_modifier.py +++ b/deepmd/tf/modifier/base_modifier.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from abc import ( + abstractmethod, +) + from deepmd.dpmodel.modifier.base_modifier import ( make_base_modifier, ) @@ -11,3 +15,23 @@ class BaseModifier(DeepPot, make_base_modifier()): def __init__(self, *args, **kwargs) -> None: """Construct a basic model for different tasks.""" DeepPot.__init__(self, *args, **kwargs) + + @staticmethod + @abstractmethod + def get_params_from_frozen_model(model) -> dict: + """Extract the modifier parameters from a model. + + This method should extract the necessary parameters from a model + to create an instance of this modifier. + + Parameters + ---------- + model + The model from which to extract parameters + + Returns + ------- + dict + The modifier parameters + """ + pass diff --git a/deepmd/tf/modifier/dipole_charge.py b/deepmd/tf/modifier/dipole_charge.py index d40c9ccd2f..a2d3efb353 100644 --- a/deepmd/tf/modifier/dipole_charge.py +++ b/deepmd/tf/modifier/dipole_charge.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import os +from typing import ( + TYPE_CHECKING, +) import numpy as np @@ -27,6 +30,11 @@ run_sess, ) +if TYPE_CHECKING: + from deepmd.tf.infer import ( + DeepEval, + ) + @BaseModifier.register("dipole_charge") class DipoleChargeModifier(DeepDipole, BaseModifier): @@ -487,3 +495,52 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None: data["force"] -= tot_f.reshape(data["force"].shape) if "find_virial" in data and data["find_virial"] == 1.0: data["virial"] -= tot_v.reshape(data["virial"].shape) + + @staticmethod + def get_params_from_frozen_model(model: "DeepEval") -> dict: + """Extract modifier parameters from a DeepEval model. + + Parameters + ---------- + model : DeepEval + The DeepEval model instance containing the modifier tensors. + + Returns + ------- + dict + Dictionary containing modifier parameters: + - model_name : str + - model_charge_map : list[int] + - sys_charge_map : list[int] + - ewald_h : float + - ewald_beta : float + """ + t_mdl_name = model._get_tensor("modifier_attr/mdl_name:0") + t_mdl_charge_map = model._get_tensor("modifier_attr/mdl_charge_map:0") + t_sys_charge_map = model._get_tensor("modifier_attr/sys_charge_map:0") + t_ewald_h = model._get_tensor("modifier_attr/ewald_h:0") + t_ewald_beta = model._get_tensor("modifier_attr/ewald_beta:0") + [mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess( + model.sess, + [ + t_mdl_name, + t_mdl_charge_map, + t_sys_charge_map, + t_ewald_h, + t_ewald_beta, + ], + ) + model_charge_map = [ + int(float(ii)) for ii in mdl_charge_map.decode("UTF-8").split() + ] + sys_charge_map = [ + int(float(ii)) for ii in sys_charge_map.decode("UTF-8").split() + ] + modifier_params = { + "model_name": mdl_name.decode("UTF-8"), + "model_charge_map": model_charge_map, + "sys_charge_map": sys_charge_map, + "ewald_h": ewald_h, + "ewald_beta": ewald_beta, + } + return modifier_params