|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import unittest
|
| 4 | +import warnings |
4 | 5 |
|
5 | 6 | import h5py # noqa: TID253
|
6 | 7 | import numpy as np
|
|
9 | 10 | from dpdata.data_type import Axis, DataType
|
10 | 11 |
|
11 | 12 |
|
| 13 | +class TestDataType(unittest.TestCase): |
| 14 | + """Test DataType class methods.""" |
| 15 | + |
| 16 | + def setUp(self): |
| 17 | + # Store original DTYPES to restore later |
| 18 | + self.original_dtypes = dpdata.System.DTYPES |
| 19 | + |
| 20 | + def tearDown(self): |
| 21 | + # Restore original DTYPES |
| 22 | + dpdata.System.DTYPES = self.original_dtypes |
| 23 | + |
| 24 | + def test_eq(self): |
| 25 | + """Test equality method.""" |
| 26 | + dt1 = DataType("test", np.ndarray, shape=(Axis.NFRAMES, 3)) |
| 27 | + dt2 = DataType("test", np.ndarray, shape=(Axis.NFRAMES, 3)) |
| 28 | + dt3 = DataType("other", np.ndarray, shape=(Axis.NFRAMES, 3)) |
| 29 | + |
| 30 | + self.assertTrue(dt1 == dt2) |
| 31 | + self.assertFalse(dt1 == dt3) |
| 32 | + self.assertFalse(dt1 == "not a DataType") |
| 33 | + |
| 34 | + def test_repr(self): |
| 35 | + """Test string representation.""" |
| 36 | + dt = DataType("test", np.ndarray, shape=(Axis.NFRAMES, 3)) |
| 37 | + expected = ( |
| 38 | + "DataType(name='test', dtype=ndarray, " |
| 39 | + "shape=(<Axis.NFRAMES: 'nframes'>, 3), required=True, " |
| 40 | + "deepmd_name='test')" |
| 41 | + ) |
| 42 | + self.assertEqual(repr(dt), expected) |
| 43 | + |
| 44 | + def test_register_same_data_type_no_warning(self): |
| 45 | + """Test registering identical DataType instances should not warn.""" |
| 46 | + dt1 = DataType("test_same", np.ndarray, shape=(Axis.NFRAMES, 3)) |
| 47 | + dt2 = DataType("test_same", np.ndarray, shape=(Axis.NFRAMES, 3)) |
| 48 | + |
| 49 | + # Register first time |
| 50 | + dpdata.System.register_data_type(dt1) |
| 51 | + |
| 52 | + # Register same DataType again - should not warn |
| 53 | + with warnings.catch_warnings(record=True) as w: |
| 54 | + warnings.simplefilter("always") |
| 55 | + dpdata.System.register_data_type(dt2) |
| 56 | + # Check no warnings were issued |
| 57 | + self.assertEqual(len(w), 0) |
| 58 | + |
| 59 | + def test_register_different_data_type_with_warning(self): |
| 60 | + """Test registering different DataType instances with same name should warn.""" |
| 61 | + dt1 = DataType("test_diff", np.ndarray, shape=(Axis.NFRAMES, 3)) |
| 62 | + dt2 = DataType( |
| 63 | + "test_diff", list, shape=(Axis.NFRAMES, 4) |
| 64 | + ) # Different dtype and shape |
| 65 | + |
| 66 | + # Register first time |
| 67 | + dpdata.System.register_data_type(dt1) |
| 68 | + |
| 69 | + # Register different DataType with same name - should warn |
| 70 | + with warnings.catch_warnings(record=True) as w: |
| 71 | + warnings.simplefilter("always") |
| 72 | + dpdata.System.register_data_type(dt2) |
| 73 | + # Check warning was issued |
| 74 | + self.assertEqual(len(w), 1) |
| 75 | + self.assertTrue(issubclass(w[-1].category, UserWarning)) |
| 76 | + self.assertIn( |
| 77 | + "registered twice with different definitions", str(w[-1].message) |
| 78 | + ) |
| 79 | + |
| 80 | + |
12 | 81 | class DeepmdLoadDumpCompTest:
|
13 | 82 | def setUp(self):
|
14 | 83 | self.system = self.cls(
|
@@ -49,14 +118,6 @@ def test_from_deepmd_hdf5(self):
|
49 | 118 | x = self.cls("data_foo.h5", fmt="deepmd/hdf5")
|
50 | 119 | np.testing.assert_allclose(x.data["foo"], self.foo)
|
51 | 120 |
|
52 |
| - def test_duplicated_data_type(self): |
53 |
| - dt = DataType("foo", np.ndarray, (Axis.NFRAMES, *self.shape), required=False) |
54 |
| - n_dtypes_old = len(self.cls.DTYPES) |
55 |
| - with self.assertWarns(UserWarning): |
56 |
| - self.cls.register_data_type(dt) |
57 |
| - n_dtypes_new = len(self.cls.DTYPES) |
58 |
| - self.assertEqual(n_dtypes_old, n_dtypes_new) |
59 |
| - |
60 | 121 | def test_to_deepmd_npy_mixed(self):
|
61 | 122 | ms = dpdata.MultiSystems(self.system)
|
62 | 123 | ms.to_deepmd_npy_mixed("data_foo_mixed")
|
|
0 commit comments