11# SPDX-License-Identifier: LGPL-3.0-or-later
22import unittest
33
4+ import numpy as np
5+
46from deepmd .dpmodel .atomic_model import (
57 DPAtomicModel ,
68 DPZBLLinearEnergyAtomicModel ,
7274)
7375
7476
77+ def make_sel_type_from_atom_exclude_types (type_map , atom_exclude_types ):
78+ """Get sel_type from complement of atom_exclude_types."""
79+ full_type_list = np .arange (len (type_map ), dtype = int )
80+ sel_type = np .setdiff1d (full_type_list , atom_exclude_types , assume_unique = True )
81+ return sel_type .tolist ()
82+
83+
7584@parameterized (
7685 des_parameterized = (
7786 (
8594 (DescriptorParamHybridMixedTTebd , DescrptHybrid ),
8695 ), # descrpt_class_param & class
8796 ((FittingParamEnergy , EnergyFittingNet ),), # fitting_class_param & class
97+ ([], [0 ]), # atom_exclude_types
8898 ),
8999 fit_parameterized = (
90100 (
97107 (
98108 * [(param_func , EnergyFittingNet ) for param_func in FittingParamEnergyList ],
99109 ), # fitting_class_param & class
110+ ([], [0 ]), # atom_exclude_types
100111 ),
101112)
102113@unittest .skipIf (TEST_DEVICE != "cpu" and CI , "Only test on CPU." )
@@ -128,16 +139,22 @@ def setUpClass(cls) -> None:
128139 ** cls .input_dict_ft ,
129140 )
130141 cls .module = DPAtomicModel (
131- ds ,
132- ft ,
133- type_map = cls .expected_type_map ,
142+ ds , ft , type_map = cls .expected_type_map , atom_exclude_types = cls .param [2 ]
134143 )
135144 cls .output_def = cls .module .atomic_output_def ().get_data ()
136145 cls .expected_has_message_passing = ds .has_message_passing ()
137146 cls .expected_sel_type = ft .get_sel_type ()
138147 cls .expected_dim_fparam = ft .get_dim_fparam ()
139148 cls .expected_dim_aparam = ft .get_dim_aparam ()
140149
150+ def test_sel_type_from_atom_exclude_types (self ):
151+ self .assertEqual (
152+ make_sel_type_from_atom_exclude_types (
153+ self .expected_type_map , self .param [2 ]
154+ ),
155+ self .expected_sel_type ,
156+ )
157+
141158
142159@parameterized (
143160 des_parameterized = (
@@ -152,6 +169,7 @@ def setUpClass(cls) -> None:
152169 (DescriptorParamHybridMixedTTebd , DescrptHybrid ),
153170 ), # descrpt_class_param & class
154171 ((FittingParamDos , DOSFittingNet ),), # fitting_class_param & class
172+ ([], [0 ]), # atom_exclude_types
155173 ),
156174 fit_parameterized = (
157175 (
@@ -164,6 +182,7 @@ def setUpClass(cls) -> None:
164182 (
165183 * [(param_func , DOSFittingNet ) for param_func in FittingParamDosList ],
166184 ), # fitting_class_param & class
185+ ([], [0 ]), # atom_exclude_types
167186 ),
168187)
169188@unittest .skipIf (TEST_DEVICE != "cpu" and CI , "Only test on CPU." )
@@ -195,16 +214,22 @@ def setUpClass(cls) -> None:
195214 ** cls .input_dict_ft ,
196215 )
197216 cls .module = DPAtomicModel (
198- ds ,
199- ft ,
200- type_map = cls .expected_type_map ,
217+ ds , ft , type_map = cls .expected_type_map , atom_exclude_types = cls .param [2 ]
201218 )
202219 cls .output_def = cls .module .atomic_output_def ().get_data ()
203220 cls .expected_has_message_passing = ds .has_message_passing ()
204221 cls .expected_sel_type = ft .get_sel_type ()
205222 cls .expected_dim_fparam = ft .get_dim_fparam ()
206223 cls .expected_dim_aparam = ft .get_dim_aparam ()
207224
225+ def test_sel_type_from_atom_exclude_types (self ):
226+ self .assertEqual (
227+ make_sel_type_from_atom_exclude_types (
228+ self .expected_type_map , self .param [2 ]
229+ ),
230+ self .expected_sel_type ,
231+ )
232+
208233
209234@parameterized (
210235 des_parameterized = (
@@ -216,6 +241,7 @@ def setUpClass(cls) -> None:
216241 (DescriptorParamHybridMixed , DescrptHybrid ),
217242 ), # descrpt_class_param & class
218243 ((FittingParamDipole , DipoleFitting ),), # fitting_class_param & class
244+ ([], [0 ]), # atom_exclude_types
219245 ),
220246 fit_parameterized = (
221247 (
@@ -226,6 +252,7 @@ def setUpClass(cls) -> None:
226252 (
227253 * [(param_func , DipoleFitting ) for param_func in FittingParamDipoleList ],
228254 ), # fitting_class_param & class
255+ ([], [0 ]), # atom_exclude_types
229256 ),
230257)
231258@unittest .skipIf (TEST_DEVICE != "cpu" and CI , "Only test on CPU." )
@@ -258,16 +285,22 @@ def setUpClass(cls) -> None:
258285 ** cls .input_dict_ft ,
259286 )
260287 cls .module = DPAtomicModel (
261- ds ,
262- ft ,
263- type_map = cls .expected_type_map ,
288+ ds , ft , type_map = cls .expected_type_map , atom_exclude_types = cls .param [2 ]
264289 )
265290 cls .output_def = cls .module .atomic_output_def ().get_data ()
266291 cls .expected_has_message_passing = ds .has_message_passing ()
267292 cls .expected_sel_type = ft .get_sel_type ()
268293 cls .expected_dim_fparam = ft .get_dim_fparam ()
269294 cls .expected_dim_aparam = ft .get_dim_aparam ()
270295
296+ def test_sel_type_from_atom_exclude_types (self ):
297+ self .assertEqual (
298+ make_sel_type_from_atom_exclude_types (
299+ self .expected_type_map , self .param [2 ]
300+ ),
301+ self .expected_sel_type ,
302+ )
303+
271304
272305@parameterized (
273306 des_parameterized = (
@@ -279,6 +312,7 @@ def setUpClass(cls) -> None:
279312 (DescriptorParamHybridMixed , DescrptHybrid ),
280313 ), # descrpt_class_param & class
281314 ((FittingParamPolar , PolarFitting ),), # fitting_class_param & class
315+ ([], [0 ]), # atom_exclude_types
282316 ),
283317 fit_parameterized = (
284318 (
@@ -289,6 +323,7 @@ def setUpClass(cls) -> None:
289323 (
290324 * [(param_func , PolarFitting ) for param_func in FittingParamPolarList ],
291325 ), # fitting_class_param & class
326+ ([], [0 ]), # atom_exclude_types
292327 ),
293328)
294329@unittest .skipIf (TEST_DEVICE != "cpu" and CI , "Only test on CPU." )
@@ -321,16 +356,22 @@ def setUpClass(cls) -> None:
321356 ** cls .input_dict_ft ,
322357 )
323358 cls .module = DPAtomicModel (
324- ds ,
325- ft ,
326- type_map = cls .expected_type_map ,
359+ ds , ft , type_map = cls .expected_type_map , atom_exclude_types = cls .param [2 ]
327360 )
328361 cls .output_def = cls .module .atomic_output_def ().get_data ()
329362 cls .expected_has_message_passing = ds .has_message_passing ()
330363 cls .expected_sel_type = ft .get_sel_type ()
331364 cls .expected_dim_fparam = ft .get_dim_fparam ()
332365 cls .expected_dim_aparam = ft .get_dim_aparam ()
333366
367+ def test_sel_type_from_atom_exclude_types (self ):
368+ self .assertEqual (
369+ make_sel_type_from_atom_exclude_types (
370+ self .expected_type_map , self .param [2 ]
371+ ),
372+ self .expected_sel_type ,
373+ )
374+
334375
335376@parameterized (
336377 des_parameterized = (
@@ -415,6 +456,7 @@ def setUpClass(cls) -> None:
415456 (DescriptorParamHybridMixedTTebd , DescrptHybrid ),
416457 ), # descrpt_class_param & class
417458 ((FittingParamProperty , PropertyFittingNet ),), # fitting_class_param & class
459+ ([], [0 ]), # atom_exclude_types
418460 ),
419461 fit_parameterized = (
420462 (
@@ -428,6 +470,7 @@ def setUpClass(cls) -> None:
428470 for param_func in FittingParamPropertyList
429471 ],
430472 ), # fitting_class_param & class
473+ ([], [0 ]), # atom_exclude_types
431474 ),
432475)
433476@unittest .skipIf (TEST_DEVICE != "cpu" and CI , "Only test on CPU." )
@@ -460,12 +503,18 @@ def setUpClass(cls) -> None:
460503 ** cls .input_dict_ft ,
461504 )
462505 cls .module = DPAtomicModel (
463- ds ,
464- ft ,
465- type_map = cls .expected_type_map ,
506+ ds , ft , type_map = cls .expected_type_map , atom_exclude_types = cls .param [2 ]
466507 )
467508 cls .output_def = cls .module .atomic_output_def ().get_data ()
468509 cls .expected_has_message_passing = ds .has_message_passing ()
469510 cls .expected_sel_type = ft .get_sel_type ()
470511 cls .expected_dim_fparam = ft .get_dim_fparam ()
471512 cls .expected_dim_aparam = ft .get_dim_aparam ()
513+
514+ def test_sel_type_from_atom_exclude_types (self ):
515+ self .assertEqual (
516+ make_sel_type_from_atom_exclude_types (
517+ self .expected_type_map , self .param [2 ]
518+ ),
519+ self .expected_sel_type ,
520+ )
0 commit comments