|
13 | 13 | import operator |
14 | 14 | import warnings |
15 | 15 | import weakref |
| 16 | +from collections import defaultdict |
16 | 17 | from collections.abc import Callable, Iterable, Iterator, MutableSet, Sequence |
17 | 18 | from random import Random |
18 | 19 |
|
@@ -397,7 +398,116 @@ def random(self) -> Random: |
397 | 398 | """ |
398 | 399 | return self.model.random |
399 | 400 |
|
| 401 | + def groupby(self, by: Callable | str, result_type: str = "agentset") -> GroupBy: |
| 402 | + """ |
| 403 | + Group agents by the specified attribute or return from the callable |
| 404 | +
|
| 405 | + Args: |
| 406 | + by (Callable, str): used to determine what to group agents by |
| 407 | +
|
| 408 | + * if ``by`` is a callable, it will be called for each agent and the return is used |
| 409 | + for grouping |
| 410 | + * if ``by`` is a str, it should refer to an attribute on the agent and the value |
| 411 | + of this attribute will be used for grouping |
| 412 | + result_type (str, optional): The datatype for the resulting groups {"agentset", "list"} |
| 413 | + Returns: |
| 414 | + GroupBy |
| 415 | +
|
| 416 | +
|
| 417 | + Notes: |
| 418 | + There might be performance benefits to using `result_type='list'` if you don't need the advanced functionality |
| 419 | + of an AgentSet. |
| 420 | +
|
| 421 | + """ |
| 422 | + groups = defaultdict(list) |
| 423 | + |
| 424 | + if isinstance(by, Callable): |
| 425 | + for agent in self: |
| 426 | + groups[by(agent)].append(agent) |
| 427 | + else: |
| 428 | + for agent in self: |
| 429 | + groups[getattr(agent, by)].append(agent) |
| 430 | + |
| 431 | + if result_type == "agentset": |
| 432 | + return GroupBy( |
| 433 | + {k: AgentSet(v, model=self.model) for k, v in groups.items()} |
| 434 | + ) |
| 435 | + else: |
| 436 | + return GroupBy(groups) |
| 437 | + |
| 438 | + # consider adding for performance reasons |
| 439 | + # for Sequence: __reversed__, index, and count |
| 440 | + # for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__ |
| 441 | + |
| 442 | + |
| 443 | +class GroupBy: |
| 444 | + """Helper class for AgentSet.groupby |
| 445 | +
|
| 446 | +
|
| 447 | + Attributes: |
| 448 | + groups (dict): A dictionary with the group_name as key and group as values |
| 449 | +
|
| 450 | + """ |
| 451 | + |
| 452 | + def __init__(self, groups: dict[Any, list | AgentSet]): |
| 453 | + self.groups: dict[Any, list | AgentSet] = groups |
| 454 | + |
| 455 | + def map(self, method: Callable | str, *args, **kwargs) -> dict[Any, Any]: |
| 456 | + """Apply the specified callable to each group and return the results. |
| 457 | +
|
| 458 | + Args: |
| 459 | + method (Callable, str): The callable to apply to each group, |
| 460 | +
|
| 461 | + * if ``method`` is a callable, it will be called it will be called with the group as first argument |
| 462 | + * if ``method`` is a str, it should refer to a method on the group |
| 463 | +
|
| 464 | + Additional arguments and keyword arguments will be passed on to the callable. |
| 465 | +
|
| 466 | + Returns: |
| 467 | + dict with group_name as key and the return of the method as value |
| 468 | +
|
| 469 | + Notes: |
| 470 | + this method is useful for methods or functions that do return something. It |
| 471 | + will break method chaining. For that, use ``do`` instead. |
| 472 | +
|
| 473 | + """ |
| 474 | + if isinstance(method, str): |
| 475 | + return { |
| 476 | + k: getattr(v, method)(*args, **kwargs) for k, v in self.groups.items() |
| 477 | + } |
| 478 | + else: |
| 479 | + return {k: method(v, *args, **kwargs) for k, v in self.groups.items()} |
| 480 | + |
| 481 | + def do(self, method: Callable | str, *args, **kwargs) -> GroupBy: |
| 482 | + """Apply the specified callable to each group |
| 483 | +
|
| 484 | + Args: |
| 485 | + method (Callable, str): The callable to apply to each group, |
| 486 | +
|
| 487 | + * if ``method`` is a callable, it will be called it will be called with the group as first argument |
| 488 | + * if ``method`` is a str, it should refer to a method on the group |
| 489 | +
|
| 490 | + Additional arguments and keyword arguments will be passed on to the callable. |
| 491 | +
|
| 492 | + Returns: |
| 493 | + the original GroupBy instance |
| 494 | +
|
| 495 | + Notes: |
| 496 | + this method is useful for methods or functions that don't return anything and/or |
| 497 | + if you want to chain multiple do calls |
| 498 | +
|
| 499 | + """ |
| 500 | + if isinstance(method, str): |
| 501 | + for v in self.groups.values(): |
| 502 | + getattr(v, method)(*args, **kwargs) |
| 503 | + else: |
| 504 | + for v in self.groups.values(): |
| 505 | + method(v, *args, **kwargs) |
| 506 | + |
| 507 | + return self |
| 508 | + |
| 509 | + def __iter__(self): |
| 510 | + return iter(self.groups.items()) |
400 | 511 |
|
401 | | -# consider adding for performance reasons |
402 | | -# for Sequence: __reversed__, index, and count |
403 | | -# for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__ |
| 512 | + def __len__(self): |
| 513 | + return len(self.groups) |
0 commit comments