Skip to content

Commit efa66a7

Browse files
fix: skip datatype registration warning for duplicate types (#856)
Currently dpdata throws a warning if an existed data type is registered again, even if with the same definition. This PR adds a check to skip if the registered data types are identicle. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Improved object comparison and display for data types, allowing equality checks and clearer string representations. * **Bug Fixes** * Reduced unnecessary warnings when registering identical data types; warnings are now only shown if duplicate registrations differ. * **Tests** * Added tests for data type comparison, representation, and registration behavior. * Removed outdated test for duplicate data type registration. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5aade0d commit efa66a7

File tree

3 files changed

+108
-10
lines changed

3 files changed

+108
-10
lines changed

dpdata/data_type.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,43 @@ def __init__(
6464
self.required = required
6565
self.deepmd_name = name if deepmd_name is None else deepmd_name
6666

67+
def __eq__(self, other) -> bool:
68+
"""Check if two DataType instances are equal.
69+
70+
Parameters
71+
----------
72+
other : object
73+
object to compare with
74+
75+
Returns
76+
-------
77+
bool
78+
True if equal, False otherwise
79+
"""
80+
if not isinstance(other, DataType):
81+
return False
82+
return (
83+
self.name == other.name
84+
and self.dtype == other.dtype
85+
and self.shape == other.shape
86+
and self.required == other.required
87+
and self.deepmd_name == other.deepmd_name
88+
)
89+
90+
def __repr__(self) -> str:
91+
"""Return string representation of DataType.
92+
93+
Returns
94+
-------
95+
str
96+
string representation
97+
"""
98+
return (
99+
f"DataType(name='{self.name}', dtype={self.dtype.__name__}, "
100+
f"shape={self.shape}, required={self.required}, "
101+
f"deepmd_name='{self.deepmd_name}')"
102+
)
103+
67104
def real_shape(self, system: System) -> tuple[int]:
68105
"""Returns expected real shape of a system."""
69106
assert self.shape is not None

dpdata/system.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1102,9 +1102,9 @@ def register_data_type(cls, *data_type: DataType):
11021102
all_dtypes = cls.DTYPES + tuple(data_type)
11031103
dtypes_dict = {}
11041104
for dt in all_dtypes:
1105-
if dt.name in dtypes_dict:
1105+
if dt.name in dtypes_dict and dt != dtypes_dict[dt.name]:
11061106
warnings.warn(
1107-
f"Data type {dt.name} is registered twice; only the newly registered one will be used.",
1107+
f"Data type {dt.name} is registered twice with different definitions; only the newly registered one will be used.",
11081108
UserWarning,
11091109
)
11101110
dtypes_dict[dt.name] = dt

tests/test_custom_data_type.py

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import unittest
4+
import warnings
45

56
import h5py # noqa: TID253
67
import numpy as np
@@ -9,6 +10,74 @@
910
from dpdata.data_type import Axis, DataType
1011

1112

13+
class TestDataType(unittest.TestCase):
14+
"""Test DataType class methods."""
15+
16+
def setUp(self):
17+
# Store original DTYPES to restore later
18+
self.original_dtypes = dpdata.System.DTYPES
19+
20+
def tearDown(self):
21+
# Restore original DTYPES
22+
dpdata.System.DTYPES = self.original_dtypes
23+
24+
def test_eq(self):
25+
"""Test equality method."""
26+
dt1 = DataType("test", np.ndarray, shape=(Axis.NFRAMES, 3))
27+
dt2 = DataType("test", np.ndarray, shape=(Axis.NFRAMES, 3))
28+
dt3 = DataType("other", np.ndarray, shape=(Axis.NFRAMES, 3))
29+
30+
self.assertTrue(dt1 == dt2)
31+
self.assertFalse(dt1 == dt3)
32+
self.assertFalse(dt1 == "not a DataType")
33+
34+
def test_repr(self):
35+
"""Test string representation."""
36+
dt = DataType("test", np.ndarray, shape=(Axis.NFRAMES, 3))
37+
expected = (
38+
"DataType(name='test', dtype=ndarray, "
39+
"shape=(<Axis.NFRAMES: 'nframes'>, 3), required=True, "
40+
"deepmd_name='test')"
41+
)
42+
self.assertEqual(repr(dt), expected)
43+
44+
def test_register_same_data_type_no_warning(self):
45+
"""Test registering identical DataType instances should not warn."""
46+
dt1 = DataType("test_same", np.ndarray, shape=(Axis.NFRAMES, 3))
47+
dt2 = DataType("test_same", np.ndarray, shape=(Axis.NFRAMES, 3))
48+
49+
# Register first time
50+
dpdata.System.register_data_type(dt1)
51+
52+
# Register same DataType again - should not warn
53+
with warnings.catch_warnings(record=True) as w:
54+
warnings.simplefilter("always")
55+
dpdata.System.register_data_type(dt2)
56+
# Check no warnings were issued
57+
self.assertEqual(len(w), 0)
58+
59+
def test_register_different_data_type_with_warning(self):
60+
"""Test registering different DataType instances with same name should warn."""
61+
dt1 = DataType("test_diff", np.ndarray, shape=(Axis.NFRAMES, 3))
62+
dt2 = DataType(
63+
"test_diff", list, shape=(Axis.NFRAMES, 4)
64+
) # Different dtype and shape
65+
66+
# Register first time
67+
dpdata.System.register_data_type(dt1)
68+
69+
# Register different DataType with same name - should warn
70+
with warnings.catch_warnings(record=True) as w:
71+
warnings.simplefilter("always")
72+
dpdata.System.register_data_type(dt2)
73+
# Check warning was issued
74+
self.assertEqual(len(w), 1)
75+
self.assertTrue(issubclass(w[-1].category, UserWarning))
76+
self.assertIn(
77+
"registered twice with different definitions", str(w[-1].message)
78+
)
79+
80+
1281
class DeepmdLoadDumpCompTest:
1382
def setUp(self):
1483
self.system = self.cls(
@@ -49,14 +118,6 @@ def test_from_deepmd_hdf5(self):
49118
x = self.cls("data_foo.h5", fmt="deepmd/hdf5")
50119
np.testing.assert_allclose(x.data["foo"], self.foo)
51120

52-
def test_duplicated_data_type(self):
53-
dt = DataType("foo", np.ndarray, (Axis.NFRAMES, *self.shape), required=False)
54-
n_dtypes_old = len(self.cls.DTYPES)
55-
with self.assertWarns(UserWarning):
56-
self.cls.register_data_type(dt)
57-
n_dtypes_new = len(self.cls.DTYPES)
58-
self.assertEqual(n_dtypes_old, n_dtypes_new)
59-
60121
def test_to_deepmd_npy_mixed(self):
61122
ms = dpdata.MultiSystems(self.system)
62123
ms.to_deepmd_npy_mixed("data_foo_mixed")

0 commit comments

Comments
 (0)