Skip to content

Commit e0d1156

Browse files
authored
AgentSet: Add agg method (#2266)
This PR introduces the `agg` method to the `AgentSet` class, allowing users to apply aggregation functions (e.g., `min`, `max`, `sum`, `np.mean`) to attributes of agents within the `AgentSet`. This enhancement makes it easier to compute summary statistics across agent attributes directly within the `AgentSet` interface. This will be useful in both the model operation itself as well as for future DataCollector use.
1 parent 221084d commit e0d1156

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

mesa/agent.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,20 @@ def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
304304

305305
return res
306306

307+
def agg(self, attribute: str, func: Callable) -> Any:
308+
"""
309+
Aggregate an attribute of all agents in the AgentSet using a specified function.
310+
311+
Args:
312+
attribute (str): The name of the attribute to aggregate.
313+
func (Callable): The function to apply to the attribute values (e.g., min, max, sum, np.mean).
314+
315+
Returns:
316+
Any: The result of applying the function to the attribute values. Often a single value.
317+
"""
318+
values = self.get(attribute)
319+
return func(values)
320+
307321
def get(self, attr_names: str | list[str]) -> list[Any]:
308322
"""
309323
Retrieve the specified attribute(s) from each agent in the AgentSet.

tests/test_agent.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pickle
22

3+
import numpy as np
34
import pytest
45

56
from mesa.agent import Agent, AgentSet
@@ -276,6 +277,41 @@ def remove_function(agent):
276277
assert len(agentset) == 0
277278

278279

280+
def test_agentset_agg():
281+
model = Model()
282+
agents = [TestAgent(i, model) for i in range(10)]
283+
284+
# Assign some values to attributes
285+
for i, agent in enumerate(agents):
286+
agent.energy = i + 1
287+
agent.wealth = 10 * (i + 1)
288+
289+
agentset = AgentSet(agents, model)
290+
291+
# Test min aggregation
292+
min_energy = agentset.agg("energy", min)
293+
assert min_energy == 1
294+
295+
# Test max aggregation
296+
max_energy = agentset.agg("energy", max)
297+
assert max_energy == 10
298+
299+
# Test sum aggregation
300+
total_energy = agentset.agg("energy", sum)
301+
assert total_energy == sum(range(1, 11))
302+
303+
# Test mean aggregation using numpy
304+
avg_wealth = agentset.agg("wealth", np.mean)
305+
assert avg_wealth == 55.0
306+
307+
# Test aggregation with a custom function
308+
def custom_func(values):
309+
return sum(values) / len(values)
310+
311+
custom_avg_energy = agentset.agg("energy", custom_func)
312+
assert custom_avg_energy == 5.5
313+
314+
279315
def test_agentset_set_method():
280316
# Initialize the model and agents with and without existing attributes
281317
class TestAgentWithAttribute(Agent):

0 commit comments

Comments
 (0)