Skip to content

Commit 8043bc0

Browse files
authored
Add default values and missing value handling to agentset.get (#2279)
1 parent 046cd97 commit 8043bc0

File tree

2 files changed

+87
-13
lines changed

2 files changed

+87
-13
lines changed

mesa/agent.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from random import Random
2121

2222
# mypy
23-
from typing import TYPE_CHECKING, Any
23+
from typing import TYPE_CHECKING, Any, Literal
2424

2525
if TYPE_CHECKING:
2626
# We ensure that these are not imported during runtime to prevent cyclic
@@ -348,29 +348,58 @@ def agg(self, attribute: str, func: Callable) -> Any:
348348
values = self.get(attribute)
349349
return func(values)
350350

351-
def get(self, attr_names: str | list[str]) -> list[Any]:
351+
def get(
352+
self,
353+
attr_names: str | list[str],
354+
handle_missing: Literal["error", "default"] = "error",
355+
default_value: Any = None,
356+
) -> list[Any] | list[list[Any]]:
352357
"""
353358
Retrieve the specified attribute(s) from each agent in the AgentSet.
354359
355360
Args:
356361
attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent.
362+
handle_missing (str, optional): How to handle missing attributes. Can be:
363+
- 'error' (default): raises an AttributeError if attribute is missing.
364+
- 'default': returns the specified default_value.
365+
default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default'
366+
and the agent does not have the attribute.
357367
358368
Returns:
359-
list[Any]: A list with the attribute value for each agent in the set if attr_names is a str
360-
list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str
369+
list[Any]: A list with the attribute value for each agent if attr_names is a str.
370+
list[list[Any]]: A list with a lists of attribute values for each agent if attr_names is a list of str.
361371
362372
Raises:
363-
AttributeError if an agent does not have the specified attribute(s)
364-
365-
"""
373+
AttributeError: If 'handle_missing' is 'error' and the agent does not have the specified attribute(s).
374+
ValueError: If an unknown 'handle_missing' option is provided.
375+
"""
376+
is_single_attr = isinstance(attr_names, str)
377+
378+
if handle_missing == "error":
379+
if is_single_attr:
380+
return [getattr(agent, attr_names) for agent in self._agents]
381+
else:
382+
return [
383+
[getattr(agent, attr) for attr in attr_names]
384+
for agent in self._agents
385+
]
386+
387+
elif handle_missing == "default":
388+
if is_single_attr:
389+
return [
390+
getattr(agent, attr_names, default_value) for agent in self._agents
391+
]
392+
else:
393+
return [
394+
[getattr(agent, attr, default_value) for attr in attr_names]
395+
for agent in self._agents
396+
]
366397

367-
if isinstance(attr_names, str):
368-
return [getattr(agent, attr_names) for agent in self._agents]
369398
else:
370-
return [
371-
[getattr(agent, attr_name) for attr_name in attr_names]
372-
for agent in self._agents
373-
]
399+
raise ValueError(
400+
f"Unknown handle_missing option: {handle_missing}, "
401+
"should be one of 'error' or 'default'"
402+
)
374403

375404
def set(self, attr_name: str, value: Any) -> AgentSet:
376405
"""

tests/test_agent.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,51 @@ def remove_function(agent):
276276
assert len(agentset) == 0
277277

278278

279+
def test_agentset_get():
280+
model = Model()
281+
_ = [TestAgent(i, model) for i in range(10)]
282+
283+
agentset = model.agents
284+
285+
agentset.set("a", 5)
286+
agentset.set("b", 6)
287+
288+
# Case 1: Normal retrieval of existing attributes
289+
values = agentset.get(["a", "b"])
290+
assert all((a == 5) & (b == 6) for a, b in values)
291+
292+
# Case 2: Raise AttributeError when attribute doesn't exist
293+
with pytest.raises(AttributeError):
294+
agentset.get("unknown_attribute")
295+
296+
# Case 3: Use default value when attribute is missing
297+
results = agentset.get(
298+
"unknown_attribute", handle_missing="default", default_value=True
299+
)
300+
assert all(results) is True
301+
302+
# Case 4: Retrieve mixed attributes with default value for missing ones
303+
values = agentset.get(
304+
["a", "unknown_attribute"], handle_missing="default", default_value=True
305+
)
306+
assert all((a == 5) & (unknown is True) for a, unknown in values)
307+
308+
# Case 5: Invalid handle_missing value raises ValueError
309+
with pytest.raises(ValueError):
310+
agentset.get("unknown_attribute", handle_missing="some nonsense value")
311+
312+
# Case 6: Retrieve multiple attributes with mixed existence and 'default' handling
313+
values = agentset.get(
314+
["a", "b", "unknown_attribute"], handle_missing="default", default_value=0
315+
)
316+
assert all((a == 5) & (b == 6) & (unknown == 0) for a, b, unknown in values)
317+
318+
# Case 7: 'default' handling when one attribute is completely missing from some agents
319+
agentset.select(at_most=0.5).set("c", 8) # Only some agents have attribute 'c'
320+
values = agentset.get(["a", "c"], handle_missing="default", default_value=-1)
321+
assert all((a == 5) & (c in [8, -1]) for a, c in values)
322+
323+
279324
def test_agentset_agg():
280325
model = Model()
281326
agents = [TestAgent(model) for i in range(10)]

0 commit comments

Comments
 (0)