Skip to content

Commit b9909e6

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 6b49a3b + 8043bc0 commit b9909e6

File tree

7 files changed

+139
-97
lines changed

7 files changed

+139
-97
lines changed

mesa/agent.py

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from random import Random
2121

2222
# mypy
23-
from typing import TYPE_CHECKING, Any
23+
from typing import TYPE_CHECKING, Any, Literal
2424

2525
if TYPE_CHECKING:
2626
# We ensure that these are not imported during runtime to prevent cyclic
@@ -348,29 +348,58 @@ def agg(self, attribute: str, func: Callable) -> Any:
348348
values = self.get(attribute)
349349
return func(values)
350350

351-
def get(self, attr_names: str | list[str]) -> list[Any]:
351+
def get(
352+
self,
353+
attr_names: str | list[str],
354+
handle_missing: Literal["error", "default"] = "error",
355+
default_value: Any = None,
356+
) -> list[Any] | list[list[Any]]:
352357
"""
353358
Retrieve the specified attribute(s) from each agent in the AgentSet.
354359
355360
Args:
356361
attr_names (str | list[str]): The name(s) of the attribute(s) to retrieve from each agent.
362+
handle_missing (str, optional): How to handle missing attributes. Can be:
363+
- 'error' (default): raises an AttributeError if attribute is missing.
364+
- 'default': returns the specified default_value.
365+
default_value (Any, optional): The default value to return if 'handle_missing' is set to 'default'
366+
and the agent does not have the attribute.
357367
358368
Returns:
359-
list[Any]: A list with the attribute value for each agent in the set if attr_names is a str
360-
list[list[Any]]: A list with a list of attribute values for each agent in the set if attr_names is a list of str
369+
list[Any]: A list with the attribute value for each agent if attr_names is a str.
370+
list[list[Any]]: A list with a lists of attribute values for each agent if attr_names is a list of str.
361371
362372
Raises:
363-
AttributeError if an agent does not have the specified attribute(s)
364-
365-
"""
373+
AttributeError: If 'handle_missing' is 'error' and the agent does not have the specified attribute(s).
374+
ValueError: If an unknown 'handle_missing' option is provided.
375+
"""
376+
is_single_attr = isinstance(attr_names, str)
377+
378+
if handle_missing == "error":
379+
if is_single_attr:
380+
return [getattr(agent, attr_names) for agent in self._agents]
381+
else:
382+
return [
383+
[getattr(agent, attr) for attr in attr_names]
384+
for agent in self._agents
385+
]
386+
387+
elif handle_missing == "default":
388+
if is_single_attr:
389+
return [
390+
getattr(agent, attr_names, default_value) for agent in self._agents
391+
]
392+
else:
393+
return [
394+
[getattr(agent, attr, default_value) for attr in attr_names]
395+
for agent in self._agents
396+
]
366397

367-
if isinstance(attr_names, str):
368-
return [getattr(agent, attr_names) for agent in self._agents]
369398
else:
370-
return [
371-
[getattr(agent, attr_name) for attr_name in attr_names]
372-
for agent in self._agents
373-
]
399+
raise ValueError(
400+
f"Unknown handle_missing option: {handle_missing}, "
401+
"should be one of 'error' or 'default'"
402+
)
374403

375404
def set(self, attr_name: str, value: Any) -> AgentSet:
376405
"""

mesa/batchrunner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import multiprocessing
23
from collections.abc import Iterable, Mapping
34
from functools import partial
45
from multiprocessing import Pool
@@ -8,6 +9,8 @@
89

910
from mesa.model import Model
1011

12+
multiprocessing.set_start_method("spawn", force=True)
13+
1114

1215
def batch_run(
1316
model_cls: type[Model],

mesa/space.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -459,15 +459,21 @@ def move_agent_to_one_of(
459459
elif selection == "closest":
460460
current_pos = agent.pos
461461
# Find the closest position without sorting all positions
462-
closest_pos = None
462+
# TODO: See if this method can be optimized further
463+
closest_pos = []
463464
min_distance = float("inf")
464465
agent.random.shuffle(pos)
465466
for p in pos:
466467
distance = self._distance_squared(p, current_pos)
467468
if distance < min_distance:
468469
min_distance = distance
469-
closest_pos = p
470-
chosen_pos = closest_pos
470+
closest_pos.clear()
471+
closest_pos.append(p)
472+
elif distance == min_distance:
473+
closest_pos.append(p)
474+
475+
chosen_pos = agent.random.choice(closest_pos)
476+
471477
else:
472478
raise ValueError(
473479
f"Invalid selection method {selection}. Choose 'random' or 'closest'."

mesa/visualization/solara_viz.py

Lines changed: 22 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,9 @@
2424
"""
2525

2626
import copy
27-
import threading
27+
import time
2828
from typing import TYPE_CHECKING, Literal
2929

30-
import reacton.ipywidgets as widgets
3130
import solara
3231
from solara.alias import rv
3332

@@ -95,7 +94,7 @@ def SolaraViz(
9594
model: "Model" | solara.Reactive["Model"],
9695
components: list[solara.component] | Literal["default"] = "default",
9796
*args,
98-
play_interval=150,
97+
play_interval=100,
9998
model_params=None,
10099
seed=0,
101100
name: str | None = None,
@@ -149,115 +148,57 @@ def step():
149148

150149

151150
@solara.component
152-
def ModelController(model: solara.Reactive["Model"], play_interval):
151+
def ModelController(model: solara.Reactive["Model"], play_interval=100):
153152
"""
154153
Create controls for model execution (step, play, pause, reset).
155154
156155
Args:
157-
model: The model being visualized
156+
model: The reactive model being visualized
158157
play_interval: Interval between steps during play
159-
current_step: Reactive value for the current step
160-
reset_counter: Counter to trigger model reset
161158
"""
159+
if not isinstance(model, solara.Reactive):
160+
model = solara.use_reactive(model)
161+
162162
playing = solara.use_reactive(False)
163-
thread = solara.use_reactive(None)
164-
# We track the previous step to detect if user resets the model via
165-
# clicking the reset button or changing the parameters. If previous_step >
166-
# current_step, it means a model reset happens while the simulation is
167-
# still playing.
168-
previous_step = solara.use_reactive(0)
169163
original_model = solara.use_reactive(None)
170164

171165
def save_initial_model():
172166
"""Save the initial model for comparison."""
173167
original_model.set(copy.deepcopy(model.value))
168+
playing.value = False
169+
force_update()
174170

175171
solara.use_effect(save_initial_model, [model.value])
176172

177-
def on_value_play(change):
178-
"""Handle play/pause state changes."""
179-
if previous_step.value > model.value.steps and model.value.steps == 0:
180-
# We add extra checks for model.value.steps == 0, just to be sure.
181-
# We automatically stop the playing if a model is reset.
182-
playing.value = False
183-
elif model.value.running:
173+
def step():
174+
while playing.value:
175+
time.sleep(play_interval / 1000)
184176
do_step()
185-
else:
186-
playing.value = False
177+
178+
solara.use_thread(step, [playing.value])
187179

188180
def do_step():
189181
"""Advance the model by one step."""
190-
previous_step.value = model.value.steps
191182
model.value.step()
192183

193184
def do_play():
194185
"""Run the model continuously."""
195-
model.value.running = True
196-
while model.value.running:
197-
do_step()
198-
199-
def threaded_do_play():
200-
"""Start a new thread for continuous model execution."""
201-
if thread is not None and thread.is_alive():
202-
return
203-
thread.value = threading.Thread(target=do_play)
204-
thread.start()
186+
playing.value = True
205187

206188
def do_pause():
207189
"""Pause the model execution."""
208-
if (thread is None) or (not thread.is_alive()):
209-
return
210-
model.value.running = False
211-
thread.join()
190+
playing.value = False
212191

213192
def do_reset():
214-
"""Reset the model"""
193+
"""Reset the model to its initial state."""
194+
playing.value = False
215195
model.value = copy.deepcopy(original_model.value)
216-
previous_step.value = 0
217-
force_update()
218-
219-
def do_set_playing(value):
220-
"""Set the playing state."""
221-
if model.value.steps == 0:
222-
# This means the model has been recreated, and the step resets to
223-
# 0. We want to avoid triggering the playing.value = False in the
224-
# on_value_play function.
225-
previous_step.value = model.value.steps
226-
playing.set(value)
227196

228-
with solara.Row():
229-
solara.Button(label="Step", color="primary", on_click=do_step)
230-
# This style is necessary so that the play widget has almost the same
231-
# height as typical Solara buttons.
232-
solara.Style(
233-
"""
234-
.widget-play {
235-
height: 35px;
236-
}
237-
.widget-play button {
238-
color: white;
239-
background-color: #1976D2; // Solara blue color
240-
}
241-
"""
242-
)
243-
widgets.Play(
244-
value=0,
245-
interval=play_interval,
246-
repeat=True,
247-
show_repeat=False,
248-
on_value=on_value_play,
249-
playing=playing.value,
250-
on_playing=do_set_playing,
251-
)
197+
with solara.Row(justify="space-between"):
252198
solara.Button(label="Reset", color="primary", on_click=do_reset)
253-
# threaded_do_play is not used for now because it
254-
# doesn't work in Google colab. We use
255-
# ipywidgets.Play until it is fixed. The threading
256-
# version is definite a much better implementation,
257-
# if it works.
258-
# solara.Button(label="▶", color="primary", on_click=viz.threaded_do_play)
259-
# solara.Button(label="⏸︎", color="primary", on_click=viz.do_pause)
260-
# solara.Button(label="Reset", color="primary", on_click=do_reset)
199+
solara.Button(label="Step", color="primary", on_click=do_step)
200+
solara.Button(label="▶", color="primary", on_click=do_play)
201+
solara.Button(label="⏸︎", color="primary", on_click=do_pause)
261202

262203

263204
def split_model_params(model_params):

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ classifiers = [
2828
"Programming Language :: Python :: 3.10",
2929
"Programming Language :: Python :: 3.11",
3030
"Programming Language :: Python :: 3.12",
31+
"Programming Language :: Python :: 3.13",
3132
"License :: OSI Approved :: Apache Software License",
3233
"Operating System :: OS Independent",
3334
"Development Status :: 3 - Alpha",

tests/test_agent.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,51 @@ def remove_function(agent):
276276
assert len(agentset) == 0
277277

278278

279+
def test_agentset_get():
280+
model = Model()
281+
_ = [TestAgent(i, model) for i in range(10)]
282+
283+
agentset = model.agents
284+
285+
agentset.set("a", 5)
286+
agentset.set("b", 6)
287+
288+
# Case 1: Normal retrieval of existing attributes
289+
values = agentset.get(["a", "b"])
290+
assert all((a == 5) & (b == 6) for a, b in values)
291+
292+
# Case 2: Raise AttributeError when attribute doesn't exist
293+
with pytest.raises(AttributeError):
294+
agentset.get("unknown_attribute")
295+
296+
# Case 3: Use default value when attribute is missing
297+
results = agentset.get(
298+
"unknown_attribute", handle_missing="default", default_value=True
299+
)
300+
assert all(results) is True
301+
302+
# Case 4: Retrieve mixed attributes with default value for missing ones
303+
values = agentset.get(
304+
["a", "unknown_attribute"], handle_missing="default", default_value=True
305+
)
306+
assert all((a == 5) & (unknown is True) for a, unknown in values)
307+
308+
# Case 5: Invalid handle_missing value raises ValueError
309+
with pytest.raises(ValueError):
310+
agentset.get("unknown_attribute", handle_missing="some nonsense value")
311+
312+
# Case 6: Retrieve multiple attributes with mixed existence and 'default' handling
313+
values = agentset.get(
314+
["a", "b", "unknown_attribute"], handle_missing="default", default_value=0
315+
)
316+
assert all((a == 5) & (b == 6) & (unknown == 0) for a, b, unknown in values)
317+
318+
# Case 7: 'default' handling when one attribute is completely missing from some agents
319+
agentset.select(at_most=0.5).set("c", 8) # Only some agents have attribute 'c'
320+
values = agentset.get(["a", "c"], handle_missing="default", default_value=-1)
321+
assert all((a == 5) & (c in [8, -1]) for a, c in values)
322+
323+
279324
def test_agentset_agg():
280325
model = Model()
281326
agents = [TestAgent(model) for i in range(10)]

tests/test_space.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,23 @@ def test_move_agent_closest_selection(self):
509509
self.space.move_agent_to_one_of(agent, possible_positions, selection="closest")
510510
assert agent.pos == (6, 6)
511511

512+
def test_move_agent_closest_selection_multiple(self):
513+
random_locations = []
514+
agent = self.agents[0]
515+
agent.pos = (5, 5)
516+
repetititions = 10
517+
518+
for _ in range(repetititions):
519+
possible_positions = [(4, 4), (6, 6), (10, 10), (20, 20)]
520+
self.space.move_agent_to_one_of(
521+
agent, possible_positions, selection="closest"
522+
)
523+
random_locations.append(agent.pos)
524+
assert agent.pos in possible_positions
525+
self.space.move_agent_to_one_of(agent, [(5, 5)], selection="closest")
526+
non_random_locations = [random_locations[0]] * repetititions
527+
assert random_locations != non_random_locations
528+
512529
def test_move_agent_invalid_selection(self):
513530
agent = self.agents[0]
514531
possible_positions = [(10, 10), (20, 20), (30, 30)]

0 commit comments

Comments
 (0)