Skip to content

Commit f736ab2

Browse files
feat(tf): optimize data modifier calling in deepeval (#5120)
# feat(tf): optimize data modifier calling in deepeval This PR includes refactoring changes to improve the modifier architecture in the DeepMD TensorFlow implementation. ## Modified Files ### 1. deepmd/tf/infer/deep_eval.py - Refactored the modifier initialization logic to use a more generic approach - Replaced hardcoded DipoleChargeModifier initialization with a dynamic class resolution using BaseModifier.get_class_by_type() - Simplified the modifier parameter extraction by moving it to the modifier class itself - Fixed atom_types reshaping issue in the _get_natoms_and_nframes method ### 2. deepmd/tf/modifier/dipole_charge.py - Added a new static method get_params() to handle parameter extraction from the model - This method encapsulates the logic for retrieving modifier-specific parameters from the model's tensors - Improved charge map parsing with proper type conversion ## Summary These changes make the modifier system more extensible by: - Introducing a plugin-like architecture for modifiers - Moving modifier-specific initialization logic to the respective modifier classes - Making it easier to add new modifier types in the future without modifying the core DeepEval class - Fixing a bug in atom type handling The refactoring follows the DRY principle and improves code maintainability while preserving existing functionality. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Switched to a unified, dynamic modifier system for modular handling. * Normalized atom-count derivation to support multi-frame and varied input shapes. * Added module-level logging and stricter error handling that prevents silent state mutation on modifier load failure. * **New Features** * Added a standard API for extracting modifier parameters from a frozen model for consistent parameter retrieval across modifiers. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent dfeba54 commit f736ab2

File tree

3 files changed

+104
-31
lines changed

3 files changed

+104
-31
lines changed

deepmd/tf/infer/deep_eval.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import json
3+
import logging
34
from collections.abc import (
45
Callable,
56
)
@@ -51,6 +52,8 @@
5152
run_sess,
5253
)
5354

55+
log = logging.getLogger(__name__)
56+
5457
if TYPE_CHECKING:
5558
from pathlib import (
5659
Path,
@@ -137,37 +140,25 @@ def __init__(
137140
self.has_aparam = self.tensors["aparam"] is not None
138141
self.has_spin = self.ntypes_spin > 0
139142

140-
# looks ugly...
141-
if self.modifier_type == "dipole_charge":
142-
from deepmd.tf.modifier import (
143-
DipoleChargeModifier,
144-
)
143+
if kwargs.get("skip_modifier", False):
144+
self.modifier_type = None
145145

146-
t_mdl_name = self._get_tensor("modifier_attr/mdl_name:0")
147-
t_mdl_charge_map = self._get_tensor("modifier_attr/mdl_charge_map:0")
148-
t_sys_charge_map = self._get_tensor("modifier_attr/sys_charge_map:0")
149-
t_ewald_h = self._get_tensor("modifier_attr/ewald_h:0")
150-
t_ewald_beta = self._get_tensor("modifier_attr/ewald_beta:0")
151-
[mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess(
152-
self.sess,
153-
[
154-
t_mdl_name,
155-
t_mdl_charge_map,
156-
t_sys_charge_map,
157-
t_ewald_h,
158-
t_ewald_beta,
159-
],
160-
)
161-
mdl_name = mdl_name.decode("UTF-8")
162-
mdl_charge_map = [int(ii) for ii in mdl_charge_map.decode("UTF-8").split()]
163-
sys_charge_map = [int(ii) for ii in sys_charge_map.decode("UTF-8").split()]
164-
self.dm = DipoleChargeModifier(
165-
mdl_name,
166-
mdl_charge_map,
167-
sys_charge_map,
168-
ewald_h=ewald_h,
169-
ewald_beta=ewald_beta,
170-
)
146+
from deepmd.tf.modifier import (
147+
BaseModifier,
148+
)
149+
150+
self.dm = None
151+
if self.modifier_type is not None:
152+
try:
153+
modifier = BaseModifier.get_class_by_type(self.modifier_type)
154+
modifier_params = modifier.get_params_from_frozen_model(self)
155+
self.dm = modifier.get_modifier(modifier_params)
156+
except Exception as exc:
157+
raise RuntimeError(
158+
f"Failed to load data modifier '{self.modifier_type}'. "
159+
f"Use skip_modifier=True to load the model without the modifier. "
160+
f"Error: {exc}"
161+
) from exc
171162

172163
def _init_tensors(self) -> None:
173164
tensor_names = {
@@ -684,7 +675,8 @@ def _get_natoms_and_nframes(
684675
coords: np.ndarray,
685676
atom_types: list[int] | np.ndarray,
686677
) -> tuple[int, int]:
687-
natoms = len(atom_types[0])
678+
# (natoms,) or (nframes, natoms,)
679+
natoms = np.shape(atom_types)[-1]
688680
if natoms == 0:
689681
assert coords.size == 0
690682
else:

deepmd/tf/modifier/base_modifier.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from abc import (
3+
abstractmethod,
4+
)
5+
26
from deepmd.dpmodel.modifier.base_modifier import (
37
make_base_modifier,
48
)
@@ -11,3 +15,23 @@ class BaseModifier(DeepPot, make_base_modifier()):
1115
def __init__(self, *args, **kwargs) -> None:
1216
"""Construct a basic model for different tasks."""
1317
DeepPot.__init__(self, *args, **kwargs)
18+
19+
@staticmethod
20+
@abstractmethod
21+
def get_params_from_frozen_model(model) -> dict:
22+
"""Extract the modifier parameters from a model.
23+
24+
This method should extract the necessary parameters from a model
25+
to create an instance of this modifier.
26+
27+
Parameters
28+
----------
29+
model
30+
The model from which to extract parameters
31+
32+
Returns
33+
-------
34+
dict
35+
The modifier parameters
36+
"""
37+
pass

deepmd/tf/modifier/dipole_charge.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import os
3+
from typing import (
4+
TYPE_CHECKING,
5+
)
36

47
import numpy as np
58

@@ -27,6 +30,11 @@
2730
run_sess,
2831
)
2932

33+
if TYPE_CHECKING:
34+
from deepmd.tf.infer import (
35+
DeepEval,
36+
)
37+
3038

3139
@BaseModifier.register("dipole_charge")
3240
class DipoleChargeModifier(DeepDipole, BaseModifier):
@@ -487,3 +495,52 @@ def modify_data(self, data: dict, data_sys: DeepmdData) -> None:
487495
data["force"] -= tot_f.reshape(data["force"].shape)
488496
if "find_virial" in data and data["find_virial"] == 1.0:
489497
data["virial"] -= tot_v.reshape(data["virial"].shape)
498+
499+
@staticmethod
500+
def get_params_from_frozen_model(model: "DeepEval") -> dict:
501+
"""Extract modifier parameters from a DeepEval model.
502+
503+
Parameters
504+
----------
505+
model : DeepEval
506+
The DeepEval model instance containing the modifier tensors.
507+
508+
Returns
509+
-------
510+
dict
511+
Dictionary containing modifier parameters:
512+
- model_name : str
513+
- model_charge_map : list[int]
514+
- sys_charge_map : list[int]
515+
- ewald_h : float
516+
- ewald_beta : float
517+
"""
518+
t_mdl_name = model._get_tensor("modifier_attr/mdl_name:0")
519+
t_mdl_charge_map = model._get_tensor("modifier_attr/mdl_charge_map:0")
520+
t_sys_charge_map = model._get_tensor("modifier_attr/sys_charge_map:0")
521+
t_ewald_h = model._get_tensor("modifier_attr/ewald_h:0")
522+
t_ewald_beta = model._get_tensor("modifier_attr/ewald_beta:0")
523+
[mdl_name, mdl_charge_map, sys_charge_map, ewald_h, ewald_beta] = run_sess(
524+
model.sess,
525+
[
526+
t_mdl_name,
527+
t_mdl_charge_map,
528+
t_sys_charge_map,
529+
t_ewald_h,
530+
t_ewald_beta,
531+
],
532+
)
533+
model_charge_map = [
534+
int(float(ii)) for ii in mdl_charge_map.decode("UTF-8").split()
535+
]
536+
sys_charge_map = [
537+
int(float(ii)) for ii in sys_charge_map.decode("UTF-8").split()
538+
]
539+
modifier_params = {
540+
"model_name": mdl_name.decode("UTF-8"),
541+
"model_charge_map": model_charge_map,
542+
"sys_charge_map": sys_charge_map,
543+
"ewald_h": ewald_h,
544+
"ewald_beta": ewald_beta,
545+
}
546+
return modifier_params

0 commit comments

Comments
 (0)