Skip to content

Commit 527c023

Browse files
committed
move solara component creations into solara_viz with backend option
1 parent 62aaa01 commit 527c023

File tree

5 files changed

+125
-108
lines changed

5 files changed

+125
-108
lines changed

mesa/visualization/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
"""Solara based visualization for Mesa models."""
22

3-
from .components.altair import make_space_altair
4-
from .components.matplotlib import make_plot_component, make_space_component
5-
from .solara_viz import JupyterViz, SolaraViz
3+
from .solara_viz import JupyterViz, SolaraViz, make_plot_component, make_space_component
64
from .UserParam import Slider
75

86
__all__ = [

mesa/visualization/components/altair.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,6 @@
1212
from mesa.visualization.utils import update_counter
1313

1414

15-
def make_space_altair(agent_portrayal=None): # noqa: D103
16-
if agent_portrayal is None:
17-
18-
def agent_portrayal(a):
19-
return {"id": a.unique_id}
20-
21-
def MakeSpaceAltair(model):
22-
return SpaceAltair(model, agent_portrayal)
23-
24-
return MakeSpaceAltair
25-
26-
2715
@solara.component
2816
def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None): # noqa: D103
2917
update_counter.get()

mesa/visualization/components/matplotlib.py

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Matplotlib based solara components for visualization MESA spaces and plots."""
22

3-
import warnings
43
from collections.abc import Callable
54

65
import matplotlib.pyplot as plt
@@ -31,54 +30,6 @@
3130
Network = NetworkGrid | mesa.experimental.cell_space.Network
3231

3332

34-
def make_space_matplotlib(*args, **kwargs): # noqa: D103
35-
warnings.warn(
36-
"make_space_matplotlib has been renamed to make_space_component",
37-
DeprecationWarning,
38-
stacklevel=2,
39-
)
40-
return make_space_component(*args, **kwargs)
41-
42-
43-
def make_space_component(
44-
agent_portrayal: Callable | None = None,
45-
propertylayer_portrayal: dict | None = None,
46-
post_process: Callable | None = None,
47-
**space_drawing_kwargs,
48-
):
49-
"""Create a Matplotlib-based space visualization component.
50-
51-
Args:
52-
agent_portrayal: Function to portray agents.
53-
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
54-
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks)
55-
space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See
56-
the functions for drawing the various spaces for further details.
57-
58-
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
59-
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
60-
61-
62-
Returns:
63-
function: A function that creates a SpaceMatplotlib component
64-
"""
65-
if agent_portrayal is None:
66-
67-
def agent_portrayal(a):
68-
return {}
69-
70-
def MakeSpaceMatplotlib(model):
71-
return SpaceMatplotlib(
72-
model,
73-
agent_portrayal,
74-
propertylayer_portrayal,
75-
post_process=post_process,
76-
**space_drawing_kwargs,
77-
)
78-
79-
return MakeSpaceMatplotlib
80-
81-
8233
@solara.component
8334
def SpaceMatplotlib(
8435
model,
@@ -107,39 +58,6 @@ def SpaceMatplotlib(
10758
)
10859

10960

110-
def make_plot_measure(*args, **kwargs): # noqa: D103
111-
warnings.warn(
112-
"make_plot_measure has been renamed to make_plot_component",
113-
DeprecationWarning,
114-
stacklevel=2,
115-
)
116-
return make_plot_component(*args, **kwargs)
117-
118-
119-
def make_plot_component(
120-
measure: str | dict[str, str] | list[str] | tuple[str],
121-
post_process: Callable | None = None,
122-
save_format="png",
123-
):
124-
"""Create a plotting function for a specified measure.
125-
126-
Args:
127-
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
128-
post_process: a user-specified callable to do post-processing called with the Axes instance.
129-
save_format: save format of figure in solara backend
130-
131-
Returns:
132-
function: A function that creates a PlotMatplotlib component.
133-
"""
134-
135-
def MakePlotMatplotlib(model):
136-
return PlotMatplotlib(
137-
model, measure, post_process=post_process, save_format=save_format
138-
)
139-
140-
return MakePlotMatplotlib
141-
142-
14361
@solara.component
14462
def PlotMatplotlib(
14563
model,

mesa/visualization/solara_viz.py

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,131 @@
2525

2626
import asyncio
2727
import copy
28+
import warnings
2829
from collections.abc import Callable
2930
from typing import TYPE_CHECKING, Literal
3031

3132
import reacton.core
3233
import solara
3334

34-
import mesa.visualization.components.altair as components_altair
35+
from mesa.visualization.components.altair import SpaceAltair
36+
from mesa.visualization.components.matplotlib import PlotMatplotlib, SpaceMatplotlib
3537
from mesa.visualization.UserParam import Slider
3638
from mesa.visualization.utils import force_update, update_counter
3739

3840
if TYPE_CHECKING:
3941
from mesa.model import Model
4042

4143

44+
def make_space_matplotlib(*args, **kwargs): # noqa: D103
45+
warnings.warn(
46+
"make_space_matplotlib has been renamed to make_space_component",
47+
DeprecationWarning,
48+
stacklevel=2,
49+
)
50+
return make_space_component(*args, **kwargs)
51+
52+
53+
def make_space_altair(agent_portrayal=None): # noqa: D103
54+
warnings.warn(
55+
"make_space_altair has been renamed to make_space_component with backend='altair'",
56+
DeprecationWarning,
57+
stacklevel=2,
58+
)
59+
if agent_portrayal is None:
60+
61+
def agent_portrayal(a):
62+
return {"id": a.unique_id}
63+
64+
def MakeSpaceAltair(model):
65+
return SpaceAltair(model, agent_portrayal)
66+
67+
return MakeSpaceAltair
68+
69+
70+
def make_space_component(
71+
agent_portrayal: Callable | None = None,
72+
propertylayer_portrayal: dict | None = None,
73+
post_process: Callable | None = None,
74+
backend: Literal["matplotlib", "altair"] = "matplotlib",
75+
**space_drawing_kwargs,
76+
):
77+
"""Create a Matplotlib-based space visualization component.
78+
79+
Args:
80+
agent_portrayal: Function to portray agents.
81+
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
82+
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks)
83+
backend: The backend to use for rendering the space. Can be "matplotlib" or "altair".
84+
space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See
85+
the functions for drawing the various spaces for further details.
86+
87+
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
88+
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.
89+
90+
91+
Returns:
92+
function: A function that creates a SpaceMatplotlib component
93+
"""
94+
if agent_portrayal is None:
95+
96+
def agent_portrayal(a):
97+
return {}
98+
99+
def MakeSpaceMatplotlib(model):
100+
return SpaceMatplotlib(
101+
model,
102+
agent_portrayal,
103+
propertylayer_portrayal,
104+
post_process=post_process,
105+
**space_drawing_kwargs,
106+
)
107+
108+
def MakeSpaceAltair(model):
109+
return SpaceAltair(model, agent_portrayal)
110+
111+
match backend:
112+
case "matplotlib":
113+
return MakeSpaceMatplotlib
114+
case "altair":
115+
return MakeSpaceAltair
116+
case _:
117+
raise ValueError(f"Invalid backend: {backend}")
118+
119+
120+
def make_plot_measure(*args, **kwargs): # noqa: D103
121+
warnings.warn(
122+
"make_plot_measure has been renamed to make_plot_component",
123+
DeprecationWarning,
124+
stacklevel=2,
125+
)
126+
return make_plot_component(*args, **kwargs)
127+
128+
129+
def make_plot_component(
130+
measure: str | dict[str, str] | list[str] | tuple[str],
131+
post_process: Callable | None = None,
132+
save_format="png",
133+
):
134+
"""Create a plotting function for a specified measure.
135+
136+
Args:
137+
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
138+
post_process: a user-specified callable to do post-processing called with the Axes instance.
139+
save_format: save format of figure in solara backend
140+
141+
Returns:
142+
function: A function that creates a PlotMatplotlib component.
143+
"""
144+
145+
def MakePlotMatplotlib(model):
146+
return PlotMatplotlib(
147+
model, measure, post_process=post_process, save_format=save_format
148+
)
149+
150+
return MakePlotMatplotlib
151+
152+
42153
@solara.component
43154
def SolaraViz(
44155
model: Model | solara.Reactive[Model],
@@ -63,7 +174,7 @@ def SolaraViz(
63174
components (list[solara.component] | Literal["default"], optional): List of solara
64175
components or functions that return a solara component.
65176
These components are used to render different parts of the model visualization.
66-
Defaults to "default", which uses the default Altair space visualization.
177+
Defaults to "default", which uses the default Matplotlib space visualization.
67178
play_interval (int, optional): Interval for playing the model steps in milliseconds.
68179
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
69180
model_params (dict, optional): Parameters for (re-)instantiating a model.
@@ -87,7 +198,7 @@ def SolaraViz(
87198
value results in faster stepping, while a higher value results in slower stepping.
88199
"""
89200
if components == "default":
90-
components = [components_altair.make_space_altair()]
201+
components = [make_space_component()]
91202

92203
# Convert model to reactive
93204
if not isinstance(model, solara.Reactive):

tests/test_solara_viz.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@
88
import mesa
99
import mesa.visualization.components.altair
1010
import mesa.visualization.components.matplotlib
11-
from mesa.visualization.components.matplotlib import make_space_component
12-
from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs
11+
from mesa.visualization.solara_viz import (
12+
Slider,
13+
SolaraViz,
14+
UserInputs,
15+
make_space_component,
16+
)
1317

1418

1519
class TestMakeUserInput(unittest.TestCase): # noqa: D101
@@ -87,11 +91,9 @@ def Test(user_params):
8791

8892

8993
def test_call_space_drawer(mocker): # noqa: D103
90-
mock_space_matplotlib = mocker.spy(
91-
mesa.visualization.components.matplotlib, "SpaceMatplotlib"
92-
)
94+
mock_space_matplotlib = mocker.spy(mesa.visualization.solara_viz, "SpaceMatplotlib")
9395

94-
mock_space_altair = mocker.spy(mesa.visualization.components.altair, "SpaceAltair")
96+
mock_space_altair = mocker.spy(mesa.visualization.solara_viz, "SpaceAltair")
9597

9698
model = mesa.Model()
9799
mocker.patch.object(mesa.Model, "__init__", return_value=None)
@@ -113,8 +115,8 @@ def test_call_space_drawer(mocker): # noqa: D103
113115
mock_space_matplotlib.reset_mock()
114116
solara.render(SolaraViz(model))
115117
# should call default method with class instance and agent portrayal
116-
assert mock_space_matplotlib.call_count == 0
117-
assert mock_space_altair.call_count > 0
118+
assert mock_space_matplotlib.call_count > 0
119+
assert mock_space_altair.call_count == 0
118120

119121
# specify a custom space method
120122
class AltSpace:

0 commit comments

Comments
 (0)