Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions mesa/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,11 @@
import numpy as np

if TYPE_CHECKING:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you check if we still need this? (honestly I don't know)

Copy link
Author

@SiddharthBansal007 SiddharthBansal007 Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I went through the codebase, I think numpy is needed
Example usecase (Line 146)

def rng(self) -> np.random.Generator:
    """Return a seeded np.random rng."""
    return self.model.rng
  1. for the TYPE_CHECKING, i have used it to prevent circular imports, do let me know if i am wrong in any of the both cases

# We ensure that these are not imported during runtime to prevent cyclic
# dependency.
from mesa.model import Model
from mesa.space import Position


class Agent:
class Agent[M: Model]:
"""Base class for a model agent in Mesa.

Attributes:
Expand All @@ -48,7 +46,7 @@ class Agent:
# so, unique_id is unique relative to a model, and counting starts from 1
_ids = defaultdict(functools.partial(itertools.count, 1))

def __init__(self, model: Model, *args, **kwargs) -> None:
def __init__(self, model: M, *args, **kwargs) -> None:
"""Create a new agent.

Args:
Expand All @@ -62,7 +60,9 @@ def __init__(self, model: Model, *args, **kwargs) -> None:
"""
super().__init__(*args, **kwargs)

self.model: Model = model
# Preserve the more specific model type for static type checkers.
# At runtime this remains the Model instance passed in.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments are not needed

Suggested change
# Preserve the more specific model type for static type checkers.
# At runtime this remains the Model instance passed in.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will make the neccessary changes

self.model: M = model
self.unique_id: int = next(self._ids[model])
self.pos: Position | None = None
self.model.register_agent(self)
Expand All @@ -85,7 +85,9 @@ def advance(self) -> None: # noqa: D102
pass

@classmethod
def create_agents(cls, model: Model, n: int, *args, **kwargs) -> AgentSet[Agent]:
def create_agents[T: Agent](
cls: type[T], model: Model, n: int, *args, **kwargs
) -> AgentSet[T]:
"""Create N agents.

Args:
Expand Down Expand Up @@ -146,7 +148,7 @@ def rng(self) -> np.random.Generator:
return self.model.rng


class AgentSet(MutableSet, Sequence):
class AgentSet[A: Agent](MutableSet[A], Sequence[A]):
"""A collection class that represents an ordered set of agents within an agent-based model (ABM).

This class extends both MutableSet and Sequence, providing set-like functionality with order preservation and
Expand All @@ -171,7 +173,7 @@ class AgentSet(MutableSet, Sequence):

def __init__(
self,
agents: Iterable[Agent],
agents: Iterable[A],
random: Random | None = None,
):
"""Initializes the AgentSet with a collection of agents and a reference to the model.
Expand Down Expand Up @@ -200,11 +202,11 @@ def __len__(self) -> int:
"""Return the number of agents in the AgentSet."""
return len(self._agents)

def __iter__(self) -> Iterator[Agent]:
def __iter__(self) -> Iterator[A]:
"""Provide an iterator over the agents in the AgentSet."""
return self._agents.keys()

def __contains__(self, agent: Agent) -> bool:
def __contains__(self, agent: A) -> bool:
"""Check if an agent is in the AgentSet. Can be used like `agent in agentset`."""
return agent in self._agents

Expand All @@ -213,7 +215,7 @@ def select(
filter_func: Callable[[Agent], bool] | None = None,
at_most: int | float = float("inf"),
inplace: bool = False,
agent_type: type[Agent] | None = None,
agent_type: type[A] | None = None,
) -> AgentSet:
"""Select a subset of agents from the AgentSet based on a filter function and/or quantity limit.

Expand Down Expand Up @@ -256,7 +258,7 @@ def agent_generator(filter_func, agent_type, at_most):

return AgentSet(agents, self.random) if not inplace else self._update(agents)

def shuffle(self, inplace: bool = False) -> AgentSet:
def shuffle(self, inplace: bool = False) -> AgentSet[A]:
"""Randomly shuffle the order of agents in the AgentSet.

Args:
Expand Down Expand Up @@ -307,15 +309,15 @@ def sort(
else self._update(sorted_agents)
)

def _update(self, agents: Iterable[Agent]):
def _update(self, agents: Iterable[A]):
"""Update the AgentSet with a new set of agents.

This is a private method primarily used internally by other methods like select, shuffle, and sort.
"""
self._agents = weakref.WeakKeyDictionary(dict.fromkeys(agents))
return self

def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
def do(self, method: str | Callable, *args, **kwargs) -> AgentSet[A]:
"""Invoke a method or function on each agent in the AgentSet.

Args:
Expand All @@ -342,7 +344,7 @@ def do(self, method: str | Callable, *args, **kwargs) -> AgentSet:

return self

def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet:
def shuffle_do(self, method: str | Callable, *args, **kwargs) -> AgentSet[A]:
"""Shuffle the agents in the AgentSet and then invoke a method or function on each agent.

It's a fast, optimized version of calling shuffle() followed by do().
Expand Down Expand Up @@ -488,7 +490,7 @@ def get(
"should be one of 'error' or 'default'"
)

def set(self, attr_name: str, value: Any) -> AgentSet:
def set(self, attr_name: str, value: Any) -> AgentSet[A]:
"""Set a specified attribute to a given value for all agents in the AgentSet.

Args:
Expand All @@ -502,7 +504,7 @@ def set(self, attr_name: str, value: Any) -> AgentSet:
setattr(agent, attr_name, value)
return self

def __getitem__(self, item: int | slice) -> Agent:
def __getitem__(self, item: int | slice) -> A:
"""Retrieve an agent or a slice of agents from the AgentSet.

Args:
Expand All @@ -513,7 +515,7 @@ def __getitem__(self, item: int | slice) -> Agent:
"""
return list(self._agents.keys())[item]

def add(self, agent: Agent):
def add(self, agent: A):
"""Add an agent to the AgentSet.

Args:
Expand All @@ -524,7 +526,7 @@ def add(self, agent: Agent):
"""
self._agents[agent] = None

def discard(self, agent: Agent):
def discard(self, agent: A):
"""Remove an agent from the AgentSet if it exists.

This method does not raise an error if the agent is not present.
Expand All @@ -538,7 +540,7 @@ def discard(self, agent: Agent):
with contextlib.suppress(KeyError):
del self._agents[agent]

def remove(self, agent: Agent):
def remove(self, agent: A):
"""Remove an agent from the AgentSet.

This method raises an error if the agent is not present.
Expand Down
23 changes: 12 additions & 11 deletions mesa/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,24 @@
from collections.abc import Sequence

# mypy
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np

from mesa.agent import Agent, AgentSet
from mesa.mesa_logging import create_module_logger, method_logger

if TYPE_CHECKING:
pass

SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence
RNGLike = np.random.Generator | np.random.BitGenerator


_mesa_logger = create_module_logger()


class Model:
class Model[A: Agent]:
"""Base class for models in the Mesa ABM library.

This class serves as a foundational structure for creating agent-based models.
Expand Down Expand Up @@ -107,11 +110,9 @@ def __init__(
# setup agent registration data structures
self._agents = {} # the hard references to all agents in the model
self._agents_by_type: dict[
type[Agent], AgentSet
] = {} # a dict with an agentset for each class of agents
self._all_agents = AgentSet(
[], random=self.random
) # an agenset with all agents
Comment on lines 110 to 114
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please restore the original comments

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the changes, please review the new push

type[A], AgentSet[A]
] = {} # agentset per agent class
self._all_agents: AgentSet[A] = AgentSet([], random=self.random)

def _wrapped_step(self, *args: Any, **kwargs: Any) -> None:
"""Automatically increments time and steps after calling the user's step method."""
Expand All @@ -122,7 +123,7 @@ def _wrapped_step(self, *args: Any, **kwargs: Any) -> None:
self._user_step(*args, **kwargs)

@property
def agents(self) -> AgentSet:
def agents(self) -> AgentSet[A]:
"""Provides an AgentSet of all agents in the model, combining agents from all types."""
return self._all_agents

Expand All @@ -140,11 +141,11 @@ def agent_types(self) -> list[type]:
return list(self._agents_by_type.keys())

@property
def agents_by_type(self) -> dict[type[Agent], AgentSet]:
def agents_by_type(self) -> dict[type[A], AgentSet[A]]:
"""A dictionary where the keys are agent types and the values are the corresponding AgentSets."""
return self._agents_by_type

def register_agent(self, agent):
def register_agent(self, agent: A):
"""Register the agent with the model.

Args:
Expand Down Expand Up @@ -174,7 +175,7 @@ def register_agent(self, agent):
f"registered {agent.__class__.__name__} with agent_id {agent.unique_id}"
)

def deregister_agent(self, agent):
def deregister_agent(self, agent: A):
"""Deregister the agent with the model.

Args:
Expand Down
4 changes: 1 addition & 3 deletions mesa/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import warnings
from collections.abc import Callable, Iterable, Iterator, Sequence
from numbers import Real
from typing import Any, TypeVar, cast, overload
from typing import Any, cast, overload
from warnings import warn

with contextlib.suppress(ImportError):
Expand All @@ -58,8 +58,6 @@
GridContent = Agent | None
MultiGridContent = list[Agent]

F = TypeVar("F", bound=Callable[..., Any])


def accept_tuple_argument[F: Callable[..., Any]](wrapped_function: F) -> F:
"""Decorator to allow grid methods that take a list of (x, y) coord tuples to also handle a single position.
Expand Down