Skip to content

Commit efa51cd

Browse files
authored
Allow selecting a fraction of agents in the AgentSet (#2253)
This PR updates the `select` method in the `AgentSet` class by replacing the `n` parameter with a more versatile `at_most` parameter. The `at_most` parameter allows for selecting either a specific number of agents or a fraction of the total agents when provided as an integer or a float, respectively. Additionally, backward compatibility is maintained by supporting the deprecated `n` parameter, which will trigger a warning when used. ### Motive Previously, the `select` method only allowed users to specify a fixed number of agents (`n`) to be selected. The new `at_most` parameter extends this functionality by enabling the selection of agents based on a proportion of the total set, which is particularly useful in scenarios where relative selection is desired over absolute selection. ### Implementation - **`at_most` Parameter:** - Accepts either an integer (to select a fixed number of agents) or a float between 0.0 and 1.0 (to select a fraction of the total agents). - `at_most=1` selects one agent, while `at_most=1.0` selects all agents. - If a float is provided, it determines the maximum fraction of agents to be selected from the total set. It rounds down to the nearest number of whole agents. - **Backward Compatibility:** - The deprecated `n` parameter is still supported, but it now serves as a fallback for `at_most` and triggers a deprecation warning. - **Behavior Notes:** - `at_most` serves as an upper limit on the number of selected agents. If additional filtering criteria are provided, the final selection may include fewer agents. - For random sampling, users should shuffle the `AgentSet` before applying `at_most`.
1 parent 95e1cd2 commit efa51cd

File tree

2 files changed

+32
-7
lines changed

2 files changed

+32
-7
lines changed

mesa/agent.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,39 +114,58 @@ def __contains__(self, agent: Agent) -> bool:
114114
def select(
115115
self,
116116
filter_func: Callable[[Agent], bool] | None = None,
117-
n: int = 0,
117+
at_most: int | float = float("inf"),
118118
inplace: bool = False,
119119
agent_type: type[Agent] | None = None,
120+
n: int | None = None,
120121
) -> AgentSet:
121122
"""
122123
Select a subset of agents from the AgentSet based on a filter function and/or quantity limit.
123124
124125
Args:
125126
filter_func (Callable[[Agent], bool], optional): A function that takes an Agent and returns True if the
126127
agent should be included in the result. Defaults to None, meaning no filtering is applied.
127-
n (int, optional): The number of agents to select. If 0, all matching agents are selected. Defaults to 0.
128+
at_most (int | float, optional): The maximum amount of agents to select. Defaults to infinity.
129+
- If an integer, at most the first number of matching agents are selected.
130+
- If a float between 0 and 1, at most that fraction of original the agents are selected.
128131
inplace (bool, optional): If True, modifies the current AgentSet; otherwise, returns a new AgentSet. Defaults to False.
129132
agent_type (type[Agent], optional): The class type of the agents to select. Defaults to None, meaning no type filtering is applied.
130133
131134
Returns:
132135
AgentSet: A new AgentSet containing the selected agents, unless inplace is True, in which case the current AgentSet is updated.
136+
137+
Notes:
138+
- at_most just return the first n or fraction of agents. To take a random sample, shuffle() beforehand.
139+
- at_most is an upper limit. When specifying other criteria, the number of agents returned can be smaller.
133140
"""
141+
if n is not None:
142+
warnings.warn(
143+
"The parameter 'n' is deprecated. Use 'at_most' instead.",
144+
DeprecationWarning,
145+
stacklevel=2,
146+
)
147+
at_most = n
134148

135-
if filter_func is None and agent_type is None and n == 0:
149+
inf = float("inf")
150+
if filter_func is None and agent_type is None and at_most == inf:
136151
return self if inplace else copy.copy(self)
137152

138-
def agent_generator(filter_func=None, agent_type=None, n=0):
153+
# Check if at_most is of type float
154+
if at_most <= 1.0 and isinstance(at_most, float):
155+
at_most = int(len(self) * at_most) # Note that it rounds down (floor)
156+
157+
def agent_generator(filter_func, agent_type, at_most):
139158
count = 0
140159
for agent in self:
160+
if count >= at_most:
161+
break
141162
if (not filter_func or filter_func(agent)) and (
142163
not agent_type or isinstance(agent, agent_type)
143164
):
144165
yield agent
145166
count += 1
146-
if 0 < n <= count:
147-
break
148167

149-
agents = agent_generator(filter_func, agent_type, n)
168+
agents = agent_generator(filter_func, agent_type, at_most)
150169

151170
return AgentSet(agents, self.model) if not inplace else self._update(agents)
152171

tests/test_agent.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def test_agentset():
5959
def test_function(agent):
6060
return agent.unique_id > 5
6161

62+
assert len(agentset.select(at_most=0.2)) == 2 # Select 20% of agents
63+
assert len(agentset.select(at_most=0.549)) == 5 # Select 50% of agents
64+
assert len(agentset.select(at_most=0.09)) == 0 # Select 0% of agents
65+
assert len(agentset.select(at_most=1.0)) == 10 # Select 100% agents
66+
assert len(agentset.select(at_most=1)) == 1 # Select 1 agent
67+
6268
assert len(agentset.select(test_function)) == 5
6369
assert len(agentset.select(test_function, n=2)) == 2
6470
assert len(agentset.select(test_function, inplace=True)) == 5

0 commit comments

Comments
 (0)