Skip to content

Commit eb8d56c

Browse files
committed
Fix: returns Polars DataFrame.
2 parents f1d844c + 25a8b6d commit eb8d56c

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

mesa_frames/concrete/datacollector.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,12 @@ def _collect_model_reporters(self, current_model_step: int, batch_id: int):
173173

174174
def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
175175
"""
176-
Collect agent-level data using the agent_reporters, including unique agent IDs
176+
Collect agent-level data using the agent_reporters, including unique agent IDs.
177177
178-
Constructs a LazyFrame with one column per reporter and
179-
includes
178+
Constructs a LazyFrame with one column per reporter and includes:
180179
- agent_id : unique identifier for each agent
181-
- step, seed and batch columns for context.
182-
- Columns for all requested agent reporters.
180+
- step, seed, and batch columns for context
181+
- Columns for all requested agent reporters
183182
"""
184183
all_agent_frames = []
185184

@@ -197,20 +196,20 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
197196
agent_id = getattr(agent, "unique_id", getattr(agent, "id", None))
198197
records.append({"agent_id": agent_id, col_name: getattr(agent, col_name, None)})
199198
df = pl.DataFrame(records)
199+
200200
else:
201201
result = reporter(self._model)
202202

203-
# Handle Polars DataFrame directly
203+
## Case 1: already a DataFrame
204204
if isinstance(result, pl.DataFrame):
205205
df = result
206-
elif isinstance(result, list):
207-
df = pl.DataFrame(result)
206+
## Case 2: dict or list -> convert
208207
elif isinstance(result, dict):
209208
df = pl.DataFrame([result])
210-
211-
# Handle dict, list, scalar reporters
209+
elif isinstance(result, list):
210+
df = pl.DataFrame(result)
212211
else:
213-
# Try to build per-agent data if possible
212+
## Case 3: scalar or callable reporter
214213
if hasattr(self._model, "agents"):
215214
records = []
216215
for agent in self._model.agents:
@@ -219,13 +218,13 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
219218
records.append({"agent_id": agent_id, col_name: value})
220219
df = pl.DataFrame(records)
221220
else:
222-
# Fallback for scalar or model-level reporters
223221
df = pl.DataFrame([{col_name: result}])
224222

225-
# Ensure column consistency
223+
## Ensure agent_id exists
226224
if "agent_id" not in df.columns:
227225
df = df.with_columns(pl.lit(None).alias("agent_id"))
228-
226+
227+
## Add meta columns
229228
df = df.with_columns([
230229
pl.lit(current_model_step).alias("step"),
231230
pl.lit(str(self.seed)).alias("seed"),
@@ -242,7 +241,8 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
242241

243242
agent_lazy_frame = merged_df.lazy()
244243
self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame))
245-
244+
245+
246246
@property
247247
def data(self) -> dict[str, pl.DataFrame]:
248248
"""

0 commit comments

Comments
 (0)