Skip to content

Commit d3b61f3

Browse files
committed
feat: Adding agent_id
1 parent 33555b4 commit d3b61f3

File tree

1 file changed

+42
-13
lines changed

1 file changed

+42
-13
lines changed

mesa_frames/concrete/datacollector.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -172,28 +172,57 @@ def _collect_model_reporters(self, current_model_step: int, batch_id: int):
172172

173173
def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
174174
"""
175-
Collect agent-level data using the agent_reporters.
175+
Collect agent-level data using the agent_reporters, including unique agent IDs
176176
177177
Constructs a LazyFrame with one column per reporter and
178-
includes `step` and `seed` metadata. Appends it to internal storage.
178+
includes
179+
- agent_id : unique identifier for each agent
180+
- step, seed and batch columns for context.
181+
- Columns for all requested agent reporters.
179182
"""
180-
agent_data_dict = {}
183+
all_agent_frames = []
184+
181185
for col_name, reporter in self._agent_reporters.items():
182186
if isinstance(reporter, str):
183-
for k, v in self._model.sets[reporter].items():
184-
agent_data_dict[col_name + "_" + str(k.__class__.__name__)] = v
187+
agent_set = self._model.sets[reporter]
188+
189+
if hasattr(agent_set, "df"):
190+
df = agent_set.df.select(["id", col_name]).rename({"id": "agent_id"})
191+
elif hasattr(agent_set, "to_polars"):
192+
df = agent_set.to_polars().select(["id", col_name]).rename({"id": "agent_id"})
193+
else:
194+
records = []
195+
for agent in agent_set.values():
196+
if hasattr(agent, "unique_id"):
197+
agent_id = agent.unique_id
198+
elif hasattr(agent, "id"):
199+
agent_id = agent.id
200+
else:
201+
agent_id = None
202+
records.append({"agent_id": agent_id, col_name: getattr(agent, col_name, None)})
203+
df = pl.DataFrame(records)
185204
else:
186-
agent_data_dict[col_name] = reporter(self._model)
187-
agent_lazy_frame = pl.LazyFrame(agent_data_dict)
188-
agent_lazy_frame = agent_lazy_frame.with_columns(
189-
[
205+
df = reporter(self._model)
206+
if not isinstance(df, pl.DataFrame):
207+
raise TypeError(f"Agent reporter {col_name} must return a Polars DataFrame")
208+
209+
df = df.with_columns([
190210
pl.lit(current_model_step).alias("step"),
191211
pl.lit(str(self.seed)).alias("seed"),
192212
pl.lit(batch_id).alias("batch"),
193-
]
194-
)
195-
self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame))
196-
213+
])
214+
all_agent_frames.append(df)
215+
216+
if all_agent_frames:
217+
merged_df = all_agent_frames[0]
218+
for next_df in all_agent_frames[1:]:
219+
if "agent_id" not in next_df.columns:
220+
continue
221+
merged_df = merged_df.join(next_df, on=["agent_id", "step", "seed", "batch"], how="outer")
222+
223+
agent_lazy_frame = merged_df.lazy()
224+
self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame))
225+
197226
@property
198227
def data(self) -> dict[str, pl.DataFrame]:
199228
"""

0 commit comments

Comments
 (0)