Skip to content

Commit 3847799

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 2b5e822 + 94bef06 commit 3847799

File tree

8 files changed

+89
-19
lines changed

8 files changed

+89
-19
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ ci:
44
repos:
55
- repo: https://github.com/astral-sh/ruff-pre-commit
66
# Ruff version.
7-
rev: v0.5.6
7+
rev: v0.6.3
88
hooks:
99
# Run the linter.
1010
- id: ruff

benchmarks/WolfSheep/wolf_sheep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ def __init__(
227227
patch.move_to(cell)
228228

229229
def step(self):
230-
self.get_agents_of_type(Sheep).shuffle(inplace=True).do("step")
231-
self.get_agents_of_type(Wolf).shuffle(inplace=True).do("step")
230+
self.agents_by_type[Sheep].shuffle(inplace=True).do("step")
231+
self.agents_by_type[Wolf].shuffle(inplace=True).do("step")
232232

233233

234234
if __name__ == "__main__":

mesa/agent.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,20 @@ def map(self, method: str | Callable, *args, **kwargs) -> list[Any]:
304304

305305
return res
306306

307+
def agg(self, attribute: str, func: Callable) -> Any:
308+
"""
309+
Aggregate an attribute of all agents in the AgentSet using a specified function.
310+
311+
Args:
312+
attribute (str): The name of the attribute to aggregate.
313+
func (Callable): The function to apply to the attribute values (e.g., min, max, sum, np.mean).
314+
315+
Returns:
316+
Any: The result of applying the function to the attribute values. Often a single value.
317+
"""
318+
values = self.get(attribute)
319+
return func(values)
320+
307321
def get(self, attr_names: str | list[str]) -> list[Any]:
308322
"""
309323
Retrieve the specified attribute(s) from each agent in the AgentSet.

mesa/experimental/devs/examples/wolf_sheep.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ def __init__(
231231
self.grid.place_agent(patch, pos)
232232

233233
def step(self):
234-
self.get_agents_of_type(Sheep).shuffle(inplace=True).do("step")
235-
self.get_agents_of_type(Wolf).shuffle(inplace=True).do("step")
234+
self.agents_by_type[Sheep].shuffle(inplace=True).do("step")
235+
self.agents_by_type[Wolf].shuffle(inplace=True).do("step")
236236

237237

238238
if __name__ == "__main__":

mesa/model.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,13 @@ class Model:
3333
Properties:
3434
agents: An AgentSet containing all agents in the model
3535
agent_types: A list of different agent types present in the model.
36+
agents_by_type: A dictionary where the keys are agent types and the values are the corresponding AgentSets.
3637
steps: An integer representing the number of steps the model has taken.
3738
It increases automatically at the start of each step() call.
3839
3940
Methods:
4041
get_agents_of_type: Returns an AgentSet of agents of the specified type.
42+
Deprecated: Use agents_by_type[agenttype] instead.
4143
run_model: Runs the model's simulation until a defined end condition is reached.
4244
step: Executes a single step of the model's simulation process.
4345
next_id: Generates and returns the next unique identifier for an agent.
@@ -106,24 +108,26 @@ def agent_types(self) -> list[type]:
106108
"""Return a list of all unique agent types registered with the model."""
107109
return list(self._agents_by_type.keys())
108110

109-
def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet:
110-
"""Retrieves an AgentSet containing all agents of the specified type.
111-
112-
Args:
113-
agenttype: The type of agent to retrieve.
114-
115-
Raises:
116-
KeyError: If agenttype does not exist
117-
111+
@property
112+
def agents_by_type(self) -> dict[type[Agent], AgentSet]:
113+
"""A dictionary where the keys are agent types and the values are the corresponding AgentSets."""
114+
return self._agents_by_type
118115

119-
"""
120-
return self._agents_by_type[agenttype]
116+
def get_agents_of_type(self, agenttype: type[Agent]) -> AgentSet:
117+
"""Deprecated: Retrieves an AgentSet containing all agents of the specified type."""
118+
warnings.warn(
119+
f"Model.get_agents_of_type() is deprecated, please replace get_agents_of_type({agenttype})"
120+
f"with the property agents_by_type[{agenttype}].",
121+
DeprecationWarning,
122+
stacklevel=2,
123+
)
124+
return self.agents_by_type[agenttype]
121125

122126
def _setup_agent_registration(self):
123127
"""helper method to initialize the agent registration datastructures"""
124128
self._agents = {} # the hard references to all agents in the model
125129
self._agents_by_type: dict[
126-
type, AgentSet
130+
type[Agent], AgentSet
127131
] = {} # a dict with an agentset for each class of agents
128132
self._all_agents = AgentSet([], self) # an agenset with all agents
129133

tests/test_agent.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pickle
22

3+
import numpy as np
34
import pytest
45

56
from mesa.agent import Agent, AgentSet
@@ -276,6 +277,41 @@ def remove_function(agent):
276277
assert len(agentset) == 0
277278

278279

280+
def test_agentset_agg():
281+
model = Model()
282+
agents = [TestAgent(i, model) for i in range(10)]
283+
284+
# Assign some values to attributes
285+
for i, agent in enumerate(agents):
286+
agent.energy = i + 1
287+
agent.wealth = 10 * (i + 1)
288+
289+
agentset = AgentSet(agents, model)
290+
291+
# Test min aggregation
292+
min_energy = agentset.agg("energy", min)
293+
assert min_energy == 1
294+
295+
# Test max aggregation
296+
max_energy = agentset.agg("energy", max)
297+
assert max_energy == 10
298+
299+
# Test sum aggregation
300+
total_energy = agentset.agg("energy", sum)
301+
assert total_energy == sum(range(1, 11))
302+
303+
# Test mean aggregation using numpy
304+
avg_wealth = agentset.agg("wealth", np.mean)
305+
assert avg_wealth == 55.0
306+
307+
# Test aggregation with a custom function
308+
def custom_func(values):
309+
return sum(values) / len(values)
310+
311+
custom_avg_energy = agentset.agg("energy", custom_func)
312+
assert custom_avg_energy == 5.5
313+
314+
279315
def test_agentset_set_method():
280316
# Initialize the model and agents with and without existing attributes
281317
class TestAgentWithAttribute(Agent):

tests/test_model.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from mesa.agent import Agent
1+
from mesa.agent import Agent, AgentSet
22
from mesa.model import Model
33

44

@@ -53,3 +53,19 @@ class TestAgent(Agent):
5353
test_agent = TestAgent(model.next_id(), model)
5454
assert test_agent in model.agents
5555
assert type(test_agent) in model.agent_types
56+
57+
58+
def test_agents_by_type():
59+
class Wolf(Agent):
60+
pass
61+
62+
class Sheep(Agent):
63+
pass
64+
65+
model = Model()
66+
wolf = Wolf(1, model)
67+
sheep = Sheep(2, model)
68+
69+
assert model.agents_by_type[Wolf] == AgentSet([wolf], model)
70+
assert model.agents_by_type[Sheep] == AgentSet([sheep], model)
71+
assert len(model.agents_by_type) == 2

tests/test_time.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def test_random_activation_counts(self):
320320
agent_types = model.agent_types
321321
for agent_type in agent_types:
322322
assert model.schedule.get_type_count(agent_type) == len(
323-
model.get_agents_of_type(agent_type)
323+
model.agents_by_type[agent_type]
324324
)
325325

326326
# def test_add_non_unique_ids(self):

0 commit comments

Comments
 (0)