Skip to content

Commit 84881e7

Browse files
authored
GroupBy: Add count and agg methods (#2290)
Added two new methods to the `GroupBy` class to enhance aggregation and group operations: - `count`: Returns the count of agents in each group. - `agg`: Performs aggregation on a specific attribute across groups, applying a function like `sum`, `min`, `max`, or `mean`. These methods improve flexibility in applying both group-level and attribute-specific operations.
1 parent 9179edb commit 84881e7

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

mesa/agent.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import warnings
1616
import weakref
1717
from collections import defaultdict
18-
from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
18+
from collections.abc import Callable, Hashable, Iterable, Iterator, MutableSet, Sequence
1919
from random import Random
2020

2121
# mypy
@@ -611,6 +611,29 @@ def do(self, method: Callable | str, *args, **kwargs) -> GroupBy:
611611

612612
return self
613613

614+
def count(self) -> dict[Any, int]:
615+
"""Return the count of agents in each group.
616+
617+
Returns:
618+
dict: A dictionary mapping group names to the number of agents in each group.
619+
"""
620+
return {k: len(v) for k, v in self.groups.items()}
621+
622+
def agg(self, attr_name: str, func: Callable) -> dict[Hashable, Any]:
623+
"""Aggregate the values of a specific attribute across each group using the provided function.
624+
625+
Args:
626+
attr_name (str): The name of the attribute to aggregate.
627+
func (Callable): The function to apply (e.g., sum, min, max, mean).
628+
629+
Returns:
630+
dict[Hashable, Any]: A dictionary mapping group names to the result of applying the aggregation function.
631+
"""
632+
return {
633+
group_name: func([getattr(agent, attr_name) for agent in group])
634+
for group_name, group in self.groups.items()
635+
}
636+
614637
def __iter__(self): # noqa: D105
615638
return iter(self.groups.items())
616639

tests/test_agent.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,7 @@ class TestAgent(Agent):
524524
def __init__(self, model):
525525
super().__init__(model)
526526
self.even = self.unique_id % 2 == 0
527+
self.value = self.unique_id * 10
527528

528529
def get_unique_identifier(self):
529530
return self.unique_id
@@ -560,6 +561,37 @@ def get_unique_identifier(self):
560561
another_ref_to_groups = groups.do(lambda x: x.do("step"))
561562
assert groups == another_ref_to_groups
562563

564+
# New tests for count() method
565+
groups = agentset.groupby("even")
566+
count_result = groups.count()
567+
assert count_result == {True: 5, False: 5}
568+
569+
# New tests for agg() method
570+
groups = agentset.groupby("even")
571+
sum_result = groups.agg("value", sum)
572+
assert sum_result[True] == sum(agent.value for agent in agents if agent.even)
573+
assert sum_result[False] == sum(agent.value for agent in agents if not agent.even)
574+
575+
max_result = groups.agg("value", max)
576+
assert max_result[True] == max(agent.value for agent in agents if agent.even)
577+
assert max_result[False] == max(agent.value for agent in agents if not agent.even)
578+
579+
min_result = groups.agg("value", min)
580+
assert min_result[True] == min(agent.value for agent in agents if agent.even)
581+
assert min_result[False] == min(agent.value for agent in agents if not agent.even)
582+
583+
# Test with a custom aggregation function
584+
def custom_agg(values):
585+
return sum(values) / len(values) if values else 0
586+
587+
custom_result = groups.agg("value", custom_agg)
588+
assert custom_result[True] == custom_agg(
589+
[agent.value for agent in agents if agent.even]
590+
)
591+
assert custom_result[False] == custom_agg(
592+
[agent.value for agent in agents if not agent.even]
593+
)
594+
563595

564596
def test_oldstyle_agent_instantiation():
565597
"""Old behavior of Agent creation with unique_id and model as positional arguments.

0 commit comments

Comments
 (0)