Skip to content

Commit f1d844c

Browse files
committed
Fix: Add agent_id handling
1 parent d3b61f3 commit f1d844c

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

mesa_frames/concrete/datacollector.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def step(self):
5454
self.dc.flush()
5555
"""
5656

57+
from unittest import result
5758
import polars as pl
5859
import boto3
5960
from urllib.parse import urlparse
@@ -193,18 +194,37 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
193194
else:
194195
records = []
195196
for agent in agent_set.values():
196-
if hasattr(agent, "unique_id"):
197-
agent_id = agent.unique_id
198-
elif hasattr(agent, "id"):
199-
agent_id = agent.id
200-
else:
201-
agent_id = None
197+
agent_id = getattr(agent, "unique_id", getattr(agent, "id", None))
202198
records.append({"agent_id": agent_id, col_name: getattr(agent, col_name, None)})
203199
df = pl.DataFrame(records)
204200
else:
205-
df = reporter(self._model)
206-
if not isinstance(df, pl.DataFrame):
207-
raise TypeError(f"Agent reporter {col_name} must return a Polars DataFrame")
201+
result = reporter(self._model)
202+
203+
# Handle Polars DataFrame directly
204+
if isinstance(result, pl.DataFrame):
205+
df = result
206+
elif isinstance(result, list):
207+
df = pl.DataFrame(result)
208+
elif isinstance(result, dict):
209+
df = pl.DataFrame([result])
210+
211+
# Handle dict, list, scalar reporters
212+
else:
213+
# Try to build per-agent data if possible
214+
if hasattr(self._model, "agents"):
215+
records = []
216+
for agent in self._model.agents:
217+
agent_id = getattr(agent, "unique_id", getattr(agent, "id", None))
218+
value = getattr(agent, col_name, result if not callable(result) else None)
219+
records.append({"agent_id": agent_id, col_name: value})
220+
df = pl.DataFrame(records)
221+
else:
222+
# Fallback for scalar or model-level reporters
223+
df = pl.DataFrame([{col_name: result}])
224+
225+
# Ensure column consistency
226+
if "agent_id" not in df.columns:
227+
df = df.with_columns(pl.lit(None).alias("agent_id"))
208228

209229
df = df.with_columns([
210230
pl.lit(current_model_step).alias("step"),

0 commit comments

Comments
 (0)