Skip to content

Commit 0f123cf

Browse files
committed
Remove the need for posthoc compression by accumulating compressed frames live
1 parent 340cf9c commit 0f123cf

File tree

5 files changed

+247
-18
lines changed

5 files changed

+247
-18
lines changed

src/pylattica/core/runner/asynchronous_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def _run( # pylint: disable=too-many-positional-arguments
3131
self,
3232
_: SimulationState,
3333
result: SimulationResult,
34-
live_state: SimulationState,
3534
controller: BasicController,
3635
num_steps: int,
3736
verbose: bool = False,
@@ -58,6 +57,7 @@ def _run( # pylint: disable=too-many-positional-arguments
5857
"""
5958

6059
site_queue = deque()
60+
live_state = result.live_state
6161

6262
def _add_sites_to_queue():
6363
next_site = controller.get_random_site(live_state)
@@ -84,7 +84,6 @@ def _add_sites_to_queue():
8484
state_updates = controller_response
8585

8686
state_updates = merge_updates(state_updates, site_id=site_id)
87-
live_state.batch_update(state_updates)
8887
site_queue.extend(next_sites)
8988

9089
result.add_step(state_updates)
@@ -95,5 +94,4 @@ def _add_sites_to_queue():
9594
if len(site_queue) == 0:
9695
break
9796

98-
result.set_output(live_state)
9997
return result

src/pylattica/core/runner/base_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ def run(
6262

6363
result = controller.instantiate_result(initial_state.copy())
6464
controller.pre_run(initial_state)
65-
live_state = initial_state.copy()
6665

67-
self._run(initial_state, result, live_state, controller, num_steps, verbose)
66+
self._run(initial_state, result, controller, num_steps, verbose)
6867

69-
result.set_output(live_state)
7068
return result

src/pylattica/core/runner/synchronous_runner.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def _run( # pylint: disable=too-many-positional-arguments
4242
self,
4343
initial_state: SimulationState,
4444
result: SimulationResult,
45-
live_state: SimulationState,
4645
controller: BasicController,
4746
num_steps: int,
4847
verbose: bool = False,
@@ -74,11 +73,9 @@ def _run( # pylint: disable=too-many-positional-arguments
7473
else:
7574
printif(verbose, "Running in series.")
7675
for _ in tqdm(range(num_steps)):
77-
updates = self._take_step(live_state, controller)
78-
live_state.batch_update(updates)
76+
updates = self._take_step(result.live_state, controller)
7977
result.add_step(updates)
8078

81-
result.set_output(live_state)
8279
return result
8380

8481
def _take_step_parallel(self, updates: dict, pool, chunk_size) -> SimulationState:

src/pylattica/core/simulation_result.py

Lines changed: 123 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ class SimulationResult:
2121
Maximum number of diffs to keep in memory. When exceeded, older diffs
2222
are dropped and a checkpoint is created. Set to None for unlimited
2323
history (default, but may cause memory issues for long simulations).
24+
live_compress : bool, optional
25+
If True, store full state snapshots at compress_freq intervals during
26+
simulation instead of diffs. This avoids the expensive O(n) reconstruction
27+
in load_steps() at the cost of more memory per frame. When enabled,
28+
load_steps() becomes a no-op since frames are already stored.
2429
"""
2530

2631
@classmethod
@@ -29,13 +34,15 @@ def from_file(cls, fpath):
2934

3035
@classmethod
3136
def from_dict(cls, res_dict):
32-
diffs = res_dict["diffs"]
37+
diffs = res_dict.get("diffs", [])
3338
compress_freq = res_dict.get("compress_freq", 1)
3439
max_history = res_dict.get("max_history", None)
40+
live_compress = res_dict.get("live_compress", False)
3541
res = cls(
3642
SimulationState.from_dict(res_dict["initial_state"]),
3743
compress_freq=compress_freq,
3844
max_history=max_history,
45+
live_compress=live_compress,
3946
)
4047
# Restore checkpoint if present
4148
if "checkpoint_state" in res_dict and res_dict["checkpoint_state"] is not None:
@@ -44,6 +51,11 @@ def from_dict(cls, res_dict):
4451
)
4552
res._checkpoint_step = res_dict.get("checkpoint_step", 0)
4653

54+
# Restore frames if present (for live_compress mode)
55+
if "frames" in res_dict and res_dict["frames"]:
56+
for step_str, state_dict in res_dict["frames"].items():
57+
res._frames[int(step_str)] = SimulationState.from_dict(state_dict)
58+
4759
for diff in diffs:
4860
if SITES in diff:
4961
diff[SITES] = {int(k): v for k, v in diff[SITES].items()}
@@ -56,13 +68,26 @@ def from_dict(cls, res_dict):
5668
"total_steps", res._checkpoint_step + len(diffs)
5769
)
5870

71+
# Reconstruct live_state to reflect the final state
72+
if res._frames:
73+
# In live_compress mode, use the last frame
74+
last_step = max(res._frames.keys())
75+
res._live_state = res._frames[last_step].copy()
76+
elif res._diffs:
77+
# Replay all diffs to get final state
78+
if res._checkpoint_state is not None:
79+
res._live_state = res._checkpoint_state.copy()
80+
for diff in res._diffs:
81+
res._live_state.batch_update(diff)
82+
5983
return res
6084

6185
def __init__(
6286
self,
6387
starting_state: SimulationState,
6488
compress_freq: int = 1,
6589
max_history: int = None,
90+
live_compress: bool = False,
6691
):
6792
"""Initializes a SimulationResult with the specified starting_state.
6893
@@ -71,24 +96,39 @@ def __init__(
7196
starting_state : SimulationState
7297
The state with which the simulation started.
7398
compress_freq : int, optional
74-
Compression frequency for sampling, by default 1.
99+
Compression frequency for sampling, by default 1. When live_compress
100+
is True, this controls how often full state snapshots are stored.
75101
max_history : int, optional
76102
Maximum number of diffs to keep in memory. When exceeded, a
77103
checkpoint is created and old diffs are dropped. This prevents
78104
unbounded memory growth during long simulations. Set to None
79105
(default) for unlimited history. Recommended: 1000-10000 for
80-
long simulations.
106+
long simulations. Ignored when live_compress is True.
107+
live_compress : bool, optional
108+
If True, store full state snapshots at compress_freq intervals
109+
during simulation instead of storing diffs. This avoids the O(n)
110+
reconstruction cost of load_steps() but uses more memory per stored
111+
frame. Default is False (store diffs, reconstruct post-hoc).
81112
"""
82113
self.initial_state = starting_state
83114
self.compress_freq = compress_freq
84115
self.max_history = max_history
116+
self.live_compress = live_compress
85117
self._diffs: list[dict] = []
86118
self._stored_states = {}
119+
self._frames: Dict[int, SimulationState] = {} # For live_compress mode
87120
# Checkpoint support for bounded history
88121
self._checkpoint_state: SimulationState = None
89122
self._checkpoint_step: int = 0
90123
self._total_steps: int = 0
91124

125+
# Live state that gets updated with each step
126+
self._live_state: SimulationState = starting_state.copy()
127+
128+
# Store initial state as frame 0 if live_compress is enabled
129+
if self.live_compress:
130+
self._frames[0] = starting_state.copy()
131+
92132
def get_diffs(self) -> list[dict]:
93133
"""Returns the list of diffs.
94134
@@ -99,6 +139,15 @@ def get_diffs(self) -> list[dict]:
99139
"""
100140
return self._diffs
101141

142+
@property
143+
def live_state(self) -> SimulationState:
144+
"""The current live state of the simulation.
145+
146+
This state is updated with each call to add_step(). Use this to access
147+
the current simulation state during a run.
148+
"""
149+
return self._live_state
150+
102151
def add_step(self, updates: Dict[int, Dict]) -> None:
103152
"""Takes a set of updates as a dictionary mapping site IDs
104153
to the new values for various state parameters. For instance, if at the
@@ -111,14 +160,30 @@ def add_step(self, updates: Dict[int, Dict]) -> None:
111160
}
112161
}
113162
163+
This method:
164+
1. Applies the updates to the internal live_state
165+
2. Increments the step counter
166+
3. In live_compress mode: stores frames at compress_freq intervals
167+
4. In normal mode: stores the diff for later reconstruction
168+
114169
Parameters
115170
----------
116171
updates : dict
117172
The changes associated with a new simulation step.
118173
"""
119-
self._diffs.append(updates)
174+
# Update the live state
175+
self._live_state.batch_update(updates)
120176
self._total_steps += 1
121177

178+
# In live_compress mode, store frames at intervals instead of diffs
179+
if self.live_compress:
180+
if self._total_steps % self.compress_freq == 0:
181+
self._frames[self._total_steps] = self._live_state.copy()
182+
return
183+
184+
# Normal mode: store diffs
185+
self._diffs.append(updates)
186+
122187
# Check if we need to create a checkpoint and drop old diffs
123188
if self.max_history is not None and len(self._diffs) > self.max_history:
124189
self._create_checkpoint()
@@ -173,13 +238,20 @@ def steps(self) -> List[SimulationState]:
173238
"""Yields all available steps from this simulation.
174239
175240
Note: When max_history is set, only steps from the checkpoint onward
176-
are available. Use earliest_available_step to check.
241+
are available. When live_compress is set, only frames at compress_freq
242+
intervals are available. Use earliest_available_step to check.
177243
178244
Yields
179245
------
180246
SimulationState
181247
Each step's state (as a copy to avoid mutation issues).
182248
"""
249+
# If frames exist (live_compress mode), yield them in order
250+
if self._frames:
251+
for step_no in sorted(self._frames.keys()):
252+
yield self._frames[step_no].copy()
253+
return
254+
183255
# Start from checkpoint or initial state
184256
if self._checkpoint_state is not None:
185257
live_state = self._checkpoint_state.copy()
@@ -206,17 +278,38 @@ def last_step(self) -> SimulationState:
206278
def first_step(self):
207279
return self.get_step(0)
208280

209-
def set_output(self, step: SimulationState):
210-
self.output = step
281+
@property
282+
def output(self) -> SimulationState:
283+
"""The final output state of the simulation (alias for live_state)."""
284+
return self._live_state
211285

212286
def load_steps(self, interval=1):
213287
"""Pre-loads steps into memory at the specified interval for faster access.
214288
289+
When live_compress is enabled, this is a no-op since frames are already
290+
stored during simulation. If a different interval is requested than what
291+
was used during simulation (compress_freq), an error is raised.
292+
215293
Parameters
216294
----------
217295
interval : int, optional
218296
Store every Nth step in memory, by default 1.
297+
298+
Raises
299+
------
300+
ValueError
301+
If live_compress was used and requested interval doesn't match compress_freq.
219302
"""
303+
# If frames already exist (live_compress mode), no reconstruction needed
304+
if self._frames:
305+
if interval != self.compress_freq:
306+
raise ValueError(
307+
f"Cannot load steps with interval={interval}. This result was "
308+
f"created with live_compress=True and compress_freq={self.compress_freq}. "
309+
f"Only interval={self.compress_freq} is available."
310+
)
311+
return
312+
220313
# Clear old cache first
221314
self._stored_states.clear()
222315

@@ -255,8 +348,20 @@ def get_step(self, step_no) -> SimulationState:
255348
Raises
256349
------
257350
ValueError
258-
If step_no is before the earliest available step (when using max_history).
351+
If step_no is before the earliest available step (when using max_history),
352+
or if step_no is not available in live_compress mode.
259353
"""
354+
# Check frames first (live_compress mode)
355+
if self._frames:
356+
if step_no in self._frames:
357+
return self._frames[step_no]
358+
# In live_compress mode, only frames at compress_freq intervals exist
359+
raise ValueError(
360+
f"Cannot retrieve step {step_no}. This result was created with "
361+
f"live_compress=True and compress_freq={self.compress_freq}. "
362+
f"Available steps: {sorted(self._frames.keys())}"
363+
)
364+
260365
if step_no < self._checkpoint_step:
261366
raise ValueError(
262367
f"Cannot retrieve step {step_no}. Earliest available step is "
@@ -288,6 +393,7 @@ def as_dict(self):
288393
"diffs": self._diffs,
289394
"compress_freq": self.compress_freq,
290395
"max_history": self.max_history,
396+
"live_compress": self.live_compress,
291397
"total_steps": self._total_steps,
292398
"@module": self.__class__.__module__,
293399
"@class": self.__class__.__name__,
@@ -299,6 +405,15 @@ def as_dict(self):
299405
else:
300406
result["checkpoint_state"] = None
301407
result["checkpoint_step"] = 0
408+
409+
# Include frames if present (live_compress mode)
410+
if self._frames:
411+
result["frames"] = {
412+
str(step): state.as_dict() for step, state in self._frames.items()
413+
}
414+
else:
415+
result["frames"] = {}
416+
302417
return result
303418

304419
def to_file(self, fpath: str = None) -> None:

0 commit comments

Comments
 (0)