Skip to content

Commit 9b21f6f

Browse files
replace the old data type with the same name and throw warning (#541)
When a data type is unexpectedly registered twice, the behavior of some methods will be strange (for example, `append`). For this reason, the old one is replaced, and a warning is thrown. --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 037c2b8 commit 9b21f6f

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

dpdata/system.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# %%
22
import glob
33
import os
4+
import warnings
45
from copy import deepcopy
56
from typing import Any, Dict, Optional, Tuple, Union
67

@@ -963,7 +964,16 @@ def register_data_type(cls, *data_type: Tuple[DataType]):
963964
*data_type : tuple[DataType]
964965
data type to be regiestered
965966
"""
966-
cls.DTYPES = cls.DTYPES + tuple(data_type)
967+
all_dtypes = cls.DTYPES + tuple(data_type)
968+
dtypes_dict = {}
969+
for dt in all_dtypes:
970+
if dt.name in dtypes_dict:
971+
warnings.warn(
972+
f"Data type {dt.name} is registered twice; only the newly registered one will be used.",
973+
UserWarning,
974+
)
975+
dtypes_dict[dt.name] = dt
976+
cls.DTYPES = tuple(dtypes_dict.values())
967977

968978

969979
def get_cell_perturb_matrix(cell_pert_fraction):

tests/test_custom_data_type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
import dpdata
7+
from dpdata.data_type import Axis, DataType
78

89

910
class TestDeepmdLoadDumpComp(unittest.TestCase):
@@ -44,6 +45,14 @@ def test_from_deepmd_hdf5(self):
4445
x = dpdata.LabeledSystem("data_foo.h5", fmt="deepmd/hdf5")
4546
np.testing.assert_allclose(x.data["foo"], self.foo)
4647

48+
def test_duplicated_data_type(self):
49+
dt = DataType("foo", np.ndarray, (Axis.NFRAMES, 2, 4), required=False)
50+
n_dtypes_old = len(dpdata.LabeledSystem.DTYPES)
51+
with self.assertWarns(UserWarning):
52+
dpdata.LabeledSystem.register_data_type(dt)
53+
n_dtypes_new = len(dpdata.LabeledSystem.DTYPES)
54+
self.assertEqual(n_dtypes_old, n_dtypes_new)
55+
4756

4857
class TestDeepmdLoadDumpCompAny(unittest.TestCase):
4958
def setUp(self):

0 commit comments

Comments
 (0)