Skip to content

Commit 3d75d30

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 45184a4 + 50b2964 commit 3d75d30

File tree

5 files changed

+267
-16
lines changed

5 files changed

+267
-16
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
## Summary
2+
<!-- Provide a brief summary of the bug and its impact. -->
3+
4+
## Bug / Issue
5+
<!-- Link to the related issue(s) and describe the bug. Include details like the context, what was expected, and what actually happened. -->
6+
7+
## Implementation
8+
<!-- Describe the changes made to resolve the issue. Highlight any important parts of the code that were modified. -->
9+
10+
## Testing
11+
<!-- Detail the testing performed to verify the fix. Include information on test cases, steps taken, and any relevant results.
12+
13+
If you're fixing the visualisation, add before/after screenshots. -->
14+
15+
## Additional Notes
16+
<!-- Add any additional information that may be relevant for the reviewers, such as potential side effects, dependencies, or related work.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
## Summary
2+
<!-- Provide a concise summary of the feature and its purpose. -->
3+
4+
## Motive
5+
<!-- Explain the reasoning behind this feature. Include details on the problem it addresses or the enhancement it provides. -->
6+
7+
## Implementation
8+
<!-- Describe how the feature was implemented. Include details on the approach taken, important decisions made, and code changes. -->
9+
10+
## Usage Examples
11+
<!-- Provide code snippets or examples demonstrating how to use the new feature. Highlight key scenarios where this feature will be beneficial.
12+
13+
If you're modifying the visualisation, add before/after screenshots. -->
14+
15+
## Additional Notes
16+
<!-- Add any additional information that may be relevant for the reviewers, such as potential side effects, dependencies, or related work. -->

mesa/agent.py

Lines changed: 159 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
import contextlib
1212
import copy
1313
import operator
14+
import warnings
1415
import weakref
16+
from collections import defaultdict
1517
from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence
1618
from random import Random
1719

@@ -216,25 +218,64 @@ def _update(self, agents: Iterable[Agent]):
216218
self._agents = weakref.WeakKeyDictionary({agent: None for agent in agents})
217219
return self
218220

219-
def do(
220-
self, method: str | Callable, *args, return_results: bool = False, **kwargs
221-
) -> AgentSet | list[Any]:
221+
def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
222222
"""
223223
Invoke a method or function on each agent in the AgentSet.
224224
225225
Args:
226-
method (str, callable): the callable to do on each agents
226+
method (str, callable): the callable to do on each agent
227227
228228
* in case of str, the name of the method to call on each agent.
229229
* in case of callable, the function to be called with each agent as first argument
230230
231-
return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls.
232231
*args: Variable length argument list passed to the callable being called.
233232
**kwargs: Arbitrary keyword arguments passed to the callable being called.
234233
235234
Returns:
236235
AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself.
237236
"""
237+
try:
238+
return_results = kwargs.pop("return_results")
239+
except KeyError:
240+
return_results = False
241+
else:
242+
warnings.warn(
243+
"Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and "
244+
"AgentSet.map in case of return_results=True",
245+
stacklevel=2,
246+
)
247+
248+
if return_results:
249+
return self.map(method, *args, **kwargs)
250+
251+
# we iterate over the actual weakref keys and check if weakref is alive before calling the method
252+
if isinstance(method, str):
253+
for agentref in self._agents.keyrefs():
254+
if (agent := agentref()) is not None:
255+
getattr(agent, method)(*args, **kwargs)
256+
else:
257+
for agentref in self._agents.keyrefs():
258+
if (agent := agentref()) is not None:
259+
method(agent, *args, **kwargs)
260+
261+
return self
262+
263+
def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
264+
"""
265+
Invoke a method or function on each agent in the AgentSet and return the results.
266+
267+
Args:
268+
method (str, callable): the callable to apply on each agent
269+
270+
* in case of str, the name of the method to call on each agent.
271+
* in case of callable, the function to be called with each agent as first argument
272+
273+
*args: Variable length argument list passed to the callable being called.
274+
**kwargs: Arbitrary keyword arguments passed to the callable being called.
275+
276+
Returns:
277+
list[Any]: The results of the callable calls
278+
"""
238279
# we iterate over the actual weakref keys and check if weakref is alive before calling the method
239280
if isinstance(method, str):
240281
res = [
@@ -249,7 +290,7 @@ def do(
249290
if (agent := agentref()) is not None
250291
]
251292

252-
return res if return_results else self
293+
return res
253294

254295
def get(self, attr_names: str | list[str]) -> list[Any]:
255296
"""
@@ -357,7 +398,116 @@ def random(self) -> Random:
357398
"""
358399
return self.model.random
359400

401+
def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy:
402+
"""
403+
Group agents by the specified attribute or return from the callable
404+
405+
Args:
406+
by (Callable, str): used to determine what to group agents by
407+
408+
* if ``by`` is a callable, it will be called for each agent and the return is used
409+
for grouping
410+
* if ``by`` is a str, it should refer to an attribute on the agent and the value
411+
of this attribute will be used for grouping
412+
result_type (str, optional): The datatype for the resulting groups {"agentset", "list"}
413+
Returns:
414+
GroupBy
415+
416+
417+
Notes:
418+
There might be performance benefits to using `result_type='list'` if you don't need the advanced functionality
419+
of an AgentSet.
420+
421+
"""
422+
groups = defaultdict(list)
423+
424+
if isinstance(by, Callable):
425+
for agent in self:
426+
groups[by(agent)].append(agent)
427+
else:
428+
for agent in self:
429+
groups[getattr(agent, by)].append(agent)
430+
431+
if result_type == "agentset":
432+
return GroupBy(
433+
{k: AgentSet(v, model=self.model) for k, v in groups.items()}
434+
)
435+
else:
436+
return GroupBy(groups)
437+
438+
# consider adding for performance reasons
439+
# for Sequence: __reversed__, index, and count
440+
# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
441+
442+
443+
class GroupBy:
444+
"""Helper class for AgentSet.groupby
445+
446+
447+
Attributes:
448+
groups (dict): A dictionary with the group_name as key and group as values
449+
450+
"""
451+
452+
def __init__(self, groups: dict[Any, list | AgentSet]):
453+
self.groups: dict[Any, list | AgentSet] = groups
454+
455+
def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]:
456+
"""Apply the specified callable to each group and return the results.
457+
458+
Args:
459+
method (Callable, str): The callable to apply to each group,
460+
461+
* if ``method`` is a callable, it will be called it will be called with the group as first argument
462+
* if ``method`` is a str, it should refer to a method on the group
463+
464+
Additional arguments and keyword arguments will be passed on to the callable.
465+
466+
Returns:
467+
dict with group_name as key and the return of the method as value
468+
469+
Notes:
470+
this method is useful for methods or functions that do return something. It
471+
will break method chaining. For that, use ``do`` instead.
472+
473+
"""
474+
if isinstance(method, str):
475+
return {
476+
k: getattr(v, method)(*args, **kwargs) for k, v in self.groups.items()
477+
}
478+
else:
479+
return {k: method(v, *args, **kwargs) for k, v in self.groups.items()}
480+
481+
def do(self, method: Callable | str, *args, **kwargs) -> GroupBy:
482+
"""Apply the specified callable to each group
483+
484+
Args:
485+
method (Callable, str): The callable to apply to each group,
486+
487+
* if ``method`` is a callable, it will be called it will be called with the group as first argument
488+
* if ``method`` is a str, it should refer to a method on the group
489+
490+
Additional arguments and keyword arguments will be passed on to the callable.
491+
492+
Returns:
493+
the original GroupBy instance
494+
495+
Notes:
496+
this method is useful for methods or functions that don't return anything and/or
497+
if you want to chain multiple do calls
498+
499+
"""
500+
if isinstance(method, str):
501+
for v in self.groups.values():
502+
getattr(v, method)(*args, **kwargs)
503+
else:
504+
for v in self.groups.values():
505+
method(v, *args, **kwargs)
506+
507+
return self
508+
509+
def __iter__(self):
510+
return iter(self.groups.items())
360511

361-
# consider adding for performance reasons
362-
# for Sequence: __reversed__, index, and count
363-
# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
512+
def __len__(self):
513+
return len(self.groups)

mesa/visualization/solara_viz.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,11 @@ def do_reseed():
158158
"""Update the random seed for the model."""
159159
reactive_seed.value = model.random.random()
160160

161-
dependencies = [current_step.value, reactive_seed.value]
161+
dependencies = [
162+
*list(model_parameters.values()),
163+
current_step.value,
164+
reactive_seed.value,
165+
]
162166

163167
# if space drawer is disabled, do not include it
164168
layout_types = [{"Space": "default"}] if space_drawer else []

tests/test_agent.py

Lines changed: 71 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,6 @@ def test_function(agent):
8383
assert all(
8484
a1 == a2.unique_id for a1, a2 in zip(agentset.get("unique_id"), agentset)
8585
)
86-
assert all(
87-
a1 == a2.unique_id
88-
for a1, a2 in zip(
89-
agentset.do("get_unique_identifier", return_results=True), agentset
90-
)
91-
)
9286
assert agentset == agentset.do("get_unique_identifier")
9387

9488
agentset.discard(agents[0])
@@ -276,6 +270,35 @@ def remove_function(agent):
276270
assert len(agentset) == 0
277271

278272

273+
def test_agentset_map_str():
274+
model = Model()
275+
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
276+
agentset = AgentSet(agents, model)
277+
278+
with pytest.raises(AttributeError):
279+
agentset.do("non_existing_method")
280+
281+
results = agentset.map("get_unique_identifier")
282+
assert all(i == entry for i, entry in zip(results, range(1, 11)))
283+
284+
285+
def test_agentset_map_callable():
286+
model = Model()
287+
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
288+
agentset = AgentSet(agents, model)
289+
290+
# Test callable with non-existent function
291+
with pytest.raises(AttributeError):
292+
agentset.map(lambda agent: agent.non_existing_method())
293+
294+
# tests for addition and removal in do using callables
295+
# do iterates, so no error should be raised to change size while iterating
296+
# related to issue #1595
297+
298+
results = agentset.map(lambda agent: agent.unique_id)
299+
assert all(i == entry for i, entry in zip(results, range(1, 11)))
300+
301+
279302
def test_agentset_get_attribute():
280303
model = Model()
281304
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
@@ -348,3 +371,45 @@ def test_agentset_shuffle():
348371
agentset = AgentSet(test_agents, model=model)
349372
agentset.shuffle(inplace=True)
350373
assert not all(a1 == a2 for a1, a2 in zip(test_agents, agentset))
374+
375+
376+
def test_agentset_groupby():
377+
class TestAgent(Agent):
378+
def __init__(self, unique_id, model):
379+
super().__init__(unique_id, model)
380+
self.even = self.unique_id % 2 == 0
381+
382+
def get_unique_identifier(self):
383+
return self.unique_id
384+
385+
model = Model()
386+
agents = [TestAgent(model.next_id(), model) for _ in range(10)]
387+
agentset = AgentSet(agents, model)
388+
389+
groups = agentset.groupby("even")
390+
assert len(groups.groups[True]) == 5
391+
assert len(groups.groups[False]) == 5
392+
393+
groups = agentset.groupby(lambda a: a.unique_id % 2 == 0)
394+
assert len(groups.groups[True]) == 5
395+
assert len(groups.groups[False]) == 5
396+
assert len(groups) == 2
397+
398+
for group_name, group in groups:
399+
assert len(group) == 5
400+
assert group_name in {True, False}
401+
402+
sizes = agentset.groupby("even", result_type="list").map(len)
403+
assert sizes == {True: 5, False: 5}
404+
405+
attributes = agentset.groupby("even", result_type="agentset").map("get", "even")
406+
for group_name, group in attributes.items():
407+
assert all(group_name == entry for entry in group)
408+
409+
groups = agentset.groupby("even", result_type="agentset")
410+
another_ref_to_groups = groups.do("do", "step")
411+
assert groups == another_ref_to_groups
412+
413+
groups = agentset.groupby("even", result_type="agentset")
414+
another_ref_to_groups = groups.do(lambda x: x.do("step"))
415+
assert groups == another_ref_to_groups

0 commit comments

Comments
 (0)