Skip to content

Commit 5c32147

Browse files
Feat: Add consistency test for ZBL between dp and pt (#4292)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced `DPZBLModel`, enhancing energy modeling capabilities. - Added `get_zbl_model` function for creating `DPZBLModel` from input data. - New `DPZBLLinearEnergyAtomicModel` class allows for complex interactions between atomic models. - **Bug Fixes** - Corrected typographical errors in multiple test classes to improve code clarity and consistency in method names. - Updated model type attributes for `DPZBLModel` and `LinearEnergyModel` to reflect accurate classifications. - **Tests** - Added comprehensive unit tests for energy models to ensure functionality across various backends. - Enhanced existing test classes with corrected method names for improved accuracy. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a468819 commit 5c32147

File tree

14 files changed

+356
-11
lines changed

14 files changed

+356
-11
lines changed

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
)
3535

3636

37+
@BaseAtomicModel.register("linear")
3738
class LinearEnergyAtomicModel(BaseAtomicModel):
3839
"""Linear model make linear combinations of several existing models.
3940
@@ -324,6 +325,7 @@ def is_aparam_nall(self) -> bool:
324325
return False
325326

326327

328+
@BaseAtomicModel.register("zbl")
327329
class DPZBLLinearEnergyAtomicModel(LinearEnergyAtomicModel):
328330
"""Model linearly combine a list of AtomicModels.
329331
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from typing import (
3+
Optional,
4+
)
5+
6+
from deepmd.dpmodel.atomic_model.linear_atomic_model import (
7+
DPZBLLinearEnergyAtomicModel,
8+
)
9+
from deepmd.dpmodel.model.base_model import (
10+
BaseModel,
11+
)
12+
from deepmd.dpmodel.model.dp_model import (
13+
DPModelCommon,
14+
)
15+
from deepmd.utils.data_system import (
16+
DeepmdDataSystem,
17+
)
18+
19+
from .make_model import (
20+
make_model,
21+
)
22+
23+
DPZBLModel_ = make_model(DPZBLLinearEnergyAtomicModel)
24+
25+
26+
@BaseModel.register("zbl")
27+
class DPZBLModel(DPZBLModel_):
28+
model_type = "zbl"
29+
30+
def __init__(
31+
self,
32+
*args,
33+
**kwargs,
34+
):
35+
super().__init__(*args, **kwargs)
36+
37+
@classmethod
38+
def update_sel(
39+
cls,
40+
train_data: DeepmdDataSystem,
41+
type_map: Optional[list[str]],
42+
local_jdata: dict,
43+
) -> tuple[dict, Optional[float]]:
44+
"""Update the selection and perform neighbor statistics.
45+
46+
Parameters
47+
----------
48+
train_data : DeepmdDataSystem
49+
data used to do neighbor statistics
50+
type_map : list[str], optional
51+
The name of each type of atoms
52+
local_jdata : dict
53+
The local data refer to the current class
54+
55+
Returns
56+
-------
57+
dict
58+
The updated local data
59+
float
60+
The minimum distance between two atoms
61+
"""
62+
local_jdata_cpy = local_jdata.copy()
63+
local_jdata_cpy["dpmodel"], min_nbor_dist = DPModelCommon.update_sel(
64+
train_data, type_map, local_jdata["dpmodel"]
65+
)
66+
return local_jdata_cpy, min_nbor_dist

deepmd/dpmodel/model/model.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from deepmd.dpmodel.atomic_model.dp_atomic_model import (
3+
DPAtomicModel,
4+
)
5+
from deepmd.dpmodel.atomic_model.pairtab_atomic_model import (
6+
PairTabAtomicModel,
7+
)
8+
from deepmd.dpmodel.descriptor.base_descriptor import (
9+
BaseDescriptor,
10+
)
211
from deepmd.dpmodel.descriptor.se_e2_a import (
312
DescrptSeA,
413
)
@@ -8,6 +17,9 @@
817
from deepmd.dpmodel.model.base_model import (
918
BaseModel,
1019
)
20+
from deepmd.dpmodel.model.dp_zbl_model import (
21+
DPZBLModel,
22+
)
1123
from deepmd.dpmodel.model.ener_model import (
1224
EnergyModel,
1325
)
@@ -55,6 +67,45 @@ def get_standard_model(data: dict) -> EnergyModel:
5567
)
5668

5769

70+
def get_zbl_model(data: dict) -> DPZBLModel:
71+
data["descriptor"]["ntypes"] = len(data["type_map"])
72+
descriptor = BaseDescriptor(**data["descriptor"])
73+
fitting_type = data["fitting_net"].pop("type")
74+
if fitting_type == "ener":
75+
fitting = EnergyFittingNet(
76+
ntypes=descriptor.get_ntypes(),
77+
dim_descrpt=descriptor.get_dim_out(),
78+
mixed_types=descriptor.mixed_types(),
79+
**data["fitting_net"],
80+
)
81+
else:
82+
raise ValueError(f"Unknown fitting type {fitting_type}")
83+
84+
dp_model = DPAtomicModel(descriptor, fitting, type_map=data["type_map"])
85+
# pairtab
86+
filepath = data["use_srtab"]
87+
pt_model = PairTabAtomicModel(
88+
filepath,
89+
data["descriptor"]["rcut"],
90+
data["descriptor"]["sel"],
91+
type_map=data["type_map"],
92+
)
93+
94+
rmin = data["sw_rmin"]
95+
rmax = data["sw_rmax"]
96+
atom_exclude_types = data.get("atom_exclude_types", [])
97+
pair_exclude_types = data.get("pair_exclude_types", [])
98+
return DPZBLModel(
99+
dp_model,
100+
pt_model,
101+
rmin,
102+
rmax,
103+
type_map=data["type_map"],
104+
atom_exclude_types=atom_exclude_types,
105+
pair_exclude_types=pair_exclude_types,
106+
)
107+
108+
58109
def get_spin_model(data: dict) -> SpinModel:
59110
"""Get a spin model from a dictionary.
60111
@@ -100,6 +151,8 @@ def get_model(data: dict):
100151
if model_type == "standard":
101152
if "spin" in data:
102153
return get_spin_model(data)
154+
elif "use_srtab" in data:
155+
return get_zbl_model(data)
103156
else:
104157
return get_standard_model(data)
105158
else:

deepmd/pt/model/model/dp_linear_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
@BaseModel.register("linear_ener")
3232
class LinearEnergyModel(DPLinearModel_):
33-
model_type = "ener"
33+
model_type = "linear_ener"
3434

3535
def __init__(
3636
self,

deepmd/pt/model/model/dp_zbl_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
@BaseModel.register("zbl")
3232
class DPZBLModel(DPZBLModel_):
33-
model_type = "ener"
33+
model_type = "zbl"
3434

3535
def __init__(
3636
self,

source/tests/consistent/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
class CommonTest(ABC):
7676
data: ClassVar[dict]
7777
"""Arguments data."""
78-
addtional_data: ClassVar[dict] = {}
78+
additional_data: ClassVar[dict] = {}
7979
"""Additional data that will not be checked."""
8080
tf_class: ClassVar[Optional[type]]
8181
"""TensorFlow model class."""
@@ -128,7 +128,7 @@ def init_backend_cls(self, cls) -> Any:
128128

129129
def pass_data_to_cls(self, cls, data) -> Any:
130130
"""Pass data to the class."""
131-
return cls(**data, **self.addtional_data)
131+
return cls(**data, **self.additional_data)
132132

133133
@abstractmethod
134134
def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]:

source/tests/consistent/fitting/test_dipole.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def setUp(self):
104104
self.atype.sort()
105105

106106
@property
107-
def addtional_data(self) -> dict:
107+
def additional_data(self) -> dict:
108108
(
109109
resnet_dt,
110110
precision,

source/tests/consistent/fitting/test_dos.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def setUp(self):
124124
).reshape(-1, 1)
125125

126126
@property
127-
def addtional_data(self) -> dict:
127+
def additional_data(self) -> dict:
128128
(
129129
resnet_dt,
130130
precision,

source/tests/consistent/fitting/test_ener.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def setUp(self):
134134
).reshape(-1, 1)
135135

136136
@property
137-
def addtional_data(self) -> dict:
137+
def additional_data(self) -> dict:
138138
(
139139
resnet_dt,
140140
precision,

source/tests/consistent/fitting/test_polar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def setUp(self):
104104
self.atype.sort()
105105

106106
@property
107-
def addtional_data(self) -> dict:
107+
def additional_data(self) -> dict:
108108
(
109109
resnet_dt,
110110
precision,

0 commit comments

Comments
 (0)