Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 10 additions & 31 deletions deepmd/tf/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,37 +137,15 @@ def __init__(
self.has_aparam = self.tensors["aparam"] is not None
self.has_spin = self.ntypes_spin > 0

# looks ugly...
if self.modifier_type == "dipole_charge":
from deepmd.tf.modifier import (
DipoleChargeModifier,
)
from deepmd.tf.modifier import (
BaseModifier,
)

t_mdl_name = self._get_tensor("modifier_attr/mdl_name:0")
t_mdl_charge_map = self._get_tensor("modifier_attr/mdl_charge_map:0")
t_sys_charge_map = self._get_tensor("modifier_attr/sys_charge_map:0")
t_ewald_h = self._get_tensor("modifier_attr/ewald_h:0")
t_ewald_beta = self._get_tensor("modifier_attr/ewald_beta:0")
[mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess(
self.sess,
[
t_mdl_name,
t_mdl_charge_map,
t_sys_charge_map,
t_ewald_h,
t_ewald_beta,
],
)
mdl_name = mdl_name.decode("UTF-8")
mdl_charge_map = [int(ii) for ii in mdl_charge_map.decode("UTF-8").split()]
sys_charge_map = [int(ii) for ii in sys_charge_map.decode("UTF-8").split()]
self.dm = DipoleChargeModifier(
mdl_name,
mdl_charge_map,
sys_charge_map,
ewald_h=ewald_h,
ewald_beta=ewald_beta,
)
self.dm = None
if self.modifier_type is not None:
modifier = BaseModifier.get_class_by_type(self.modifier_type)
modifier_params = modifier.get_params(self)
self.dm = modifier.get_modifier(modifier_params)

def _init_tensors(self) -> None:
tensor_names = {
Expand Down Expand Up @@ -684,7 +662,8 @@ def _get_natoms_and_nframes(
coords: np.ndarray,
atom_types: list[int] | np.ndarray,
) -> tuple[int, int]:
natoms = len(atom_types[0])
atom_types = np.reshape(atom_types, (-1))
natoms = len(atom_types)
if natoms == 0:
assert coords.size == 0
else:
Expand Down
35 changes: 35 additions & 0 deletions deepmd/tf/modifier/dipole_charge.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
op_module,
tf,
)
from deepmd.tf.infer import (
DeepEval,
)
from deepmd.tf.infer.deep_dipole import DeepDipoleOld as DeepDipole
from deepmd.tf.infer.ewald_recp import (
EwaldRecp,
Expand Down Expand Up @@ -487,3 +490,35 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None:
data["force"] -= tot_f.reshape(data["force"].shape)
if "find_virial" in data and data["find_virial"] == 1.0:
data["virial"] -= tot_v.reshape(data["virial"].shape)

@staticmethod
def get_params(model: DeepEval):
t_mdl_name = model._get_tensor("modifier_attr/mdl_name:0")
t_mdl_charge_map = model._get_tensor("modifier_attr/mdl_charge_map:0")
t_sys_charge_map = model._get_tensor("modifier_attr/sys_charge_map:0")
t_ewald_h = model._get_tensor("modifier_attr/ewald_h:0")
t_ewald_beta = model._get_tensor("modifier_attr/ewald_beta:0")
[mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess(
model.sess,
[
t_mdl_name,
t_mdl_charge_map,
t_sys_charge_map,
t_ewald_h,
t_ewald_beta,
],
)
model_charge_map = [
int(float(ii)) for ii in mdl_charge_map.decode("UTF-8").split()
]
sys_charge_map = [
int(float(ii)) for ii in sys_charge_map.decode("UTF-8").split()
]
modifier_params = {
"model_name": mdl_name.decode("UTF-8"),
"model_charge_map": model_charge_map,
"sys_charge_map": sys_charge_map,
"ewald_h": ewald_h,
"ewald_beta": ewald_beta,
}
return modifier_params
Loading