Skip to content

Commit 3020ff7

Browse files
authored
fix(pt): get correct sel_type in pt model (#5097)
1. Allow backend-convert from pt to tf with `atom_exclude_types` 2. Fix bug in `get_sel_type` method of pt model (by calling `fitting.reinit_exclude` in atomic model init). Fix #5096 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fitting now reinitializes exclusion/selection settings during model initialization and after type-map changes; deserialization injects selection info when atom exclusions exist to ensure correct selection behavior. * **Tests** * Added unit and cross-backend tests to verify atom-exclude → selection computation and parity between backends; expanded coverage for exclusion scenarios. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 4c8880c commit 3020ff7

File tree

7 files changed

+112
-19
lines changed

7 files changed

+112
-19
lines changed

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def __init__(
5151
self.type_map = type_map
5252
self.descriptor = descriptor
5353
self.fitting = fitting
54+
if hasattr(self.fitting, "reinit_exclude"):
55+
self.fitting.reinit_exclude(self.atom_exclude_types)
5456
self.type_map = type_map
5557
super().init_out_stat()
5658

@@ -191,7 +193,7 @@ def change_type_map(
191193
if model_with_new_type_stat is not None
192194
else None,
193195
)
194-
self.fitting_net.change_type_map(type_map=type_map)
196+
self.fitting.change_type_map(type_map=type_map)
195197

196198
def serialize(self) -> dict:
197199
dd = super().serialize()

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def __init__(
6464
self.rcut = self.descriptor.get_rcut()
6565
self.sel = self.descriptor.get_sel()
6666
self.fitting_net = fitting
67+
if hasattr(self.fitting_net, "reinit_exclude"):
68+
self.fitting_net.reinit_exclude(self.atom_exclude_types)
6769
super().init_out_stat()
6870
self.enable_eval_descriptor_hook = False
6971
self.enable_eval_fitting_last_layer_hook = False
@@ -151,6 +153,9 @@ def change_type_map(
151153
else None,
152154
)
153155
self.fitting_net.change_type_map(type_map=type_map)
156+
# Reinitialize fitting to get correct sel_type
157+
if hasattr(self.fitting_net, "reinit_exclude"):
158+
self.fitting_net.reinit_exclude(self.atom_exclude_types)
154159

155160
def has_message_passing(self) -> bool:
156161
"""Returns whether the atomic model has message passing."""

deepmd/tf/model/model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,16 +1003,23 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor":
10031003
check_version_compatibility(data.pop("@version", 2), 2, 1)
10041004
descriptor = Descriptor.deserialize(data.pop("descriptor"), suffix=suffix)
10051005
# bias_atom_e and out_bias are now completely independent - no conversion needed
1006-
fitting = Fitting.deserialize(data.pop("fitting"), suffix=suffix)
1006+
fitting_dict = data.pop("fitting", {})
1007+
atom_exclude_types = data.pop("atom_exclude_types", [])
1008+
if len(atom_exclude_types) > 0:
1009+
# get sel_type from complement of atom_exclude_types
1010+
full_type_list = np.arange(len(data["type_map"]), dtype=int)
1011+
sel_type = np.setdiff1d(
1012+
full_type_list, atom_exclude_types, assume_unique=True
1013+
)
1014+
fitting_dict["sel_type"] = sel_type.tolist()
1015+
fitting = Fitting.deserialize(fitting_dict, suffix=suffix)
10071016
# pass descriptor type embedding to model
10081017
if descriptor.explicit_ntypes:
10091018
type_embedding = descriptor.type_embedding
10101019
fitting.dim_descrpt -= type_embedding.neuron[-1]
10111020
else:
10121021
type_embedding = None
10131022
# BEGINE not supported keys
1014-
if len(data.pop("atom_exclude_types")) > 0:
1015-
raise NotImplementedError("atom_exclude_types is not supported")
10161023
if len(data.pop("pair_exclude_types")) > 0:
10171024
raise NotImplementedError("pair_exclude_types is not supported")
10181025
data.pop("rcond", None)

source/tests/consistent/model/test_dipole.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,15 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
204204
ret[1].ravel(),
205205
)
206206
raise ValueError(f"Unknown backend: {backend}")
207+
208+
def test_atom_exclude_types(self):
209+
if self.skip_pt:
210+
self.skipTest("Unsupported backend")
211+
if self.skip_tf:
212+
self.skipTest("Unsupported backend")
213+
_ret, data = self.get_reference_ret_serialization(self.RefBackend.PT)
214+
data["atom_exclude_types"] = [1]
215+
self.reset_unique_id()
216+
tf_obj = self.tf_class.deserialize(data, suffix=self.unique_id)
217+
pt_obj = self.pt_class.deserialize(data)
218+
self.assertEqual(tf_obj.get_sel_type(), pt_obj.get_sel_type())

source/tests/consistent/model/test_polar.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,15 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]:
198198
ret[1].ravel(),
199199
)
200200
raise ValueError(f"Unknown backend: {backend}")
201+
202+
def test_atom_exclude_types(self):
203+
if self.skip_pt:
204+
self.skipTest("Unsupported backend")
205+
if self.skip_tf:
206+
self.skipTest("Unsupported backend")
207+
_ret, data = self.get_reference_ret_serialization(self.RefBackend.PT)
208+
data["atom_exclude_types"] = [1]
209+
self.reset_unique_id()
210+
tf_obj = self.tf_class.deserialize(data, suffix=self.unique_id)
211+
pt_obj = self.pt_class.deserialize(data)
212+
self.assertEqual(tf_obj.get_sel_type(), pt_obj.get_sel_type())

source/tests/pt/model/test_get_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ def test_model_attr(self) -> None:
6060
]
6161
},
6262
)
63+
full_type_list = np.arange(len(atomic_model.type_map), dtype=int)
64+
atom_exclude_types = np.setdiff1d(
65+
full_type_list,
66+
self.model.get_sel_type(),
67+
).tolist()
68+
self.assertEqual(atom_exclude_types, [1])
6369
self.assertEqual(atomic_model.atom_exclude_types, [1])
6470
self.assertEqual(atomic_model.pair_exclude_types, [[1, 2]])
6571

source/tests/universal/dpmodel/atomc_model/test_atomic_model.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import unittest
33

4+
import numpy as np
5+
46
from deepmd.dpmodel.atomic_model import (
57
DPAtomicModel,
68
DPZBLLinearEnergyAtomicModel,
@@ -72,6 +74,13 @@
7274
)
7375

7476

77+
def make_sel_type_from_atom_exclude_types(type_map, atom_exclude_types):
78+
"""Get sel_type from complement of atom_exclude_types."""
79+
full_type_list = np.arange(len(type_map), dtype=int)
80+
sel_type = np.setdiff1d(full_type_list, atom_exclude_types, assume_unique=True)
81+
return sel_type.tolist()
82+
83+
7584
@parameterized(
7685
des_parameterized=(
7786
(
@@ -85,6 +94,7 @@
8594
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
8695
), # descrpt_class_param & class
8796
((FittingParamEnergy, EnergyFittingNet),), # fitting_class_param & class
97+
([], [0]), # atom_exclude_types
8898
),
8999
fit_parameterized=(
90100
(
@@ -97,6 +107,7 @@
97107
(
98108
*[(param_func, EnergyFittingNet) for param_func in FittingParamEnergyList],
99109
), # fitting_class_param & class
110+
([], [0]), # atom_exclude_types
100111
),
101112
)
102113
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
@@ -128,16 +139,22 @@ def setUpClass(cls) -> None:
128139
**cls.input_dict_ft,
129140
)
130141
cls.module = DPAtomicModel(
131-
ds,
132-
ft,
133-
type_map=cls.expected_type_map,
142+
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
134143
)
135144
cls.output_def = cls.module.atomic_output_def().get_data()
136145
cls.expected_has_message_passing = ds.has_message_passing()
137146
cls.expected_sel_type = ft.get_sel_type()
138147
cls.expected_dim_fparam = ft.get_dim_fparam()
139148
cls.expected_dim_aparam = ft.get_dim_aparam()
140149

150+
def test_sel_type_from_atom_exclude_types(self):
151+
self.assertEqual(
152+
make_sel_type_from_atom_exclude_types(
153+
self.expected_type_map, self.param[2]
154+
),
155+
self.expected_sel_type,
156+
)
157+
141158

142159
@parameterized(
143160
des_parameterized=(
@@ -152,6 +169,7 @@ def setUpClass(cls) -> None:
152169
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
153170
), # descrpt_class_param & class
154171
((FittingParamDos, DOSFittingNet),), # fitting_class_param & class
172+
([], [0]), # atom_exclude_types
155173
),
156174
fit_parameterized=(
157175
(
@@ -164,6 +182,7 @@ def setUpClass(cls) -> None:
164182
(
165183
*[(param_func, DOSFittingNet) for param_func in FittingParamDosList],
166184
), # fitting_class_param & class
185+
([], [0]), # atom_exclude_types
167186
),
168187
)
169188
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
@@ -195,16 +214,22 @@ def setUpClass(cls) -> None:
195214
**cls.input_dict_ft,
196215
)
197216
cls.module = DPAtomicModel(
198-
ds,
199-
ft,
200-
type_map=cls.expected_type_map,
217+
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
201218
)
202219
cls.output_def = cls.module.atomic_output_def().get_data()
203220
cls.expected_has_message_passing = ds.has_message_passing()
204221
cls.expected_sel_type = ft.get_sel_type()
205222
cls.expected_dim_fparam = ft.get_dim_fparam()
206223
cls.expected_dim_aparam = ft.get_dim_aparam()
207224

225+
def test_sel_type_from_atom_exclude_types(self):
226+
self.assertEqual(
227+
make_sel_type_from_atom_exclude_types(
228+
self.expected_type_map, self.param[2]
229+
),
230+
self.expected_sel_type,
231+
)
232+
208233

209234
@parameterized(
210235
des_parameterized=(
@@ -216,6 +241,7 @@ def setUpClass(cls) -> None:
216241
(DescriptorParamHybridMixed, DescrptHybrid),
217242
), # descrpt_class_param & class
218243
((FittingParamDipole, DipoleFitting),), # fitting_class_param & class
244+
([], [0]), # atom_exclude_types
219245
),
220246
fit_parameterized=(
221247
(
@@ -226,6 +252,7 @@ def setUpClass(cls) -> None:
226252
(
227253
*[(param_func, DipoleFitting) for param_func in FittingParamDipoleList],
228254
), # fitting_class_param & class
255+
([], [0]), # atom_exclude_types
229256
),
230257
)
231258
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
@@ -258,16 +285,22 @@ def setUpClass(cls) -> None:
258285
**cls.input_dict_ft,
259286
)
260287
cls.module = DPAtomicModel(
261-
ds,
262-
ft,
263-
type_map=cls.expected_type_map,
288+
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
264289
)
265290
cls.output_def = cls.module.atomic_output_def().get_data()
266291
cls.expected_has_message_passing = ds.has_message_passing()
267292
cls.expected_sel_type = ft.get_sel_type()
268293
cls.expected_dim_fparam = ft.get_dim_fparam()
269294
cls.expected_dim_aparam = ft.get_dim_aparam()
270295

296+
def test_sel_type_from_atom_exclude_types(self):
297+
self.assertEqual(
298+
make_sel_type_from_atom_exclude_types(
299+
self.expected_type_map, self.param[2]
300+
),
301+
self.expected_sel_type,
302+
)
303+
271304

272305
@parameterized(
273306
des_parameterized=(
@@ -279,6 +312,7 @@ def setUpClass(cls) -> None:
279312
(DescriptorParamHybridMixed, DescrptHybrid),
280313
), # descrpt_class_param & class
281314
((FittingParamPolar, PolarFitting),), # fitting_class_param & class
315+
([], [0]), # atom_exclude_types
282316
),
283317
fit_parameterized=(
284318
(
@@ -289,6 +323,7 @@ def setUpClass(cls) -> None:
289323
(
290324
*[(param_func, PolarFitting) for param_func in FittingParamPolarList],
291325
), # fitting_class_param & class
326+
([], [0]), # atom_exclude_types
292327
),
293328
)
294329
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
@@ -321,16 +356,22 @@ def setUpClass(cls) -> None:
321356
**cls.input_dict_ft,
322357
)
323358
cls.module = DPAtomicModel(
324-
ds,
325-
ft,
326-
type_map=cls.expected_type_map,
359+
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
327360
)
328361
cls.output_def = cls.module.atomic_output_def().get_data()
329362
cls.expected_has_message_passing = ds.has_message_passing()
330363
cls.expected_sel_type = ft.get_sel_type()
331364
cls.expected_dim_fparam = ft.get_dim_fparam()
332365
cls.expected_dim_aparam = ft.get_dim_aparam()
333366

367+
def test_sel_type_from_atom_exclude_types(self):
368+
self.assertEqual(
369+
make_sel_type_from_atom_exclude_types(
370+
self.expected_type_map, self.param[2]
371+
),
372+
self.expected_sel_type,
373+
)
374+
334375

335376
@parameterized(
336377
des_parameterized=(
@@ -415,6 +456,7 @@ def setUpClass(cls) -> None:
415456
(DescriptorParamHybridMixedTTebd, DescrptHybrid),
416457
), # descrpt_class_param & class
417458
((FittingParamProperty, PropertyFittingNet),), # fitting_class_param & class
459+
([], [0]), # atom_exclude_types
418460
),
419461
fit_parameterized=(
420462
(
@@ -428,6 +470,7 @@ def setUpClass(cls) -> None:
428470
for param_func in FittingParamPropertyList
429471
],
430472
), # fitting_class_param & class
473+
([], [0]), # atom_exclude_types
431474
),
432475
)
433476
@unittest.skipIf(TEST_DEVICE != "cpu" and CI, "Only test on CPU.")
@@ -460,12 +503,18 @@ def setUpClass(cls) -> None:
460503
**cls.input_dict_ft,
461504
)
462505
cls.module = DPAtomicModel(
463-
ds,
464-
ft,
465-
type_map=cls.expected_type_map,
506+
ds, ft, type_map=cls.expected_type_map, atom_exclude_types=cls.param[2]
466507
)
467508
cls.output_def = cls.module.atomic_output_def().get_data()
468509
cls.expected_has_message_passing = ds.has_message_passing()
469510
cls.expected_sel_type = ft.get_sel_type()
470511
cls.expected_dim_fparam = ft.get_dim_fparam()
471512
cls.expected_dim_aparam = ft.get_dim_aparam()
513+
514+
def test_sel_type_from_atom_exclude_types(self):
515+
self.assertEqual(
516+
make_sel_type_from_atom_exclude_types(
517+
self.expected_type_map, self.param[2]
518+
),
519+
self.expected_sel_type,
520+
)

0 commit comments

Comments
 (0)