diff --git a/mesa/agent.py b/mesa/agent.py index 93cd6287528..e8bf2bc8c71 100644 --- a/mesa/agent.py +++ b/mesa/agent.py @@ -24,13 +24,11 @@ import numpy as np if TYPE_CHECKING: - # We ensure that these are not imported during runtime to prevent cyclic - # dependency. from mesa.model import Model from mesa.space import Position -class Agent: +class Agent[M: Model]: """Base class for a model agent in Mesa. Attributes: @@ -48,7 +46,7 @@ class Agent: # so, unique_id is unique relative to a model, and counting starts from 1 _ids = defaultdict(functools.partial(itertools.count, 1)) - def __init__(self, model: Model, *args, **kwargs) -> None: + def __init__(self, model: M, *args, **kwargs) -> None: """Create a new agent. Args: @@ -62,7 +60,7 @@ def __init__(self, model: Model, *args, **kwargs) -> None: """ super().__init__(*args, **kwargs) - self.model: Model = model + self.model: M = model self.unique_id: int = next(self._ids[model]) self.pos: Position | None = None self.model.register_agent(self) @@ -85,7 +83,9 @@ def advance(self) -> None: # noqa: D102 pass @classmethod - def create_agents(cls, model: Model, n: int, *args, **kwargs) -> AgentSet[Agent]: + def create_agents[T: Agent]( + cls: type[T], model: Model, n: int, *args, **kwargs + ) -> AgentSet[T]: """Create N agents. Args: @@ -146,7 +146,7 @@ def rng(self) -> np.random.Generator: return self.model.rng -class AgentSet(MutableSet, Sequence): +class AgentSet[A: Agent](MutableSet[A], Sequence[A]): """A collection class that represents an ordered set of agents within an agent-based model (ABM). This class extends both MutableSet and Sequence, providing set-like functionality with order preservation and @@ -171,7 +171,7 @@ class AgentSet(MutableSet, Sequence): def __init__( self, - agents: Iterable[Agent], + agents: Iterable[A], random: Random | None = None, ): """Initializes the AgentSet with a collection of agents and a reference to the model. @@ -200,11 +200,11 @@ def __len__(self) -> int: """Return the number of agents in the AgentSet.""" return len(self._agents) - def __iter__(self) -> Iterator[Agent]: + def __iter__(self) -> Iterator[A]: """Provide an iterator over the agents in the AgentSet.""" return self._agents.keys() - def __contains__(self, agent: Agent) -> bool: + def __contains__(self, agent: A) -> bool: """Check if an agent is in the AgentSet. Can be used like `agent in agentset`.""" return agent in self._agents @@ -213,7 +213,7 @@ def select( filter_func: Callable[[Agent], bool] | None = None, at_most: int | float = float("inf"), inplace: bool = False, - agent_type: type[Agent] | None = None, + agent_type: type[A] | None = None, ) -> AgentSet: """Select a subset of agents from the AgentSet based on a filter function and/or quantity limit. @@ -256,7 +256,7 @@ def agent_generator(filter_func, agent_type, at_most): return AgentSet(agents, self.random) if not inplace else self._update(agents) - def shuffle(self, inplace: bool = False) -> AgentSet: + def shuffle(self, inplace: bool = False) -> AgentSet[A]: """Randomly shuffle the order of agents in the AgentSet. Args: @@ -307,7 +307,7 @@ def sort( else self._update(sorted_agents) ) - def _update(self, agents: Iterable[Agent]): + def _update(self, agents: Iterable[A]): """Update the AgentSet with a new set of agents. This is a private method primarily used internally by other methods like select, shuffle, and sort. @@ -315,7 +315,7 @@ def _update(self, agents: Iterable[Agent]): self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents)) return self - def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: + def do(self, method: str | Callable, *args, **kwargs) -> AgentSet[A]: """Invoke a method or function on each agent in the AgentSet. Args: @@ -342,7 +342,7 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet: return self - def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet: + def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet[A]: """Shuffle the agents in the AgentSet and then invoke a method or function on each agent. It's a fast, optimized version of calling shuffle() followed by do(). @@ -488,7 +488,7 @@ def get( "should be one of 'error' or 'default'" ) - def set(self, attr_name: str, value: Any) -> AgentSet: + def set(self, attr_name: str, value: Any) -> AgentSet[A]: """Set a specified attribute to a given value for all agents in the AgentSet. Args: @@ -502,7 +502,7 @@ def set(self, attr_name: str, value: Any) -> AgentSet: setattr(agent, attr_name, value) return self - def __getitem__(self, item: int | slice) -> Agent: + def __getitem__(self, item: int | slice) -> A: """Retrieve an agent or a slice of agents from the AgentSet. Args: @@ -513,7 +513,7 @@ def __getitem__(self, item: int | slice) -> Agent: """ return list(self._agents.keys())[item] - def add(self, agent: Agent): + def add(self, agent: A): """Add an agent to the AgentSet. Args: @@ -524,7 +524,7 @@ def add(self, agent: Agent): """ self._agents[agent] = None - def discard(self, agent: Agent): + def discard(self, agent: A): """Remove an agent from the AgentSet if it exists. This method does not raise an error if the agent is not present. @@ -538,7 +538,7 @@ def discard(self, agent: Agent): with contextlib.suppress(KeyError): del self._agents[agent] - def remove(self, agent: Agent): + def remove(self, agent: A): """Remove an agent from the AgentSet. This method raises an error if the agent is not present. diff --git a/mesa/model.py b/mesa/model.py index f53d92b1633..c49403e9f41 100644 --- a/mesa/model.py +++ b/mesa/model.py @@ -12,13 +12,16 @@ from collections.abc import Sequence # mypy -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np from mesa.agent import Agent, AgentSet from mesa.mesa_logging import create_module_logger, method_logger +if TYPE_CHECKING: + pass + SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence RNGLike = np.random.Generator | np.random.BitGenerator @@ -26,7 +29,7 @@ _mesa_logger = create_module_logger() -class Model: +class Model[A: Agent]: """Base class for models in the Mesa ABM library. This class serves as a foundational structure for creating agent-based models. @@ -107,9 +110,9 @@ def __init__( # setup agent registration data structures self._agents = {} # the hard references to all agents in the model self._agents_by_type: dict[ - type[Agent], AgentSet + type[A], AgentSet[A] ] = {} # a dict with an agentset for each class of agents - self._all_agents = AgentSet( + self._all_agents: AgentSet[A] = AgentSet( [], random=self.random ) # an agenset with all agents @@ -122,7 +125,7 @@ def _wrapped_step(self, *args: Any, **kwargs: Any) -> None: self._user_step(*args, **kwargs) @property - def agents(self) -> AgentSet: + def agents(self) -> AgentSet[A]: """Provides an AgentSet of all agents in the model, combining agents from all types.""" return self._all_agents @@ -140,11 +143,11 @@ def agent_types(self) -> list[type]: return list(self._agents_by_type.keys()) @property - def agents_by_type(self) -> dict[type[Agent], AgentSet]: + def agents_by_type(self) -> dict[type[A], AgentSet[A]]: """A dictionary where the keys are agent types and the values are the corresponding AgentSets.""" return self._agents_by_type - def register_agent(self, agent): + def register_agent(self, agent: A): """Register the agent with the model. Args: @@ -174,7 +177,7 @@ def register_agent(self, agent): f"registered {agent.__class__.__name__} with agent_id {agent.unique_id}" ) - def deregister_agent(self, agent): + def deregister_agent(self, agent: A): """Deregister the agent with the model. Args: diff --git a/mesa/space.py b/mesa/space.py index c9cd04f4223..5d7a899e826 100644 --- a/mesa/space.py +++ b/mesa/space.py @@ -33,7 +33,7 @@ import warnings from collections.abc import Callable, Iterable, Iterator, Sequence from numbers import Real -from typing import Any, TypeVar, cast, overload +from typing import Any, cast, overload from warnings import warn with contextlib.suppress(ImportError): @@ -58,8 +58,6 @@ GridContent = Agent | None MultiGridContent = list[Agent] -F = TypeVar("F", bound=Callable[..., Any]) - def accept_tuple_argument[F: Callable[..., Any]](wrapped_function: F) -> F: """Decorator to allow grid methods that take a list of (x, y) coord tuples to also handle a single position.