|
20 | 20 | from random import Random |
21 | 21 |
|
22 | 22 | # mypy |
23 | | -from typing import TYPE_CHECKING, Any |
| 23 | +from typing import TYPE_CHECKING, Any, Literal |
24 | 24 |
|
25 | 25 | if TYPE_CHECKING: |
26 | 26 | # We ensure that these are not imported during runtime to prevent cyclic |
@@ -348,29 +348,58 @@ def agg(self, attribute: str, func: Callable) -> Any: |
348 | 348 | values = self.get(attribute) |
349 | 349 | return func(values) |
350 | 350 |
|
351 | | - def get(self, attr_names: str | list[str]) -> list[Any]: |
| 351 | + def get( |
| 352 | + self, |
| 353 | + attr_names: str | list[str], |
| 354 | + handle_missing: Literal["error", "default"] = "error", |
| 355 | + default_value: Any = None, |
| 356 | + ) -> list[Any] | list[list[Any]]: |
352 | 357 | """ |
353 | 358 | Retrieve the specified attribute(s) from each agent in the AgentSet. |
354 | 359 |
|
355 | 360 | Args: |
356 | 361 | attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent. |
| 362 | + handle_missing (str, optional): How to handle missing attributes. Can be: |
| 363 | + - 'error' (default): raises an AttributeError if attribute is missing. |
| 364 | + - 'default': returns the specified default_value. |
| 365 | + default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default' |
| 366 | + and the agent does not have the attribute. |
357 | 367 |
|
358 | 368 | Returns: |
359 | | - list[Any]: A list with the attribute value for each agent in the set if attr_names is a str |
360 | | - list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str |
| 369 | + list[Any]: A list with the attribute value for each agent if attr_names is a str. |
| 370 | + list[list[Any]]: A list with a lists of attribute values for each agent if attr_names is a list of str. |
361 | 371 |
|
362 | 372 | Raises: |
363 | | - AttributeError if an agent does not have the specified attribute(s) |
364 | | -
|
365 | | - """ |
| 373 | + AttributeError: If 'handle_missing' is 'error' and the agent does not have the specified attribute(s). |
| 374 | + ValueError: If an unknown 'handle_missing' option is provided. |
| 375 | + """ |
| 376 | + is_single_attr = isinstance(attr_names, str) |
| 377 | + |
| 378 | + if handle_missing == "error": |
| 379 | + if is_single_attr: |
| 380 | + return [getattr(agent, attr_names) for agent in self._agents] |
| 381 | + else: |
| 382 | + return [ |
| 383 | + [getattr(agent, attr) for attr in attr_names] |
| 384 | + for agent in self._agents |
| 385 | + ] |
| 386 | + |
| 387 | + elif handle_missing == "default": |
| 388 | + if is_single_attr: |
| 389 | + return [ |
| 390 | + getattr(agent, attr_names, default_value) for agent in self._agents |
| 391 | + ] |
| 392 | + else: |
| 393 | + return [ |
| 394 | + [getattr(agent, attr, default_value) for attr in attr_names] |
| 395 | + for agent in self._agents |
| 396 | + ] |
366 | 397 |
|
367 | | - if isinstance(attr_names, str): |
368 | | - return [getattr(agent, attr_names) for agent in self._agents] |
369 | 398 | else: |
370 | | - return [ |
371 | | - [getattr(agent, attr_name) for attr_name in attr_names] |
372 | | - for agent in self._agents |
373 | | - ] |
| 399 | + raise ValueError( |
| 400 | + f"Unknown handle_missing option: {handle_missing}, " |
| 401 | + "should be one of 'error' or 'default'" |
| 402 | + ) |
374 | 403 |
|
375 | 404 | def set(self, attr_name: str, value: Any) -> AgentSet: |
376 | 405 | """ |
|
0 commit comments