Skip to content

Commit b9088dd

Browse files
authored
Add better docstrings and improve layout of solara viz
1 parent 696f123 commit b9088dd

File tree

3 files changed

+149
-75
lines changed

3 files changed

+149
-75
lines changed

mesa/visualization/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22

33
from .components.altair import make_space_altair
44
from .components.matplotlib import make_plot_measure, make_space_matplotlib
5-
from .solara_viz import JupyterViz, SolaraViz, make_text
5+
from .solara_viz import JupyterViz, SolaraViz
66
from .UserParam import Slider
77

88
__all__ = [
99
"JupyterViz",
1010
"SolaraViz",
11-
"make_text",
1211
"Slider",
1312
"make_space_altair",
1413
"make_space_matplotlib",

mesa/visualization/solara_viz.py

Lines changed: 133 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
- SolaraViz: Main component for creating visualizations, supporting grid displays and plots
88
- ModelController: Handles model execution controls (step, play, pause, reset)
99
- UserInputs: Generates UI elements for adjusting model parameters
10-
- Card: Renders individual visualization elements (space, measures)
1110
1211
The module uses Solara for rendering in Jupyter notebooks or as standalone web applications.
1312
It supports various types of visualizations including matplotlib plots, agent grids, and
@@ -22,10 +21,14 @@
2221
See the Visualization Tutorial and example models for more details.
2322
"""
2423

24+
from __future__ import annotations
25+
2526
import copy
2627
import time
28+
from collections.abc import Callable
2729
from typing import TYPE_CHECKING, Literal
2830

31+
import reacton.core
2932
import solara
3033
from solara.alias import rv
3134

@@ -89,31 +92,57 @@ def Card(
8992

9093
@solara.component
9194
def SolaraViz(
92-
model: "Model" | solara.Reactive["Model"],
93-
components: list[solara.component] | Literal["default"] = "default",
94-
play_interval=100,
95+
model: Model | solara.Reactive[Model],
96+
components: list[reacton.core.Component]
97+
| list[Callable[[Model], reacton.core.Component]]
98+
| Literal["default"] = "default",
99+
play_interval: int = 100,
95100
model_params=None,
96-
seed=0,
101+
seed: float = 0,
97102
name: str | None = None,
98103
):
99104
"""Solara visualization component.
100105
106+
This component provides a visualization interface for a given model using Solara.
107+
It supports various visualization components and allows for interactive model
108+
stepping and parameter adjustments.
109+
101110
Args:
102-
model: a Model instance
103-
components: list of solara components
104-
play_interval: int
105-
model_params: parameters for instantiating a model
106-
seed: the seed for the rng
107-
name: str
111+
model (Model | solara.Reactive[Model]): A Model instance or a reactive Model.
112+
This is the main model to be visualized. If a non-reactive model is provided,
113+
it will be converted to a reactive model.
114+
components (list[solara.component] | Literal["default"], optional): List of solara
115+
components or functions that return a solara component.
116+
These components are used to render different parts of the model visualization.
117+
Defaults to "default", which uses the default Altair space visualization.
118+
play_interval (int, optional): Interval for playing the model steps in milliseconds.
119+
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
120+
model_params (dict, optional): Parameters for (re-)instantiating a model.
121+
Can include user-adjustable parameters and fixed parameters. Defaults to None.
122+
seed (int, optional): Seed for the random number generator. This ensures reproducibility
123+
of the model's behavior. Defaults to 0.
124+
name (str | None, optional): Name of the visualization. Defaults to the models class name.
108125
126+
Returns:
127+
solara.component: A Solara component that renders the visualization interface for the model.
128+
129+
Example:
130+
>>> model = MyModel()
131+
>>> page = SolaraViz(model)
132+
>>> page
133+
134+
Notes:
135+
- The `model` argument can be either a direct model instance or a reactive model. If a direct
136+
model instance is provided, it will be converted to a reactive model using `solara.use_reactive`.
137+
- The `play_interval` argument controls the speed of the model's automatic stepping. A lower
138+
value results in faster stepping, while a higher value results in slower stepping.
109139
"""
110-
update_counter.get()
111140
if components == "default":
112141
components = [components_altair.make_space_altair()]
113142

114143
# Convert model to reactive
115144
if not isinstance(model, solara.Reactive):
116-
model = solara.use_reactive(model)
145+
model = solara.use_reactive(model) # noqa: SH102, RUF100
117146

118147
def connect_to_model():
119148
# Patch the step function to force updates
@@ -133,39 +162,68 @@ def step():
133162
with solara.AppBar():
134163
solara.AppBarTitle(name if name else model.value.__class__.__name__)
135164

136-
with solara.Sidebar():
137-
with solara.Card("Controls", margin=1, elevation=2):
138-
if model_params is not None:
165+
with solara.Sidebar(), solara.Column():
166+
with solara.Card("Controls"):
167+
ModelController(model, play_interval)
168+
169+
if model_params is not None:
170+
with solara.Card("Model Parameters"):
139171
ModelCreator(
140172
model,
141173
model_params,
142174
seed=seed,
143175
)
144-
ModelController(model, play_interval)
145-
with solara.Card("Information", margin=1, elevation=2):
176+
with solara.Card("Information"):
146177
ShowSteps(model.value)
147178

148-
solara.Column(
149-
[
150-
*(component(model.value) for component in components),
151-
]
152-
)
179+
ComponentsView(components, model.value)
180+
181+
182+
def _wrap_component(
183+
component: reacton.core.Component | Callable[[Model], reacton.core.Component],
184+
) -> reacton.core.Component:
185+
"""Wrap a component in an auto-updated Solara component if needed."""
186+
if isinstance(component, reacton.core.Component):
187+
return component
188+
189+
@solara.component
190+
def WrappedComponent(model):
191+
update_counter.get()
192+
return component(model)
193+
194+
return WrappedComponent
195+
196+
197+
@solara.component
198+
def ComponentsView(
199+
components: list[reacton.core.Component]
200+
| list[Callable[[Model], reacton.core.Component]],
201+
model: Model,
202+
):
203+
"""Display a list of components.
204+
205+
Args:
206+
components: List of components to display
207+
model: Model instance to pass to each component
208+
"""
209+
wrapped_components = [_wrap_component(component) for component in components]
210+
211+
with solara.Column():
212+
for component in wrapped_components:
213+
component(model)
153214

154215

155216
JupyterViz = SolaraViz
156217

157218

158219
@solara.component
159-
def ModelController(model: solara.Reactive["Model"], play_interval=100):
220+
def ModelController(model: solara.Reactive[Model], play_interval=100):
160221
"""Create controls for model execution (step, play, pause, reset).
161222
162223
Args:
163-
model: The reactive model being visualized
164-
play_interval: Interval between steps during play
224+
model (solara.Reactive[Model]): Reactive model instance
225+
play_interval (int, optional): Interval for playing the model steps in milliseconds.
165226
"""
166-
if not isinstance(model, solara.Reactive):
167-
model = solara.use_reactive(model)
168-
169227
playing = solara.use_reactive(False)
170228
original_model = solara.use_reactive(None)
171229

@@ -188,24 +246,25 @@ def do_step():
188246
"""Advance the model by one step."""
189247
model.value.step()
190248

191-
def do_play():
192-
"""Run the model continuously."""
193-
playing.value = True
194-
195-
def do_pause():
196-
"""Pause the model execution."""
197-
playing.value = False
198-
199249
def do_reset():
200250
"""Reset the model to its initial state."""
201251
playing.value = False
202252
model.value = copy.deepcopy(original_model.value)
203253

254+
def do_play_pause():
255+
"""Toggle play/pause."""
256+
playing.value = not playing.value
257+
204258
with solara.Row(justify="space-between"):
205259
solara.Button(label="Reset", color="primary", on_click=do_reset)
206-
solara.Button(label="Step", color="primary", on_click=do_step)
207-
solara.Button(label="▶", color="primary", on_click=do_play)
208-
solara.Button(label="⏸︎", color="primary", on_click=do_pause)
260+
solara.Button(
261+
label="▶" if not playing.value else "❚❚",
262+
color="primary",
263+
on_click=do_play_pause,
264+
)
265+
solara.Button(
266+
label="Step", color="primary", on_click=do_step, disabled=playing.value
267+
)
209268

210269

211270
def split_model_params(model_params):
@@ -246,13 +305,34 @@ def check_param_is_fixed(param):
246305

247306
@solara.component
248307
def ModelCreator(model, model_params, seed=1):
249-
"""Helper class to create a new Model instance.
308+
"""Solara component for creating and managing a model instance with user-defined parameters.
309+
310+
This component allows users to create a model instance with specified parameters and seed.
311+
It provides an interface for adjusting model parameters and reseeding the model's random
312+
number generator.
250313
251314
Args:
252-
model: model instance
253-
model_params: model parameters
254-
seed: the seed to use for the random number generator
315+
model (solara.Reactive[Model]): A reactive model instance. This is the main model to be created and managed.
316+
model_params (dict): Dictionary of model parameters. This includes both user-adjustable parameters and fixed parameters.
317+
seed (int, optional): Initial seed for the random number generator. Defaults to 1.
255318
319+
Returns:
320+
solara.component: A Solara component that renders the model creation and management interface.
321+
322+
Example:
323+
>>> model = solara.reactive(MyModel())
324+
>>> model_params = {
325+
>>> "param1": {"type": "slider", "value": 10, "min": 0, "max": 100},
326+
>>> "param2": {"type": "slider", "value": 5, "min": 1, "max": 10},
327+
>>> }
328+
>>> creator = ModelCreator(model, model_params)
329+
>>> creator
330+
331+
Notes:
332+
- The `model_params` argument should be a dictionary where keys are parameter names and values either fixed values
333+
or are dictionaries containing parameter details such as type, value, min, and max.
334+
- The `seed` argument ensures reproducibility by setting the initial seed for the model's random number generator.
335+
- The component provides an interface for adjusting user-defined parameters and reseeding the model.
256336
257337
"""
258338
user_params, fixed_params = split_model_params(model_params)
@@ -279,13 +359,14 @@ def create_model():
279359

280360
solara.use_effect(create_model, [model_parameters, reactive_seed.value])
281361

282-
solara.InputText(
283-
label="Seed",
284-
value=reactive_seed,
285-
continuous_update=True,
286-
)
362+
with solara.Row(justify="space-between"):
363+
solara.InputText(
364+
label="Seed",
365+
value=reactive_seed,
366+
continuous_update=True,
367+
)
287368

288-
solara.Button(label="Reseed", color="primary", on_click=do_reseed)
369+
solara.Button(label="Reseed", color="primary", on_click=do_reseed)
289370

290371
UserInputs(user_params, on_change=on_change)
291372

@@ -358,22 +439,6 @@ def change_handler(value, name=name):
358439
raise ValueError(f"{input_type} is not a supported input type")
359440

360441

361-
def make_text(renderer):
362-
"""Create a function that renders text using Markdown.
363-
364-
Args:
365-
renderer: Function that takes a model and returns a string
366-
367-
Returns:
368-
function: A function that renders the text as Markdown
369-
"""
370-
371-
def function(model):
372-
solara.Markdown(renderer(model))
373-
374-
return function
375-
376-
377442
def make_initial_grid_layout(layout_types):
378443
"""Create an initial grid layout for visualization components.
379444
@@ -397,6 +462,7 @@ def make_initial_grid_layout(layout_types):
397462

398463

399464
@solara.component
400-
def ShowSteps(model): # noqa: D103
465+
def ShowSteps(model):
466+
"""Display the current step of the model."""
401467
update_counter.get()
402468
return solara.Text(f"Step: {model.steps}")

tests/test_solara_viz.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
"""Test Solara visualizations."""
22

33
import unittest
4-
from unittest.mock import Mock
54

65
import ipyvuetify as vw
76
import solara
87

98
import mesa
9+
import mesa.visualization.components.altair
10+
import mesa.visualization.components.matplotlib
1011
from mesa.visualization.components.matplotlib import make_space_matplotlib
1112
from mesa.visualization.solara_viz import Slider, SolaraViz, UserInputs
1213

@@ -86,10 +87,12 @@ def Test(user_params):
8687

8788

8889
def test_call_space_drawer(mocker): # noqa: D103
89-
mock_space_matplotlib = mocker.patch(
90-
"mesa.visualization.components.matplotlib.SpaceMatplotlib"
90+
mock_space_matplotlib = mocker.spy(
91+
mesa.visualization.components.matplotlib, "SpaceMatplotlib"
9192
)
9293

94+
mock_space_altair = mocker.spy(mesa.visualization.components.altair, "SpaceAltair")
95+
9396
model = mesa.Model()
9497
mocker.patch.object(mesa.Model, "__init__", return_value=None)
9598

@@ -105,13 +108,19 @@ def test_call_space_drawer(mocker): # noqa: D103
105108

106109
# specify no space should be drawn
107110
mock_space_matplotlib.reset_mock()
108-
solara.render(SolaraViz(model, components=[]))
111+
solara.render(SolaraViz(model))
109112
# should call default method with class instance and agent portrayal
110113
assert mock_space_matplotlib.call_count == 0
114+
assert mock_space_altair.call_count > 0
111115

112116
# specify a custom space method
113-
altspace_drawer = Mock()
114-
solara.render(SolaraViz(model, components=[altspace_drawer]))
117+
class AltSpace:
118+
@staticmethod
119+
def drawer(model):
120+
return
121+
122+
altspace_drawer = mocker.spy(AltSpace, "drawer")
123+
solara.render(SolaraViz(model, components=[AltSpace.drawer]))
115124
altspace_drawer.assert_called_with(model)
116125

117126
# check voronoi space drawer

0 commit comments

Comments
 (0)