Skip to content

Commit d9f1b6d

Browse files
committed
add fitting.reinit_exclude for dpmodel; add UT
1 parent 4e4274c commit d9f1b6d

File tree

2 files changed

+67
-16
lines changed

2 files changed

+67
-16
lines changed

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(
5151
self.type_map = type_map
5252
self.descriptor = descriptor
5353
self.fitting = fitting
54+
if hasattr(self.fitting, "reinit_exclude"):
55+
self.fitting.reinit_exclude(self.atom_exclude_types)
5456
self.type_map = type_map
5557
super().init_out_stat()
5658

@@ -191,7 +193,7 @@ def change_type_map(
191193
if model_with_new_type_stat is not None
192194
else None,
193195
)
194-
self.fitting_net.change_type_map(type_map=type_map)
196+
self.fitting.change_type_map(type_map=type_map)
195197

196198
def serialize(self) -> dict:
197199
dd = super().serialize()

source/tests/universal/dpmodel/atomc_model/test_atomic_model.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import unittest
33

4+
import numpy as np
5+
46
from deepmd.dpmodel.atomic_model import (
57
DPAtomicModel,
68
DPZBLLinearEnergyAtomicModel,
@@ -72,6 +74,13 @@
7274
)
7375

7476

77+
def make_sel_type_from_atom_exclude_types(type_map, atom_exclude_types):
78+
"""Get sel_type from complement of atom_exclude_types."""
79+
full_type_list = np.arange(len(type_map), dtype=int)
80+
sel_type = np.setdiff1d(full_type_list, atom_exclude_types, assume_unique=True)
81+
return sel_type.tolist()
82+
83+
7584
@parameterized(
7685
des_parameterized=(
7786
(
@@ -85,6 +94,7 @@
8594
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
8695
), # descrpt_class_param & class
8796
((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class
97+
([], [0]), # atom_exclude_types
8898
),
8999
fit_parameterized=(
90100
(
@@ -97,6 +107,7 @@
97107
(
98108
*[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList],
99109
), # fitting_class_param & class
110+
([], [0]), # atom_exclude_types
100111
),
101112
)
102113
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
@@ -128,16 +139,22 @@ def setUpClass(cls) -> None:
128139
**cls.input_dict_ft,
129140
)
130141
cls.module = DPAtomicModel(
131-
ds,
132-
ft,
133-
type_map=cls.expected_type_map,
142+
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
134143
)
135144
cls.output_def = cls.module.atomic_output_def().get_data()
136145
cls.expected_has_message_passing = ds.has_message_passing()
137146
cls.expected_sel_type = ft.get_sel_type()
138147
cls.expected_dim_fparam = ft.get_dim_fparam()
139148
cls.expected_dim_aparam = ft.get_dim_aparam()
140149

150+
def test_sel_type_from_atom_exclude_types(self):
151+
self.assertEqual(
152+
make_sel_type_from_atom_exclude_types(
153+
self.expected_type_map, self.param[2]
154+
),
155+
self.expected_sel_type,
156+
)
157+
141158

142159
@parameterized(
143160
des_parameterized=(
@@ -152,6 +169,7 @@ def setUpClass(cls) -> None:
152169
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
153170
), # descrpt_class_param & class
154171
((FittingParamDos, DOSFittingNet),), # fitting_class_param & class
172+
([], [0]), # atom_exclude_types
155173
),
156174
fit_parameterized=(
157175
(
@@ -164,6 +182,7 @@ def setUpClass(cls) -> None:
164182
(
165183
*[(param_func, DOSFittingNet) for param_func in FittingParamDosList],
166184
), # fitting_class_param & class
185+
([], [0]), # atom_exclude_types
167186
),
168187
)
169188
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
@@ -195,16 +214,22 @@ def setUpClass(cls) -> None:
195214
**cls.input_dict_ft,
196215
)
197216
cls.module = DPAtomicModel(
198-
ds,
199-
ft,
200-
type_map=cls.expected_type_map,
217+
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
201218
)
202219
cls.output_def = cls.module.atomic_output_def().get_data()
203220
cls.expected_has_message_passing = ds.has_message_passing()
204221
cls.expected_sel_type = ft.get_sel_type()
205222
cls.expected_dim_fparam = ft.get_dim_fparam()
206223
cls.expected_dim_aparam = ft.get_dim_aparam()
207224

225+
def test_sel_type_from_atom_exclude_types(self):
226+
self.assertEqual(
227+
make_sel_type_from_atom_exclude_types(
228+
self.expected_type_map, self.param[2]
229+
),
230+
self.expected_sel_type,
231+
)
232+
208233

209234
@parameterized(
210235
des_parameterized=(
@@ -216,6 +241,7 @@ def setUpClass(cls) -> None:
216241
(DescriptorParamHybridMixed, DescrptHybrid),
217242
), # descrpt_class_param & class
218243
((FittingParamDipole, DipoleFitting),), # fitting_class_param & class
244+
([], [0]), # atom_exclude_types
219245
),
220246
fit_parameterized=(
221247
(
@@ -226,6 +252,7 @@ def setUpClass(cls) -> None:
226252
(
227253
*[(param_func, DipoleFitting) for param_func in FittingParamDipoleList],
228254
), # fitting_class_param & class
255+
([], [0]), # atom_exclude_types
229256
),
230257
)
231258
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
@@ -258,16 +285,22 @@ def setUpClass(cls) -> None:
258285
**cls.input_dict_ft,
259286
)
260287
cls.module = DPAtomicModel(
261-
ds,
262-
ft,
263-
type_map=cls.expected_type_map,
288+
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
264289
)
265290
cls.output_def = cls.module.atomic_output_def().get_data()
266291
cls.expected_has_message_passing = ds.has_message_passing()
267292
cls.expected_sel_type = ft.get_sel_type()
268293
cls.expected_dim_fparam = ft.get_dim_fparam()
269294
cls.expected_dim_aparam = ft.get_dim_aparam()
270295

296+
def test_sel_type_from_atom_exclude_types(self):
297+
self.assertEqual(
298+
make_sel_type_from_atom_exclude_types(
299+
self.expected_type_map, self.param[2]
300+
),
301+
self.expected_sel_type,
302+
)
303+
271304

272305
@parameterized(
273306
des_parameterized=(
@@ -279,6 +312,7 @@ def setUpClass(cls) -> None:
279312
(DescriptorParamHybridMixed, DescrptHybrid),
280313
), # descrpt_class_param & class
281314
((FittingParamPolar, PolarFitting),), # fitting_class_param & class
315+
([], [0]), # atom_exclude_types
282316
),
283317
fit_parameterized=(
284318
(
@@ -289,6 +323,7 @@ def setUpClass(cls) -> None:
289323
(
290324
*[(param_func, PolarFitting) for param_func in FittingParamPolarList],
291325
), # fitting_class_param & class
326+
([], [0]), # atom_exclude_types
292327
),
293328
)
294329
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
@@ -321,16 +356,22 @@ def setUpClass(cls) -> None:
321356
**cls.input_dict_ft,
322357
)
323358
cls.module = DPAtomicModel(
324-
ds,
325-
ft,
326-
type_map=cls.expected_type_map,
359+
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
327360
)
328361
cls.output_def = cls.module.atomic_output_def().get_data()
329362
cls.expected_has_message_passing = ds.has_message_passing()
330363
cls.expected_sel_type = ft.get_sel_type()
331364
cls.expected_dim_fparam = ft.get_dim_fparam()
332365
cls.expected_dim_aparam = ft.get_dim_aparam()
333366

367+
def test_sel_type_from_atom_exclude_types(self):
368+
self.assertEqual(
369+
make_sel_type_from_atom_exclude_types(
370+
self.expected_type_map, self.param[2]
371+
),
372+
self.expected_sel_type,
373+
)
374+
334375

335376
@parameterized(
336377
des_parameterized=(
@@ -415,6 +456,7 @@ def setUpClass(cls) -> None:
415456
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
416457
), # descrpt_class_param & class
417458
((FittingParamProperty, PropertyFittingNet),), # fitting_class_param & class
459+
([], [0]), # atom_exclude_types
418460
),
419461
fit_parameterized=(
420462
(
@@ -428,6 +470,7 @@ def setUpClass(cls) -> None:
428470
for param_func in FittingParamPropertyList
429471
],
430472
), # fitting_class_param & class
473+
([], [0]), # atom_exclude_types
431474
),
432475
)
433476
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
@@ -460,12 +503,18 @@ def setUpClass(cls) -> None:
460503
**cls.input_dict_ft,
461504
)
462505
cls.module = DPAtomicModel(
463-
ds,
464-
ft,
465-
type_map=cls.expected_type_map,
506+
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
466507
)
467508
cls.output_def = cls.module.atomic_output_def().get_data()
468509
cls.expected_has_message_passing = ds.has_message_passing()
469510
cls.expected_sel_type = ft.get_sel_type()
470511
cls.expected_dim_fparam = ft.get_dim_fparam()
471512
cls.expected_dim_aparam = ft.get_dim_aparam()
513+
514+
def test_sel_type_from_atom_exclude_types(self):
515+
self.assertEqual(
516+
make_sel_type_from_atom_exclude_types(
517+
self.expected_type_map, self.param[2]
518+
),
519+
self.expected_sel_type,
520+
)

0 commit comments

Comments
 (0)