Skip to content

Commit 48065fd

Browse files
Corvincepre-commit-ci[bot]EwoutH
authored
Solaraviz api (#2263)
* add new solara viz API --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ewout ter Hoeven <[email protected]>
1 parent d01d15d commit 48065fd

File tree

7 files changed

+872
-40
lines changed

7 files changed

+872
-40
lines changed

mesa/visualization/UserParam.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
class UserParam:
2+
_ERROR_MESSAGE = "Missing or malformed inputs for '{}' Option '{}'"
3+
4+
def maybe_raise_error(self, param_type, valid):
5+
if valid:
6+
return
7+
msg = self._ERROR_MESSAGE.format(param_type, self.label)
8+
raise ValueError(msg)
9+
10+
11+
class Slider(UserParam):
12+
"""
13+
A number-based slider input with settable increment.
14+
15+
Example:
16+
17+
slider_option = Slider("My Slider", value=123, min=10, max=200, step=0.1)
18+
19+
Args:
20+
label: The displayed label in the UI
21+
value: The initial value of the slider
22+
min: The minimum possible value of the slider
23+
max: The maximum possible value of the slider
24+
step: The step between min and max for a range of possible values
25+
dtype: either int or float
26+
"""
27+
28+
def __init__(
29+
self,
30+
label="",
31+
value=None,
32+
min=None,
33+
max=None,
34+
step=1,
35+
dtype=None,
36+
):
37+
self.label = label
38+
self.value = value
39+
self.min = min
40+
self.max = max
41+
self.step = step
42+
43+
# Validate option type to make sure values are supplied properly
44+
valid = not (self.value is None or self.min is None or self.max is None)
45+
self.maybe_raise_error("slider", valid)
46+
47+
if dtype is None:
48+
self.is_float_slider = self._check_values_are_float(value, min, max, step)
49+
else:
50+
self.is_float_slider = dtype is float
51+
52+
def _check_values_are_float(self, value, min, max, step):
53+
return any(isinstance(n, float) for n in (value, min, max, step))
54+
55+
def get(self, attr):
56+
return getattr(self, attr)

mesa/visualization/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .components.altair import make_space_altair
2+
from .components.matplotlib import make_plot_measure, make_space_matplotlib
3+
from .solara_viz import JupyterViz, SolaraViz, make_text
4+
from .UserParam import Slider
5+
6+
__all__ = [
7+
"JupyterViz",
8+
"SolaraViz",
9+
"make_text",
10+
"Slider",
11+
"make_space_altair",
12+
"make_space_matplotlib",
13+
"make_plot_measure",
14+
]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import contextlib
2+
3+
import solara
4+
5+
with contextlib.suppress(ImportError):
6+
import altair as alt
7+
8+
from mesa.visualization.utils import update_counter
9+
10+
11+
def make_space_altair(agent_portrayal=None):
12+
if agent_portrayal is None:
13+
14+
def agent_portrayal(a):
15+
return {"id": a.unique_id}
16+
17+
def MakeSpaceAltair(model):
18+
return SpaceAltair(model, agent_portrayal)
19+
20+
return MakeSpaceAltair
21+
22+
23+
@solara.component
24+
def SpaceAltair(model, agent_portrayal, dependencies: list[any] | None = None):
25+
update_counter.get()
26+
space = getattr(model, "grid", None)
27+
if space is None:
28+
# Sometimes the space is defined as model.space instead of model.grid
29+
space = model.space
30+
chart = _draw_grid(space, agent_portrayal)
31+
solara.FigureAltair(chart)
32+
33+
34+
def _draw_grid(space, agent_portrayal):
35+
def portray(g):
36+
all_agent_data = []
37+
for content, (x, y) in g.coord_iter():
38+
if not content:
39+
continue
40+
if not hasattr(content, "__iter__"):
41+
# Is a single grid
42+
content = [content] # noqa: PLW2901
43+
for agent in content:
44+
# use all data from agent portrayal, and add x,y coordinates
45+
agent_data = agent_portrayal(agent)
46+
agent_data["x"] = x
47+
agent_data["y"] = y
48+
all_agent_data.append(agent_data)
49+
return all_agent_data
50+
51+
all_agent_data = portray(space)
52+
invalid_tooltips = ["color", "size", "x", "y"]
53+
54+
encoding_dict = {
55+
# no x-axis label
56+
"x": alt.X("x", axis=None, type="ordinal"),
57+
# no y-axis label
58+
"y": alt.Y("y", axis=None, type="ordinal"),
59+
"tooltip": [
60+
alt.Tooltip(key, type=alt.utils.infer_vegalite_type([value]))
61+
for key, value in all_agent_data[0].items()
62+
if key not in invalid_tooltips
63+
],
64+
}
65+
has_color = "color" in all_agent_data[0]
66+
if has_color:
67+
encoding_dict["color"] = alt.Color("color", type="nominal")
68+
has_size = "size" in all_agent_data[0]
69+
if has_size:
70+
encoding_dict["size"] = alt.Size("size", type="quantitative")
71+
72+
chart = (
73+
alt.Chart(
74+
alt.Data(values=all_agent_data), encoding=alt.Encoding(**encoding_dict)
75+
)
76+
.mark_point(filled=True)
77+
.properties(width=280, height=280)
78+
# .configure_view(strokeOpacity=0) # hide grid/chart lines
79+
)
80+
# This is the default value for the marker size, which auto-scales
81+
# according to the grid area.
82+
if not has_size:
83+
length = min(space.width, space.height)
84+
chart = chart.mark_point(size=30000 / length**2, filled=True)
85+
86+
return chart

0 commit comments

Comments
 (0)