Skip to content

Commit 6ed3c44

Browse files
add a public API to register data types dynamically (#532)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e239d0b commit 6ed3c44

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

dpdata/system.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,17 @@ def pick_by_amber_mask(self, param, maskstr, pass_coords=False, nopbc=None):
954954
idx = pick_by_amber_mask(parm, maskstr)
955955
return self.pick_atom_idx(idx, nopbc=nopbc)
956956

957+
@classmethod
958+
def register_data_type(cls, *data_type: Tuple[DataType]):
959+
"""Register data type.
960+
961+
Parameters
962+
----------
963+
*data_type : tuple[DataType]
964+
data type to be regiestered
965+
"""
966+
cls.DTYPES = cls.DTYPES + tuple(data_type)
967+
957968

958969
def get_cell_perturb_matrix(cell_pert_fraction):
959970
if cell_pert_fraction < 0:
@@ -1599,9 +1610,9 @@ def to_format(self, *args, **kwargs):
15991610
setattr(MultiSystems, method, get_func(formatcls))
16001611

16011612
# at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized
1602-
System.DTYPES = System.DTYPES + get_data_types(labeled=False)
1603-
LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=False)
1604-
LabeledSystem.DTYPES = LabeledSystem.DTYPES + get_data_types(labeled=True)
1613+
System.register_data_type(*get_data_types(labeled=False))
1614+
LabeledSystem.register_data_type(*get_data_types(labeled=False))
1615+
LabeledSystem.register_data_type(*get_data_types(labeled=True))
16051616

16061617

16071618
add_format_methods()

0 commit comments

Comments
 (0)