Skip to content

Commit 25a8b6d

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent d3b61f3 commit 25a8b6d

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

mesa_frames/concrete/datacollector.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
175175
Collect agent-level data using the agent_reporters, including unique agent IDs
176176
177177
Constructs a LazyFrame with one column per reporter and
178-
includes
178+
includes
179179
- agent_id : unique identifier for each agent
180180
- step, seed and batch columns for context.
181181
- Columns for all requested agent reporters.
@@ -187,9 +187,15 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
187187
agent_set = self._model.sets[reporter]
188188

189189
if hasattr(agent_set, "df"):
190-
df = agent_set.df.select(["id", col_name]).rename({"id": "agent_id"})
190+
df = agent_set.df.select(["id", col_name]).rename(
191+
{"id": "agent_id"}
192+
)
191193
elif hasattr(agent_set, "to_polars"):
192-
df = agent_set.to_polars().select(["id", col_name]).rename({"id": "agent_id"})
194+
df = (
195+
agent_set.to_polars()
196+
.select(["id", col_name])
197+
.rename({"id": "agent_id"})
198+
)
193199
else:
194200
records = []
195201
for agent in agent_set.values():
@@ -199,30 +205,43 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
199205
agent_id = agent.id
200206
else:
201207
agent_id = None
202-
records.append({"agent_id": agent_id, col_name: getattr(agent, col_name, None)})
208+
records.append(
209+
{
210+
"agent_id": agent_id,
211+
col_name: getattr(agent, col_name, None),
212+
}
213+
)
203214
df = pl.DataFrame(records)
204215
else:
205216
df = reporter(self._model)
206217
if not isinstance(df, pl.DataFrame):
207-
raise TypeError(f"Agent reporter {col_name} must return a Polars DataFrame")
208-
209-
df = df.with_columns([
210-
pl.lit(current_model_step).alias("step"),
211-
pl.lit(str(self.seed)).alias("seed"),
212-
pl.lit(batch_id).alias("batch"),
213-
])
218+
raise TypeError(
219+
f"Agent reporter {col_name} must return a Polars DataFrame"
220+
)
221+
222+
df = df.with_columns(
223+
[
224+
pl.lit(current_model_step).alias("step"),
225+
pl.lit(str(self.seed)).alias("seed"),
226+
pl.lit(batch_id).alias("batch"),
227+
]
228+
)
214229
all_agent_frames.append(df)
215230

216231
if all_agent_frames:
217232
merged_df = all_agent_frames[0]
218233
for next_df in all_agent_frames[1:]:
219234
if "agent_id" not in next_df.columns:
220235
continue
221-
merged_df = merged_df.join(next_df, on=["agent_id", "step", "seed", "batch"], how="outer")
236+
merged_df = merged_df.join(
237+
next_df, on=["agent_id", "step", "seed", "batch"], how="outer"
238+
)
222239

223240
agent_lazy_frame = merged_df.lazy()
224-
self._frames.append(("agent", current_model_step, batch_id, agent_lazy_frame))
225-
241+
self._frames.append(
242+
("agent", current_model_step, batch_id, agent_lazy_frame)
243+
)
244+
226245
@property
227246
def data(self) -> dict[str, pl.DataFrame]:
228247
"""

0 commit comments

Comments
 (0)