Skip to content

Commit 0b9c109

Browse files
committed
tests for compress_result
1 parent 0c68212 commit 0b9c109

File tree

1 file changed

+91
-0
lines changed

1 file changed

+91
-0
lines changed

tests/core/test_simulation_result.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import random
44
import os
55
from pylattica.core import SimulationResult, SimulationState
6+
from pylattica.core.simulation_result import compress_result
67

78

89
@pytest.fixture
@@ -206,3 +207,93 @@ def test_max_history_none_unlimited(initial_state):
206207
assert len(result._diffs) == 500
207208
assert result._checkpoint_state is None
208209
assert result.earliest_available_step == 0
210+
211+
212+
def test_max_history_steps_generator(initial_state):
213+
"""Test that steps() works correctly with checkpoints."""
214+
result = SimulationResult(initial_state, max_history=50)
215+
216+
for step in range(100):
217+
updates = {0: {"value": step}}
218+
result.add_step(updates)
219+
220+
# Iterate through available steps
221+
steps_list = list(result.steps())
222+
223+
# Should have steps from checkpoint onward
224+
expected_count = len(result._diffs) + 1 # diffs + checkpoint state
225+
assert len(steps_list) == expected_count
226+
227+
# Each step should be a separate object (copies)
228+
assert steps_list[0] is not steps_list[1]
229+
230+
231+
def test_max_history_load_steps(initial_state):
232+
"""Test that load_steps() works correctly with checkpoints."""
233+
result = SimulationResult(initial_state, max_history=50)
234+
235+
for step in range(100):
236+
updates = {0: {"value": step}}
237+
result.add_step(updates)
238+
239+
# Load steps at interval
240+
result.load_steps(interval=10)
241+
242+
# Should have cached states
243+
assert len(result._stored_states) > 0
244+
245+
# Cached states should be after checkpoint
246+
for step_no in result._stored_states:
247+
assert step_no >= result.earliest_available_step
248+
249+
250+
def test_original_length(initial_state):
251+
"""Test the original_length property."""
252+
result = SimulationResult(initial_state, compress_freq=1)
253+
254+
for step in range(10):
255+
updates = {0: {"value": step}}
256+
result.add_step(updates)
257+
258+
# With compress_freq=1, original_length should equal len
259+
assert result.original_length == len(result)
260+
261+
# With compress_freq=2, original_length should be doubled
262+
result_compressed = SimulationResult(initial_state, compress_freq=2)
263+
for step in range(10):
264+
updates = {0: {"value": step}}
265+
result_compressed.add_step(updates)
266+
267+
assert result_compressed.original_length == len(result_compressed) * 2
268+
269+
270+
def test_compress_result(initial_state):
271+
"""Test the compress_result function."""
272+
result = SimulationResult(initial_state)
273+
274+
# Add 100 steps with deterministic values
275+
for step in range(100):
276+
updates = {0: {"value": step}}
277+
result.add_step(updates)
278+
279+
# Compress to 20 steps
280+
compressed = compress_result(result, 20)
281+
282+
# Should have fewer steps
283+
assert len(compressed) <= 25 # Some margin for sampling
284+
285+
# compress_freq should be updated
286+
assert compressed.compress_freq > 1
287+
288+
289+
def test_compress_result_invalid_size(initial_state):
290+
"""Test that compress_result raises error for invalid target size."""
291+
result = SimulationResult(initial_state)
292+
293+
for step in range(10):
294+
updates = {0: {"value": step}}
295+
result.add_step(updates)
296+
297+
# Can't compress to more steps than we have
298+
with pytest.raises(ValueError, match="Cannot compress"):
299+
compress_result(result, 100)

0 commit comments

Comments
 (0)