1111import contextlib
1212import copy
1313import operator
14+ import warnings
1415import weakref
16+ from collections import defaultdict
1517from collections .abc import Callable , Iterable , Iterator , MutableSet , Sequence
1618from random import Random
1719
@@ -216,25 +218,64 @@ def _update(self, agents: Iterable[Agent]):
216218 self ._agents = weakref .WeakKeyDictionary ({agent : None for agent in agents })
217219 return self
218220
219- def do (
220- self , method : str | Callable , * args , return_results : bool = False , ** kwargs
221- ) -> AgentSet | list [Any ]:
221+ def do (self , method : str | Callable , * args , ** kwargs ) -> AgentSet :
222222 """
223223 Invoke a method or function on each agent in the AgentSet.
224224
225225 Args:
226- method (str, callable): the callable to do on each agents
226+ method (str, callable): the callable to do on each agent
227227
228228 * in case of str, the name of the method to call on each agent.
229229 * in case of callable, the function to be called with each agent as first argument
230230
231- return_results (bool, optional): If True, returns the results of the method calls; otherwise, returns the AgentSet itself. Defaults to False, so you can chain method calls.
232231 *args: Variable length argument list passed to the callable being called.
233232 **kwargs: Arbitrary keyword arguments passed to the callable being called.
234233
235234 Returns:
236235 AgentSet | list[Any]: The results of the callable calls if return_results is True, otherwise the AgentSet itself.
237236 """
237+ try :
238+ return_results = kwargs .pop ("return_results" )
239+ except KeyError :
240+ return_results = False
241+ else :
242+ warnings .warn (
243+ "Using return_results is deprecated. Use AgenSet.do in case of return_results=False, and "
244+ "AgentSet.map in case of return_results=True" ,
245+ stacklevel = 2 ,
246+ )
247+
248+ if return_results :
249+ return self .map (method , * args , ** kwargs )
250+
251+ # we iterate over the actual weakref keys and check if weakref is alive before calling the method
252+ if isinstance (method , str ):
253+ for agentref in self ._agents .keyrefs ():
254+ if (agent := agentref ()) is not None :
255+ getattr (agent , method )(* args , ** kwargs )
256+ else :
257+ for agentref in self ._agents .keyrefs ():
258+ if (agent := agentref ()) is not None :
259+ method (agent , * args , ** kwargs )
260+
261+ return self
262+
263+ def map (self , method : str | Callable , * args , ** kwargs ) -> list [Any ]:
264+ """
265+ Invoke a method or function on each agent in the AgentSet and return the results.
266+
267+ Args:
268+ method (str, callable): the callable to apply on each agent
269+
270+ * in case of str, the name of the method to call on each agent.
271+ * in case of callable, the function to be called with each agent as first argument
272+
273+ *args: Variable length argument list passed to the callable being called.
274+ **kwargs: Arbitrary keyword arguments passed to the callable being called.
275+
276+ Returns:
277+ list[Any]: The results of the callable calls
278+ """
238279 # we iterate over the actual weakref keys and check if weakref is alive before calling the method
239280 if isinstance (method , str ):
240281 res = [
@@ -249,7 +290,7 @@ def do(
249290 if (agent := agentref ()) is not None
250291 ]
251292
252- return res if return_results else self
293+ return res
253294
254295 def get (self , attr_names : str | list [str ]) -> list [Any ]:
255296 """
@@ -357,7 +398,116 @@ def random(self) -> Random:
357398 """
358399 return self .model .random
359400
401+ def groupby (self , by : Callable | str , result_type : str = "agentset" ) -> GroupBy :
402+ """
403+ Group agents by the specified attribute or return from the callable
404+
405+ Args:
406+ by (Callable, str): used to determine what to group agents by
407+
408+ * if ``by`` is a callable, it will be called for each agent and the return is used
409+ for grouping
410+ * if ``by`` is a str, it should refer to an attribute on the agent and the value
411+ of this attribute will be used for grouping
412+ result_type (str, optional): The datatype for the resulting groups {"agentset", "list"}
413+ Returns:
414+ GroupBy
415+
416+
417+ Notes:
418+ There might be performance benefits to using `result_type='list'` if you don't need the advanced functionality
419+ of an AgentSet.
420+
421+ """
422+ groups = defaultdict (list )
423+
424+ if isinstance (by , Callable ):
425+ for agent in self :
426+ groups [by (agent )].append (agent )
427+ else :
428+ for agent in self :
429+ groups [getattr (agent , by )].append (agent )
430+
431+ if result_type == "agentset" :
432+ return GroupBy (
433+ {k : AgentSet (v , model = self .model ) for k , v in groups .items ()}
434+ )
435+ else :
436+ return GroupBy (groups )
437+
438+ # consider adding for performance reasons
439+ # for Sequence: __reversed__, index, and count
440+ # for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
441+
442+
443+ class GroupBy :
444+ """Helper class for AgentSet.groupby
445+
446+
447+ Attributes:
448+ groups (dict): A dictionary with the group_name as key and group as values
449+
450+ """
451+
452+ def __init__ (self , groups : dict [Any , list | AgentSet ]):
453+ self .groups : dict [Any , list | AgentSet ] = groups
454+
455+ def map (self , method : Callable | str , * args , ** kwargs ) -> dict [Any , Any ]:
456+ """Apply the specified callable to each group and return the results.
457+
458+ Args:
459+ method (Callable, str): The callable to apply to each group,
460+
461+ * if ``method`` is a callable, it will be called it will be called with the group as first argument
462+ * if ``method`` is a str, it should refer to a method on the group
463+
464+ Additional arguments and keyword arguments will be passed on to the callable.
465+
466+ Returns:
467+ dict with group_name as key and the return of the method as value
468+
469+ Notes:
470+ this method is useful for methods or functions that do return something. It
471+ will break method chaining. For that, use ``do`` instead.
472+
473+ """
474+ if isinstance (method , str ):
475+ return {
476+ k : getattr (v , method )(* args , ** kwargs ) for k , v in self .groups .items ()
477+ }
478+ else :
479+ return {k : method (v , * args , ** kwargs ) for k , v in self .groups .items ()}
480+
481+ def do (self , method : Callable | str , * args , ** kwargs ) -> GroupBy :
482+ """Apply the specified callable to each group
483+
484+ Args:
485+ method (Callable, str): The callable to apply to each group,
486+
487+ * if ``method`` is a callable, it will be called it will be called with the group as first argument
488+ * if ``method`` is a str, it should refer to a method on the group
489+
490+ Additional arguments and keyword arguments will be passed on to the callable.
491+
492+ Returns:
493+ the original GroupBy instance
494+
495+ Notes:
496+ this method is useful for methods or functions that don't return anything and/or
497+ if you want to chain multiple do calls
498+
499+ """
500+ if isinstance (method , str ):
501+ for v in self .groups .values ():
502+ getattr (v , method )(* args , ** kwargs )
503+ else :
504+ for v in self .groups .values ():
505+ method (v , * args , ** kwargs )
506+
507+ return self
508+
509+ def __iter__ (self ):
510+ return iter (self .groups .items ())
360511
361- # consider adding for performance reasons
362- # for Sequence: __reversed__, index, and count
363- # for MutableSet clear, pop, remove, __ior__, __iand__, __ixor__, and __isub__
512+ def __len__ (self ):
513+ return len (self .groups )
0 commit comments