Skip to content

Commit f238736

Browse files
Enhancement/async datacollector (#167)
* changed self._frames for async functionality * test timer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * delay for pg * check for pg delay fix * added frames to flush in docstring * frames to flush type hinds * code quality * agentreporter gets model instead of agents * code quality * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * precommit * code quality * max workers * removed default argument from abstract and added them to concrete * batch collection * tests * test_batch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * precommit * precommit run * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * added local batch save test case * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use public step method --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9ce5cd7 commit f238736

File tree

4 files changed

+521
-74
lines changed

4 files changed

+521
-74
lines changed

AGENTS.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Repository Guidelines
22

33
## Project Structure & Module Organization
4+
45
- `mesa_frames/`: Source package.
56
- `abstract/` and `concrete/`: Core APIs and implementations.
67
- Key modules: `agents.py`, `agentset.py`, `space.py`, `datacollector.py`, `types_.py`.
@@ -9,6 +10,7 @@
910
- `examples/`: Reproducible demo models and performance scripts.
1011

1112
## Build, Test, and Development Commands
13+
1214
- Install (dev stack): `uv sync` (always use uv)
1315
- Lint & format: `uv run ruff check . --fix && uv run ruff format .`
1416
- Tests (quiet + coverage): `export MESA_FRAMES_RUNTIME_TYPECHECKING = 1 && uv run pytest -q --cov=mesa_frames --cov-report=term-missing`
@@ -18,23 +20,27 @@
1820
Always run tools via uv: `uv run <command>`.
1921

2022
## Coding Style & Naming Conventions
23+
2124
- Python 3.11+, 4-space indent, type hints required for public APIs.
2225
- Docstrings: NumPy style (validated by Ruff/pydoclint).
2326
- Formatting/linting: Ruff (formatter + lints). Fix on save if your IDE supports it.
2427
- Names: `CamelCase` for classes, `snake_case` for functions/attributes, tests as `test_<unit>.py` with `Test<Class>` groups.
2528

2629
## Testing Guidelines
30+
2731
- Framework: Pytest; place tests under `tests/` mirroring module paths.
2832
- Conventions: One test module per feature; name tests `test_<method_or_behavior>`.
2933
- Coverage: Aim to exercise new branches and error paths; keep `--cov=mesa_frames` green.
3034
- Run fast locally: `pytest -q` or `uv run pytest -q`.
3135

3236
## Commit & Pull Request Guidelines
37+
3338
- Commits: Imperative mood, concise subject, meaningful body when needed.
3439
Example: `Fix AgentsDF.sets copy binding and tests`.
3540
- PRs: Link issues, summarize changes, note API impacts, add/adjust tests and docs.
3641
- CI hygiene: Run `ruff`, `pytest`, and `pre-commit` locally before pushing.
3742

3843
## Security & Configuration Tips
44+
3945
- Never commit secrets; use env vars. Example: `MESA_FRAMES_RUNTIME_TYPECHECKING=1` for stricter dev runs.
4046
- Treat underscored attributes as internal.

mesa_frames/abstract/datacollector.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ def flush(self):
4545
"""
4646

4747
from abc import ABC, abstractmethod
48-
from typing import Dict, Optional, Union, Any, Literal, List
48+
from typing import Any, Literal
4949
from collections.abc import Callable
5050
from mesa_frames import ModelDF
5151
import polars as pl
52+
import threading
53+
from concurrent.futures import ThreadPoolExecutor
5254

5355

5456
class AbstractDataCollector(ABC):
@@ -62,21 +64,22 @@ class AbstractDataCollector(ABC):
6264
_model: ModelDF
6365
_model_reporters: dict[str, Callable] | None
6466
_agent_reporters: dict[str, str | Callable] | None
65-
_trigger: Callable[..., bool]
67+
_trigger: Callable[..., bool] | None
6668
_reset_memory = bool
6769
_storage: Literal["memory", "csv", "parquet", "S3-csv", "S3-parquet", "postgresql"]
6870
_frames: list[pl.DataFrame]
6971

7072
def __init__(
7173
self,
7274
model: ModelDF,
73-
model_reporters: dict[str, Callable] | None = None,
74-
agent_reporters: dict[str, str | Callable] | None = None,
75-
trigger: Callable[[Any], bool] | None = None,
76-
reset_memory: bool = True,
75+
model_reporters: dict[str, Callable] | None,
76+
agent_reporters: dict[str, str | Callable] | None,
77+
trigger: Callable[[Any], bool] | None,
78+
reset_memory: bool,
7779
storage: Literal[
7880
"memory", "csv", "parquet", "S3-csv", "S3-parquet", "postgresql"
79-
] = "memory",
81+
],
82+
max_workers: int,
8083
):
8184
"""
8285
Initialize a Datacollector.
@@ -95,6 +98,8 @@ def __init__(
9598
Whether to reset in-memory data after flushing. Default is True.
9699
storage : Literal["memory", "csv", "parquet", "S3-csv", "S3-parquet", "postgresql" ]
97100
Storage backend URI (e.g. 'memory:', 'csv:', 'postgresql:').
101+
max_workers : int
102+
Maximum number of worker threads used for flushing collected data asynchronously
98103
"""
99104
self._model = model
100105
self._model_reporters = model_reporters or {}
@@ -103,6 +108,8 @@ def __init__(
103108
self._reset_memory = reset_memory
104109
self._storage = storage or "memory"
105110
self._frames = []
111+
self._lock = threading.Lock()
112+
self._executor = ThreadPoolExecutor(max_workers=max_workers)
106113

107114
def collect(self) -> None:
108115
"""
@@ -177,9 +184,12 @@ def flush(self) -> None:
177184
>>> datacollector.flush()
178185
>>> # Data is saved externally and in-memory buffers are cleared if configured
179186
"""
180-
self._flush()
181-
if self._reset_memory:
182-
self._reset()
187+
with self._lock:
188+
frames_to_flush = self._frames
189+
if self._reset_memory:
190+
self._reset()
191+
192+
self._executor.submit(self._flush, frames_to_flush)
183193

184194
def _reset(self):
185195
"""

mesa_frames/concrete/datacollector.py

Lines changed: 64 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
] = "memory",
8080
storage_uri: str | None = None,
8181
schema: str = "public",
82+
max_worker: int = 4,
8283
):
8384
"""
8485
Initialize the DataCollector with configuration options.
@@ -101,15 +102,17 @@ def __init__(
101102
URI or path corresponding to the selected storage backend.
102103
schema: str
103104
Schema name used for PostgreSQL storage.
104-
105+
max_worker : int
106+
Maximum number of worker threads used for flushing collected data asynchronously
105107
"""
106108
super().__init__(
107109
model=model,
108110
model_reporters=model_reporters,
109111
agent_reporters=agent_reporters,
110112
trigger=trigger,
111113
reset_memory=reset_memory,
112-
storage=storage, # literal won't work
114+
storage=storage,
115+
max_workers=max_worker,
113116
)
114117
self._writers = {
115118
"csv": self._write_csv_local,
@@ -120,6 +123,8 @@ def __init__(
120123
}
121124
self._storage_uri = storage_uri
122125
self._schema = schema
126+
self._current_model_step = None
127+
self._batch_id = None
123128

124129
self._validate_inputs()
125130

@@ -130,28 +135,42 @@ def _collect(self):
130135
This method checks for the presence of model and agent reporters
131136
and calls the appropriate collection routines for each.
132137
"""
138+
if (
139+
self._current_model_step is None
140+
or self._current_model_step != self._model.steps
141+
):
142+
self._current_model_step = self._model.steps
143+
self._batch_id = 0
144+
133145
if self._model_reporters:
134-
self._collect_model_reporters()
146+
self._collect_model_reporters(
147+
current_model_step=self._current_model_step, batch_id=self._batch_id
148+
)
135149

136150
if self._agent_reporters:
137-
self._collect_agent_reporters()
151+
self._collect_agent_reporters(
152+
current_model_step=self._current_model_step, batch_id=self._batch_id
153+
)
154+
155+
self._batch_id += 1
138156

139-
def _collect_model_reporters(self):
157+
def _collect_model_reporters(self, current_model_step: int, batch_id: int):
140158
"""
141159
Collect model-level data using the model_reporters.
142160
143161
Creates a LazyFrame containing the step, seed, and values
144162
returned by each model reporter. Appends the LazyFrame to internal storage.
145163
"""
146164
model_data_dict = {}
147-
model_data_dict["step"] = self._model._steps
165+
model_data_dict["step"] = current_model_step
148166
model_data_dict["seed"] = str(self.seed)
167+
model_data_dict["batch"] = batch_id
149168
for column_name, reporter in self._model_reporters.items():
150169
model_data_dict[column_name] = reporter(self._model)
151170
model_lazy_frame = pl.LazyFrame([model_data_dict])
152-
self._frames.append(("model", str(self._model._steps), model_lazy_frame))
171+
self._frames.append(("model", current_model_step, batch_id, model_lazy_frame))
153172

154-
def _collect_agent_reporters(self):
173+
def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
155174
"""
156175
Collect agent-level data using the agent_reporters.
157176
@@ -164,15 +183,16 @@ def _collect_agent_reporters(self):
164183
for k, v in self._model.agents[reporter].items():
165184
agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v
166185
else:
167-
agent_data_dict[col_name] = reporter(self._model.agents)
186+
agent_data_dict[col_name] = reporter(self._model)
168187
agent_lazy_frame = pl.LazyFrame(agent_data_dict)
169188
agent_lazy_frame = agent_lazy_frame.with_columns(
170189
[
171-
pl.lit(self._model._steps).alias("step"),
190+
pl.lit(current_model_step).alias("step"),
172191
pl.lit(str(self.seed)).alias("seed"),
192+
pl.lit(batch_id).alias("batch"),
173193
]
174194
)
175-
self._frames.append(("agent", str(self._model._steps), agent_lazy_frame))
195+
self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame))
176196

177197
@property
178198
def data(self) -> dict[str, pl.DataFrame]:
@@ -185,96 +205,108 @@ def data(self) -> dict[str, pl.DataFrame]:
185205
A dictionary with keys "model" and "agent" mapping to concatenated DataFrames of collected data.
186206
"""
187207
model_frames = [
188-
lf.collect() for kind, step, lf in self._frames if kind == "model"
208+
lf.collect() for kind, step, batch_id, lf in self._frames if kind == "model"
189209
]
190210
agent_frames = [
191-
lf.collect() for kind, step, lf in self._frames if kind == "agent"
211+
lf.collect() for kind, step, batch_id, lf in self._frames if kind == "agent"
192212
]
193213
return {
194214
"model": pl.concat(model_frames) if model_frames else pl.DataFrame(),
195215
"agent": pl.concat(agent_frames) if agent_frames else pl.DataFrame(),
196216
}
197217

198-
def _flush(self):
218+
def _flush(self, frames_to_flush: list):
199219
"""
200220
Flush the collected data to the configured external storage backend.
201221
202222
Uses the appropriate writer function based on the specified storage option.
203223
"""
204-
self._writers[self._storage](self._storage_uri)
224+
self._writers[self._storage](
225+
uri=self._storage_uri, frames_to_flush=frames_to_flush
226+
)
205227

206-
def _write_csv_local(self, uri: str):
228+
def _write_csv_local(self, uri: str, frames_to_flush: list):
207229
"""
208230
Write collected data to local CSV files.
209231
210232
Parameters
211233
----------
212234
uri : str
213235
Local directory path to write files into.
236+
frames_to_flush : list
237+
the collected data in the current thread.
214238
"""
215-
for kind, step, df in self._frames:
216-
df.collect().write_csv(f"{uri}/{kind}_step{step}.csv")
239+
for kind, step, batch, df in frames_to_flush:
240+
df.collect().write_csv(f"{uri}/{kind}_step{step}_batch{batch}.csv")
217241

218-
def _write_parquet_local(self, uri: str):
242+
def _write_parquet_local(self, uri: str, frames_to_flush: list):
219243
"""
220244
Write collected data to local Parquet files.
221245
222246
Parameters
223247
----------
224248
uri: str
225249
Local directory path to write files into.
250+
frames_to_flush : list
251+
the collected data in the current thread.
226252
"""
227-
for kind, step, df in self._frames:
228-
df.collect().write_parquet(f"{uri}/{kind}_step{step}.parquet")
253+
for kind, step, batch, df in frames_to_flush:
254+
df.collect().write_parquet(f"{uri}/{kind}_step{step}_batch{batch}.parquet")
229255

230-
def _write_csv_s3(self, uri: str):
256+
def _write_csv_s3(self, uri: str, frames_to_flush: list):
231257
"""
232258
Write collected data to AWS S3 in CSV format.
233259
234260
Parameters
235261
----------
236262
uri: str
237263
S3 URI (e.g., s3://bucket/path) to upload files to.
264+
frames_to_flush : list
265+
the collected data in the current thread.
238266
"""
239-
self._write_s3(uri, format_="csv")
267+
self._write_s3(uri=uri, frames_to_flush=frames_to_flush, format_="csv")
240268

241-
def _write_parquet_s3(self, uri: str):
269+
def _write_parquet_s3(self, uri: str, frames_to_flush: list):
242270
"""
243271
Write collected data to AWS S3 in Parquet format.
244272
245273
Parameters
246274
----------
247275
uri: str
248276
S3 URI (e.g., s3://bucket/path) to upload files to.
277+
frames_to_flush : list
278+
the collected data in the current thread.
249279
"""
250-
self._write_s3(uri, format_="parquet")
280+
self._write_s3(uri=uri, frames_to_flush=frames_to_flush, format_="parquet")
251281

252-
def _write_s3(self, uri: str, format_: str):
282+
def _write_s3(self, uri: str, frames_to_flush: list, format_: str):
253283
"""
254284
Upload collected data to S3 in a specified format.
255285
256286
Parameters
257287
----------
258288
uri: str
259289
S3 URI to upload to.
290+
frames_to_flush : list
291+
the collected data in the current thread.
260292
format_: str
261293
Format of the output files ("csv" or "parquet").
262294
"""
263295
s3 = boto3.client("s3")
264296
parsed = urlparse(uri)
265297
bucket = parsed.netloc
266298
prefix = parsed.path.lstrip("/")
267-
for kind, step, lf in self._frames:
299+
for kind, step, batch, lf in frames_to_flush:
268300
df = lf.collect()
269301
with tempfile.NamedTemporaryFile(suffix=f".{format_}") as tmp:
270302
if format_ == "csv":
271303
df.write_csv(tmp.name)
272304
elif format_ == "parquet":
273305
df.write_parquet(tmp.name)
274-
key = f"{prefix}/{kind}_step{step}.{format_}"
306+
key = f"{prefix}/{kind}_step{step}_batch{batch}.{format_}"
275307
s3.upload_file(tmp.name, bucket, key)
276308

277-
def _write_postgres(self, uri: str):
309+
def _write_postgres(self, uri: str, frames_to_flush: list):
278310
"""
279311
Write collected data to a PostgreSQL database.
280312
@@ -285,10 +317,12 @@ def _write_postgres(self, uri: str):
285317
----------
286318
uri: str
287319
PostgreSQL connection URI in the form postgresql://testuser:testpass@localhost:5432/testdb
320+
frames_to_flush : list
321+
the collected data in the current thread.
288322
"""
289323
conn = self._get_db_connection(uri=uri)
290324
cur = conn.cursor()
291-
for kind, step, lf in self._frames:
325+
for kind, step, batch, lf in frames_to_flush:
292326
df = lf.collect()
293327
table = f"{kind}_data"
294328
cols = df.columns

0 commit comments

Comments
 (0)