Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 23 additions & 31 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import json
import logging
from collections.abc import (
Callable,
)
Expand Down Expand Up @@ -51,6 +52,8 @@
run_sess,
)

log = logging.getLogger(__name__)

if TYPE_CHECKING:
from pathlib import (
Path,
Expand Down Expand Up @@ -137,37 +140,25 @@ def __init__(
self.has_aparam = self.tensors["aparam"] is not None
self.has_spin = self.ntypes_spin > 0

# looks ugly...
if self.modifier_type == "dipole_charge":
from deepmd.tf.modifier import (
DipoleChargeModifier,
)
if kwargs.get("skip_modifier", False):
self.modifier_type = None

t_mdl_name = self._get_tensor("modifier_attr/mdl_name:0")
t_mdl_charge_map = self._get_tensor("modifier_attr/mdl_charge_map:0")
t_sys_charge_map = self._get_tensor("modifier_attr/sys_charge_map:0")
t_ewald_h = self._get_tensor("modifier_attr/ewald_h:0")
t_ewald_beta = self._get_tensor("modifier_attr/ewald_beta:0")
[mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess(
self.sess,
[
t_mdl_name,
t_mdl_charge_map,
t_sys_charge_map,
t_ewald_h,
t_ewald_beta,
],
)
mdl_name = mdl_name.decode("UTF-8")
mdl_charge_map = [int(ii) for ii in mdl_charge_map.decode("UTF-8").split()]
sys_charge_map = [int(ii) for ii in sys_charge_map.decode("UTF-8").split()]
self.dm = DipoleChargeModifier(
mdl_name,
mdl_charge_map,
sys_charge_map,
ewald_h=ewald_h,
ewald_beta=ewald_beta,
)
from deepmd.tf.modifier import (
BaseModifier,
)

self.dm = None
if self.modifier_type is not None:
try:
modifier = BaseModifier.get_class_by_type(self.modifier_type)
modifier_params = modifier.get_params_from_frozen_model(self)
self.dm = modifier.get_modifier(modifier_params)
except Exception as exc:
raise RuntimeError(
f"Failed to load data modifier '{self.modifier_type}'. "
f"Use skip_modifier=True to load the model without the modifier. "
f"Error: {exc}"
) from exc

def _init_tensors(self) -> None:
tensor_names = {
Expand Down Expand Up @@ -684,7 +675,8 @@ def _get_natoms_and_nframes(
coords: np.ndarray,
atom_types: list[int] | np.ndarray,
) -> tuple[int, int]:
natoms = len(atom_types[0])
# (natoms,) or (nframes, natoms,)
natoms = np.shape(atom_types)[-1]
if natoms == 0:
assert coords.size == 0
else:
Expand Down
24 changes: 24 additions & 0 deletions deepmd/tf/modifier/base_modifier.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
abstractmethod,
)

from deepmd.dpmodel.modifier.base_modifier import (
make_base_modifier,
)
Expand All @@ -11,3 +15,23 @@ class BaseModifier(DeepPot, make_base_modifier()):
def __init__(self, *args, **kwargs) -> None:
"""Construct a basic model for different tasks."""
DeepPot.__init__(self, *args, **kwargs)

@staticmethod
@abstractmethod
def get_params_from_frozen_model(model) -> dict:
"""Extract the modifier parameters from a model.

This method should extract the necessary parameters from a model
to create an instance of this modifier.

Parameters
----------
model
The model from which to extract parameters

Returns
-------
dict
The modifier parameters
"""
pass
57 changes: 57 additions & 0 deletions deepmd/tf/modifier/dipole_charge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import os
from typing import (
TYPE_CHECKING,
)

import numpy as np

Expand Down Expand Up @@ -27,6 +30,11 @@
run_sess,
)

if TYPE_CHECKING:
from deepmd.tf.infer import (
DeepEval,
)


@BaseModifier.register("dipole_charge")
class DipoleChargeModifier(DeepDipole, BaseModifier):
Expand Down Expand Up @@ -487,3 +495,52 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None:
data["force"] -= tot_f.reshape(data["force"].shape)
if "find_virial" in data and data["find_virial"] == 1.0:
data["virial"] -= tot_v.reshape(data["virial"].shape)

@staticmethod
def get_params_from_frozen_model(model: "DeepEval") -> dict:
"""Extract modifier parameters from a DeepEval model.

Parameters
----------
model : DeepEval
The DeepEval model instance containing the modifier tensors.

Returns
-------
dict
Dictionary containing modifier parameters:
- model_name : str
- model_charge_map : list[int]
- sys_charge_map : list[int]
- ewald_h : float
- ewald_beta : float
"""
t_mdl_name = model._get_tensor("modifier_attr/mdl_name:0")
t_mdl_charge_map = model._get_tensor("modifier_attr/mdl_charge_map:0")
t_sys_charge_map = model._get_tensor("modifier_attr/sys_charge_map:0")
t_ewald_h = model._get_tensor("modifier_attr/ewald_h:0")
t_ewald_beta = model._get_tensor("modifier_attr/ewald_beta:0")
[mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess(
model.sess,
[
t_mdl_name,
t_mdl_charge_map,
t_sys_charge_map,
t_ewald_h,
t_ewald_beta,
],
)
model_charge_map = [
int(float(ii)) for ii in mdl_charge_map.decode("UTF-8").split()
]
sys_charge_map = [
int(float(ii)) for ii in sys_charge_map.decode("UTF-8").split()
]
modifier_params = {
"model_name": mdl_name.decode("UTF-8"),
"model_charge_map": model_charge_map,
"sys_charge_map": sys_charge_map,
"ewald_h": ewald_h,
"ewald_beta": ewald_beta,
}
return modifier_params