Skip to content

Commit 2b68dd4

Browse files
committed
Check consistency between multiple dump calls in NPYWriter
1 parent f11d0c4 commit 2b68dd4

File tree

3 files changed

+44
-11
lines changed

3 files changed

+44
-11
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616

1717
- By default, run only 100 optimization steps in `build_random_cell`.
1818
- Wrap atoms back into the cell when writing PDB trajectory files, for nicer visual.
19+
- Stricter consistency checking between multiple `dump` calls in `NPYWriter`.
1920

2021

2122
## [0.0.0] - 2024-10-06

src/tinyff/trajectory.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,15 +153,24 @@ def dump(self, **kwargs):
153153

154154
def dump_each(self, **kwargs):
155155
"""Write data to NPY files without considering `self.stride`."""
156-
for key, value in kwargs.items():
157-
# Check array properties
158-
shape, dtype = self.fields.get(key, (None, None))
159-
arvalue = np.asarray(value)
160-
if shape is None:
161-
shape = arvalue.shape
162-
dtype = arvalue.dtype
163-
self.fields[key] = (shape, dtype)
164-
else:
156+
converted = {}
157+
if len(self.fields) == 0:
158+
# No checking, just record the given shapes and types
159+
for key, value in kwargs.items():
160+
arvalue = np.asarray(value)
161+
converted[key] = arvalue
162+
self.fields[key] = (arvalue.shape, arvalue.dtype)
163+
else:
164+
# Check kwargs
165+
if set(self.fields) != set(kwargs):
166+
raise TypeError(
167+
f"Received keys: {list(kwargs.keys())}. "
168+
f"Expected: {list(self.fields.keys())}"
169+
)
170+
for key, value in kwargs.items():
171+
arvalue = np.asarray(value)
172+
converted[key] = arvalue
173+
shape, dtype = self.fields[key]
165174
if shape != arvalue.shape:
166175
raise TypeError(
167176
f"The shape of {key}, {arvalue.shape}, differs from the first one, {shape}"
@@ -170,7 +179,9 @@ def dump_each(self, **kwargs):
170179
raise TypeError(
171180
f"The dtype of {key}, {arvalue.dtype}, differs from the first one, {dtype}"
172181
)
173-
# Append to NPY file
182+
183+
# Write only once all checks have passed
184+
for key, value in converted.items():
174185
path = os.path.join(self.dir_out, f"{key}.npy")
175186
with NpyAppendArray(path, delete_if_exists=False) as npaa:
176-
npaa.append(arvalue.reshape(1, *arvalue.shape))
187+
npaa.append(value.reshape(1, *value.shape))

tests/test_trajectory.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,24 @@ def test_npy_traj_stride(tmpdir):
120120
for foo in traj_foo:
121121
npy_writer.dump(foo=foo)
122122
assert (np.load(os.path.join(tmpdir, "out/foo.npy")) == [1, 4, 7]).all()
123+
124+
125+
def test_npy_traj_consistent_names(tmpdir):
126+
npy_writer = NPYWriter(os.path.join(tmpdir, "out"))
127+
npy_writer.dump(a=1, b=2)
128+
with pytest.raises(TypeError):
129+
npy_writer.dump(b=3, c=4)
130+
131+
132+
def test_npy_traj_consistent_shapes(tmpdir):
133+
npy_writer = NPYWriter(os.path.join(tmpdir, "out"))
134+
npy_writer.dump(a=[1, 2], b=3)
135+
with pytest.raises(TypeError):
136+
npy_writer.dump(a=3, b=4)
137+
138+
139+
def test_npy_traj_consistent_dtypes(tmpdir):
140+
npy_writer = NPYWriter(os.path.join(tmpdir, "out"))
141+
npy_writer.dump(a=1.0, b=3)
142+
with pytest.raises(TypeError):
143+
npy_writer.dump(a=3, b=4)

0 commit comments

Comments
 (0)