Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,19 +269,12 @@ def SpaceRendererComponent(
for artist in itertools.chain.from_iterable(all_artists):
artist.remove()

# Draw the space structure if specified
if renderer.space_mesh:
renderer.draw_structure(**renderer.space_kwargs)

# Draw agents if specified
renderer.draw_structure()
if renderer.agent_mesh:
renderer.draw_agents(
agent_portrayal=renderer.agent_portrayal, **renderer.agent_kwargs
)

# Draw property layers if specified
renderer.draw_agents()
if renderer.propertylayer_mesh:
renderer.draw_propertylayer(renderer.propertylayer_portrayal)
renderer.draw_propertylayer()

# Update the fig every time frame
if dependencies:
Expand All @@ -306,15 +299,11 @@ def SpaceRendererComponent(
propertylayer = renderer.propertylayer_mesh or None

if renderer.space_mesh:
structure = renderer.draw_structure(**renderer.space_kwargs)
structure = renderer.draw_structure()
if renderer.agent_mesh:
agents = renderer.draw_agents(
renderer.agent_portrayal, **renderer.agent_kwargs
)
agents = renderer.draw_agents()
if renderer.propertylayer_mesh:
propertylayer = renderer.draw_propertylayer(
renderer.propertylayer_portrayal
)
propertylayer = renderer.draw_propertylayer()

spatial_charts_list = [
chart for chart in [structure, propertylayer, agents] if chart
Expand Down
108 changes: 54 additions & 54 deletions mesa/visualization/space_drawers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def __init__(self, space: OrthogonalGrid):
self.viz_ymin = -0.5
self.viz_ymax = self.space.height - 0.5

def draw_matplotlib(self, ax=None, **space_kwargs):
def draw_matplotlib(self, ax=None, **draw_space_kwargs):
"""Draw the orthogonal grid using matplotlib.

Args:
ax: Matplotlib axes object to draw on
**space_kwargs: Additional keyword arguments for styling.
**draw_space_kwargs: Additional keyword arguments for styling.

Examples:
figsize=(10, 10), color="blue", linewidth=2.
Expand All @@ -96,8 +96,8 @@ def draw_matplotlib(self, ax=None, **space_kwargs):
The modified axes object
"""
fig_kwargs = {
"figsize": space_kwargs.pop("figsize", (8, 8)),
"dpi": space_kwargs.pop("dpi", 100),
"figsize": draw_space_kwargs.pop("figsize", (8, 8)),
"dpi": draw_space_kwargs.pop("dpi", 100),
}

if ax is None:
Expand All @@ -110,7 +110,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs):
"linewidth": 1,
"alpha": 1,
}
line_kwargs.update(space_kwargs)
line_kwargs.update(draw_space_kwargs)

ax.set_xlim(self.viz_xmin, self.viz_xmax)
ax.set_ylim(self.viz_ymin, self.viz_ymax)
Expand All @@ -123,13 +123,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs):

return ax

def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs):
"""Draw the orthogonal grid using Altair.

Args:
chart_width: Width for the shown chart
chart_height: Height for the shown chart
**chart_kwargs: Additional keyword arguments for styling the chart.
**draw_chart_kwargs: Additional keyword arguments for styling the chart.

Examples:
width=500, height=500, title="Grid".
Expand All @@ -139,20 +139,20 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
"""
# for axis and grid styling
axis_kwargs = {
"xlabel": chart_kwargs.pop("xlabel", "X"),
"ylabel": chart_kwargs.pop("ylabel", "Y"),
"grid_color": chart_kwargs.pop("grid_color", "lightgray"),
"grid_dash": chart_kwargs.pop("grid_dash", [2, 2]),
"grid_width": chart_kwargs.pop("grid_width", 1),
"grid_opacity": chart_kwargs.pop("grid_opacity", 1),
"xlabel": draw_chart_kwargs.pop("xlabel", "X"),
"ylabel": draw_chart_kwargs.pop("ylabel", "Y"),
"grid_color": draw_chart_kwargs.pop("grid_color", "lightgray"),
"grid_dash": draw_chart_kwargs.pop("grid_dash", [2, 2]),
"grid_width": draw_chart_kwargs.pop("grid_width", 1),
"grid_opacity": draw_chart_kwargs.pop("grid_opacity", 1),
}

# for chart properties
chart_props = {
"width": chart_width,
"height": chart_height,
}
chart_props.update(chart_kwargs)
chart_props.update(draw_chart_kwargs)

chart = (
alt.Chart(pd.DataFrame([{}]))
Expand Down Expand Up @@ -263,12 +263,12 @@ def _get_unique_edges(self):
edges.add(edge)
return edges

def draw_matplotlib(self, ax=None, **space_kwargs):
def draw_matplotlib(self, ax=None, **draw_space_kwargs):
"""Draw the hexagonal grid using matplotlib.

Args:
ax: Matplotlib axes object to draw on
**space_kwargs: Additional keyword arguments for styling.
**draw_space_kwargs: Additional keyword arguments for styling.

Examples:
figsize=(8, 8), color="red", alpha=0.5.
Expand All @@ -277,8 +277,8 @@ def draw_matplotlib(self, ax=None, **space_kwargs):
The modified axes object
"""
fig_kwargs = {
"figsize": space_kwargs.pop("figsize", (8, 8)),
"dpi": space_kwargs.pop("dpi", 100),
"figsize": draw_space_kwargs.pop("figsize", (8, 8)),
"dpi": draw_space_kwargs.pop("dpi", 100),
}

if ax is None:
Expand All @@ -290,7 +290,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs):
"linewidth": 1,
"alpha": 0.8,
}
line_kwargs.update(space_kwargs)
line_kwargs.update(draw_space_kwargs)

ax.set_xlim(self.viz_xmin, self.viz_xmax)
ax.set_ylim(self.viz_ymin, self.viz_ymax)
Expand All @@ -300,13 +300,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs):
ax.add_collection(LineCollection(list(edges), **line_kwargs))
return ax

def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs):
"""Draw the hexagonal grid using Altair.

Args:
chart_width: Width for the shown chart
chart_height: Height for the shown chart
**chart_kwargs: Additional keyword arguments for styling the chart.
**draw_chart_kwargs: Additional keyword arguments for styling the chart.

Examples:
* Line properties like color, strokeDash, strokeWidth, opacity.
Expand All @@ -316,17 +316,17 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
Altair chart object representing the hexagonal grid.
"""
mark_kwargs = {
"color": chart_kwargs.pop("color", "black"),
"strokeDash": chart_kwargs.pop("strokeDash", [2, 2]),
"strokeWidth": chart_kwargs.pop("strokeWidth", 1),
"opacity": chart_kwargs.pop("opacity", 0.8),
"color": draw_chart_kwargs.pop("color", "black"),
"strokeDash": draw_chart_kwargs.pop("strokeDash", [2, 2]),
"strokeWidth": draw_chart_kwargs.pop("strokeWidth", 1),
"opacity": draw_chart_kwargs.pop("opacity", 0.8),
}

chart_props = {
"width": chart_width,
"height": chart_height,
}
chart_props.update(chart_kwargs)
chart_props.update(draw_chart_kwargs)

edge_data = []
edges = self._get_unique_edges()
Expand Down Expand Up @@ -400,12 +400,12 @@ def __init__(
self.viz_ymin = ymin - height / 20
self.viz_ymax = ymax + height / 20

def draw_matplotlib(self, ax=None, **space_kwargs):
def draw_matplotlib(self, ax=None, **draw_space_kwargs):
"""Draw the network using matplotlib.

Args:
ax: Matplotlib axes object to draw on.
**space_kwargs: Dictionaries of keyword arguments for styling.
**draw_space_kwargs: Dictionaries of keyword arguments for styling.
Can also handle zorder for both nodes and edges if passed.
* ``node_kwargs``: A dict passed to nx.draw_networkx_nodes.
* ``edge_kwargs``: A dict passed to nx.draw_networkx_edges.
Expand All @@ -423,8 +423,8 @@ def draw_matplotlib(self, ax=None, **space_kwargs):
node_kwargs = {"alpha": 0.5}
edge_kwargs = {"alpha": 0.5, "style": "--"}

node_kwargs.update(space_kwargs.get("node_kwargs", {}))
edge_kwargs.update(space_kwargs.get("edge_kwargs", {}))
node_kwargs.update(draw_space_kwargs.get("node_kwargs", {}))
edge_kwargs.update(draw_space_kwargs.get("edge_kwargs", {}))

node_zorder = node_kwargs.pop("zorder", 1)
edge_zorder = edge_kwargs.pop("zorder", 0)
Expand All @@ -443,13 +443,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs):

return ax

def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs):
"""Draw the network using Altair.

Args:
chart_width: Width for the shown chart
chart_height: Height for the shown chart
**chart_kwargs: Dictionaries for styling the chart.
**draw_chart_kwargs: Dictionaries for styling the chart.
* ``node_kwargs``: A dict of properties for the node's mark_point.
* ``edge_kwargs``: A dict of properties for the edge's mark_rule.
* Other kwargs (e.g., title, width) are passed to chart.properties().
Expand All @@ -474,14 +474,14 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
node_mark_kwargs = {"filled": True, "opacity": 0.5, "size": 500}
edge_mark_kwargs = {"opacity": 0.5, "strokeDash": [5, 3]}

node_mark_kwargs.update(chart_kwargs.pop("node_kwargs", {}))
edge_mark_kwargs.update(chart_kwargs.pop("edge_kwargs", {}))
node_mark_kwargs.update(draw_chart_kwargs.pop("node_kwargs", {}))
edge_mark_kwargs.update(draw_chart_kwargs.pop("edge_kwargs", {}))

chart_kwargs = {
draw_chart_kwargs = {
"width": chart_width,
"height": chart_height,
}
chart_kwargs.update(chart_kwargs)
draw_chart_kwargs.update(draw_chart_kwargs)

edge_plot = (
alt.Chart(edge_positions)
Expand Down Expand Up @@ -510,8 +510,8 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):

chart = edge_plot + node_plot

if chart_kwargs:
chart = chart.properties(**chart_kwargs)
if draw_chart_kwargs:
chart = chart.properties(**draw_chart_kwargs)

return chart

Expand Down Expand Up @@ -540,12 +540,12 @@ def __init__(self, space: ContinuousSpace):
self.viz_ymin = self.space.y_min - y_padding
self.viz_ymax = self.space.y_max + y_padding

def draw_matplotlib(self, ax=None, **space_kwargs):
def draw_matplotlib(self, ax=None, **draw_space_kwargs):
"""Draw the continuous space using matplotlib.

Args:
ax: Matplotlib axes object to draw on
**space_kwargs: Keyword arguments for styling the axis frame.
**draw_space_kwargs: Keyword arguments for styling the axis frame.

Examples:
linewidth=3, color="green"
Expand All @@ -558,7 +558,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs):

border_style = "solid" if not self.space.torus else (0, (5, 10))
spine_kwargs = {"linewidth": 1.5, "color": "black", "linestyle": border_style}
spine_kwargs.update(space_kwargs)
spine_kwargs.update(draw_space_kwargs)

for spine in ax.spines.values():
spine.set(**spine_kwargs)
Expand All @@ -568,20 +568,20 @@ def draw_matplotlib(self, ax=None, **space_kwargs):

return ax

def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs):
"""Draw the continuous space using Altair.

Args:
chart_width: Width for the shown chart
chart_height: Height for the shown chart
**chart_kwargs: Keyword arguments for styling the chart's view properties.
**draw_chart_kwargs: Keyword arguments for styling the chart's view properties.
See Altair's documentation for `configure_view`.

Returns:
An Altair Chart object representing the space.
"""
chart_props = {"width": chart_width, "height": chart_height}
chart_props.update(chart_kwargs)
chart_props.update(draw_chart_kwargs)

chart = (
alt.Chart(pd.DataFrame([{}]))
Expand Down Expand Up @@ -712,12 +712,12 @@ def _get_clipped_segments(self):

return final_segments, clip_box

def draw_matplotlib(self, ax=None, **space_kwargs):
def draw_matplotlib(self, ax=None, **draw_space_kwargs):
"""Draw the Voronoi diagram using matplotlib.

Args:
ax: Matplotlib axes object to draw on
**space_kwargs: Keyword arguments passed to matplotlib's LineCollection.
**draw_space_kwargs: Keyword arguments passed to matplotlib's LineCollection.

Examples:
lw=2, alpha=0.5, colors='red'
Expand All @@ -736,21 +736,21 @@ def draw_matplotlib(self, ax=None, **space_kwargs):
if final_segments:
# Define default styles for the plot
style_args = {"colors": "k", "linestyle": "dotted", "lw": 1}
style_args.update(space_kwargs)
style_args.update(draw_space_kwargs)

# Create the LineCollection with the final styles
lc = LineCollection(final_segments, **style_args)
ax.add_collection(lc)

return ax

def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):
def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs):
"""Draw the Voronoi diagram using Altair.

Args:
chart_width: Width for the shown chart
chart_height: Height for the shown chart
**chart_kwargs: Additional keyword arguments for styling the chart.
**draw_chart_kwargs: Additional keyword arguments for styling the chart.

Examples:
* Line properties like color, strokeDash, strokeWidth, opacity.
Expand All @@ -771,14 +771,14 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs):

# Define default properties for the mark
mark_kwargs = {
"color": chart_kwargs.pop("color", "black"),
"strokeDash": chart_kwargs.pop("strokeDash", [2, 2]),
"strokeWidth": chart_kwargs.pop("strokeWidth", 1),
"opacity": chart_kwargs.pop("opacity", 0.8),
"color": draw_chart_kwargs.pop("color", "black"),
"strokeDash": draw_chart_kwargs.pop("strokeDash", [2, 2]),
"strokeWidth": draw_chart_kwargs.pop("strokeWidth", 1),
"opacity": draw_chart_kwargs.pop("opacity", 0.8),
}

chart_props = {"width": chart_width, "height": chart_height}
chart_props.update(chart_kwargs)
chart_props.update(draw_chart_kwargs)

chart = (
alt.Chart(df)
Expand Down
Loading
Loading