diff --git a/mesa/visualization/solara_viz.py b/mesa/visualization/solara_viz.py index 3bd33e5b432..87b33ea2a9f 100644 --- a/mesa/visualization/solara_viz.py +++ b/mesa/visualization/solara_viz.py @@ -269,19 +269,12 @@ def SpaceRendererComponent( for artist in itertools.chain.from_iterable(all_artists): artist.remove() - # Draw the space structure if specified if renderer.space_mesh: - renderer.draw_structure(**renderer.space_kwargs) - - # Draw agents if specified + renderer.draw_structure() if renderer.agent_mesh: - renderer.draw_agents( - agent_portrayal=renderer.agent_portrayal, **renderer.agent_kwargs - ) - - # Draw property layers if specified + renderer.draw_agents() if renderer.propertylayer_mesh: - renderer.draw_propertylayer(renderer.propertylayer_portrayal) + renderer.draw_propertylayer() # Update the fig every time frame if dependencies: @@ -306,15 +299,11 @@ def SpaceRendererComponent( propertylayer = renderer.propertylayer_mesh or None if renderer.space_mesh: - structure = renderer.draw_structure(**renderer.space_kwargs) + structure = renderer.draw_structure() if renderer.agent_mesh: - agents = renderer.draw_agents( - renderer.agent_portrayal, **renderer.agent_kwargs - ) + agents = renderer.draw_agents() if renderer.propertylayer_mesh: - propertylayer = renderer.draw_propertylayer( - renderer.propertylayer_portrayal - ) + propertylayer = renderer.draw_propertylayer() spatial_charts_list = [ chart for chart in [structure, propertylayer, agents] if chart diff --git a/mesa/visualization/space_drawers.py b/mesa/visualization/space_drawers.py index da5a4079533..e17a420f2b5 100644 --- a/mesa/visualization/space_drawers.py +++ b/mesa/visualization/space_drawers.py @@ -82,12 +82,12 @@ def __init__(self, space: OrthogonalGrid): self.viz_ymin = -0.5 self.viz_ymax = self.space.height - 0.5 - def draw_matplotlib(self, ax=None, **space_kwargs): + def draw_matplotlib(self, ax=None, **draw_space_kwargs): """Draw the orthogonal grid using matplotlib. Args: ax: Matplotlib axes object to draw on - **space_kwargs: Additional keyword arguments for styling. + **draw_space_kwargs: Additional keyword arguments for styling. Examples: figsize=(10, 10), color="blue", linewidth=2. @@ -96,8 +96,8 @@ def draw_matplotlib(self, ax=None, **space_kwargs): The modified axes object """ fig_kwargs = { - "figsize": space_kwargs.pop("figsize", (8, 8)), - "dpi": space_kwargs.pop("dpi", 100), + "figsize": draw_space_kwargs.pop("figsize", (8, 8)), + "dpi": draw_space_kwargs.pop("dpi", 100), } if ax is None: @@ -110,7 +110,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs): "linewidth": 1, "alpha": 1, } - line_kwargs.update(space_kwargs) + line_kwargs.update(draw_space_kwargs) ax.set_xlim(self.viz_xmin, self.viz_xmax) ax.set_ylim(self.viz_ymin, self.viz_ymax) @@ -123,13 +123,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs): return ax - def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): + def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs): """Draw the orthogonal grid using Altair. Args: chart_width: Width for the shown chart chart_height: Height for the shown chart - **chart_kwargs: Additional keyword arguments for styling the chart. + **draw_chart_kwargs: Additional keyword arguments for styling the chart. Examples: width=500, height=500, title="Grid". @@ -139,12 +139,12 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): """ # for axis and grid styling axis_kwargs = { - "xlabel": chart_kwargs.pop("xlabel", "X"), - "ylabel": chart_kwargs.pop("ylabel", "Y"), - "grid_color": chart_kwargs.pop("grid_color", "lightgray"), - "grid_dash": chart_kwargs.pop("grid_dash", [2, 2]), - "grid_width": chart_kwargs.pop("grid_width", 1), - "grid_opacity": chart_kwargs.pop("grid_opacity", 1), + "xlabel": draw_chart_kwargs.pop("xlabel", "X"), + "ylabel": draw_chart_kwargs.pop("ylabel", "Y"), + "grid_color": draw_chart_kwargs.pop("grid_color", "lightgray"), + "grid_dash": draw_chart_kwargs.pop("grid_dash", [2, 2]), + "grid_width": draw_chart_kwargs.pop("grid_width", 1), + "grid_opacity": draw_chart_kwargs.pop("grid_opacity", 1), } # for chart properties @@ -152,7 +152,7 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): "width": chart_width, "height": chart_height, } - chart_props.update(chart_kwargs) + chart_props.update(draw_chart_kwargs) chart = ( alt.Chart(pd.DataFrame([{}])) @@ -263,12 +263,12 @@ def _get_unique_edges(self): edges.add(edge) return edges - def draw_matplotlib(self, ax=None, **space_kwargs): + def draw_matplotlib(self, ax=None, **draw_space_kwargs): """Draw the hexagonal grid using matplotlib. Args: ax: Matplotlib axes object to draw on - **space_kwargs: Additional keyword arguments for styling. + **draw_space_kwargs: Additional keyword arguments for styling. Examples: figsize=(8, 8), color="red", alpha=0.5. @@ -277,8 +277,8 @@ def draw_matplotlib(self, ax=None, **space_kwargs): The modified axes object """ fig_kwargs = { - "figsize": space_kwargs.pop("figsize", (8, 8)), - "dpi": space_kwargs.pop("dpi", 100), + "figsize": draw_space_kwargs.pop("figsize", (8, 8)), + "dpi": draw_space_kwargs.pop("dpi", 100), } if ax is None: @@ -290,7 +290,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs): "linewidth": 1, "alpha": 0.8, } - line_kwargs.update(space_kwargs) + line_kwargs.update(draw_space_kwargs) ax.set_xlim(self.viz_xmin, self.viz_xmax) ax.set_ylim(self.viz_ymin, self.viz_ymax) @@ -300,13 +300,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs): ax.add_collection(LineCollection(list(edges), **line_kwargs)) return ax - def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): + def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs): """Draw the hexagonal grid using Altair. Args: chart_width: Width for the shown chart chart_height: Height for the shown chart - **chart_kwargs: Additional keyword arguments for styling the chart. + **draw_chart_kwargs: Additional keyword arguments for styling the chart. Examples: * Line properties like color, strokeDash, strokeWidth, opacity. @@ -316,17 +316,17 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): Altair chart object representing the hexagonal grid. """ mark_kwargs = { - "color": chart_kwargs.pop("color", "black"), - "strokeDash": chart_kwargs.pop("strokeDash", [2, 2]), - "strokeWidth": chart_kwargs.pop("strokeWidth", 1), - "opacity": chart_kwargs.pop("opacity", 0.8), + "color": draw_chart_kwargs.pop("color", "black"), + "strokeDash": draw_chart_kwargs.pop("strokeDash", [2, 2]), + "strokeWidth": draw_chart_kwargs.pop("strokeWidth", 1), + "opacity": draw_chart_kwargs.pop("opacity", 0.8), } chart_props = { "width": chart_width, "height": chart_height, } - chart_props.update(chart_kwargs) + chart_props.update(draw_chart_kwargs) edge_data = [] edges = self._get_unique_edges() @@ -400,12 +400,12 @@ def __init__( self.viz_ymin = ymin - height / 20 self.viz_ymax = ymax + height / 20 - def draw_matplotlib(self, ax=None, **space_kwargs): + def draw_matplotlib(self, ax=None, **draw_space_kwargs): """Draw the network using matplotlib. Args: ax: Matplotlib axes object to draw on. - **space_kwargs: Dictionaries of keyword arguments for styling. + **draw_space_kwargs: Dictionaries of keyword arguments for styling. Can also handle zorder for both nodes and edges if passed. * ``node_kwargs``: A dict passed to nx.draw_networkx_nodes. * ``edge_kwargs``: A dict passed to nx.draw_networkx_edges. @@ -423,8 +423,8 @@ def draw_matplotlib(self, ax=None, **space_kwargs): node_kwargs = {"alpha": 0.5} edge_kwargs = {"alpha": 0.5, "style": "--"} - node_kwargs.update(space_kwargs.get("node_kwargs", {})) - edge_kwargs.update(space_kwargs.get("edge_kwargs", {})) + node_kwargs.update(draw_space_kwargs.get("node_kwargs", {})) + edge_kwargs.update(draw_space_kwargs.get("edge_kwargs", {})) node_zorder = node_kwargs.pop("zorder", 1) edge_zorder = edge_kwargs.pop("zorder", 0) @@ -443,13 +443,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs): return ax - def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): + def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs): """Draw the network using Altair. Args: chart_width: Width for the shown chart chart_height: Height for the shown chart - **chart_kwargs: Dictionaries for styling the chart. + **draw_chart_kwargs: Dictionaries for styling the chart. * ``node_kwargs``: A dict of properties for the node's mark_point. * ``edge_kwargs``: A dict of properties for the edge's mark_rule. * Other kwargs (e.g., title, width) are passed to chart.properties(). @@ -474,14 +474,14 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): node_mark_kwargs = {"filled": True, "opacity": 0.5, "size": 500} edge_mark_kwargs = {"opacity": 0.5, "strokeDash": [5, 3]} - node_mark_kwargs.update(chart_kwargs.pop("node_kwargs", {})) - edge_mark_kwargs.update(chart_kwargs.pop("edge_kwargs", {})) + node_mark_kwargs.update(draw_chart_kwargs.pop("node_kwargs", {})) + edge_mark_kwargs.update(draw_chart_kwargs.pop("edge_kwargs", {})) - chart_kwargs = { + chart_props = { "width": chart_width, "height": chart_height, } - chart_kwargs.update(chart_kwargs) + chart_props.update(draw_chart_kwargs) edge_plot = ( alt.Chart(edge_positions) @@ -510,8 +510,8 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): chart = edge_plot + node_plot - if chart_kwargs: - chart = chart.properties(**chart_kwargs) + if chart_props: + chart = chart.properties(**chart_props) return chart @@ -540,12 +540,12 @@ def __init__(self, space: ContinuousSpace): self.viz_ymin = self.space.y_min - y_padding self.viz_ymax = self.space.y_max + y_padding - def draw_matplotlib(self, ax=None, **space_kwargs): + def draw_matplotlib(self, ax=None, **draw_space_kwargs): """Draw the continuous space using matplotlib. Args: ax: Matplotlib axes object to draw on - **space_kwargs: Keyword arguments for styling the axis frame. + **draw_space_kwargs: Keyword arguments for styling the axis frame. Examples: linewidth=3, color="green" @@ -558,7 +558,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs): border_style = "solid" if not self.space.torus else (0, (5, 10)) spine_kwargs = {"linewidth": 1.5, "color": "black", "linestyle": border_style} - spine_kwargs.update(space_kwargs) + spine_kwargs.update(draw_space_kwargs) for spine in ax.spines.values(): spine.set(**spine_kwargs) @@ -568,20 +568,20 @@ def draw_matplotlib(self, ax=None, **space_kwargs): return ax - def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): + def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs): """Draw the continuous space using Altair. Args: chart_width: Width for the shown chart chart_height: Height for the shown chart - **chart_kwargs: Keyword arguments for styling the chart's view properties. + **draw_chart_kwargs: Keyword arguments for styling the chart's view properties. See Altair's documentation for `configure_view`. Returns: An Altair Chart object representing the space. """ chart_props = {"width": chart_width, "height": chart_height} - chart_props.update(chart_kwargs) + chart_props.update(draw_chart_kwargs) chart = ( alt.Chart(pd.DataFrame([{}])) @@ -712,12 +712,12 @@ def _get_clipped_segments(self): return final_segments, clip_box - def draw_matplotlib(self, ax=None, **space_kwargs): + def draw_matplotlib(self, ax=None, **draw_space_kwargs): """Draw the Voronoi diagram using matplotlib. Args: ax: Matplotlib axes object to draw on - **space_kwargs: Keyword arguments passed to matplotlib's LineCollection. + **draw_space_kwargs: Keyword arguments passed to matplotlib's LineCollection. Examples: lw=2, alpha=0.5, colors='red' @@ -736,7 +736,7 @@ def draw_matplotlib(self, ax=None, **space_kwargs): if final_segments: # Define default styles for the plot style_args = {"colors": "k", "linestyle": "dotted", "lw": 1} - style_args.update(space_kwargs) + style_args.update(draw_space_kwargs) # Create the LineCollection with the final styles lc = LineCollection(final_segments, **style_args) @@ -744,13 +744,13 @@ def draw_matplotlib(self, ax=None, **space_kwargs): return ax - def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): + def draw_altair(self, chart_width=450, chart_height=350, **draw_chart_kwargs): """Draw the Voronoi diagram using Altair. Args: chart_width: Width for the shown chart chart_height: Height for the shown chart - **chart_kwargs: Additional keyword arguments for styling the chart. + **draw_chart_kwargs: Additional keyword arguments for styling the chart. Examples: * Line properties like color, strokeDash, strokeWidth, opacity. @@ -771,14 +771,14 @@ def draw_altair(self, chart_width=450, chart_height=350, **chart_kwargs): # Define default properties for the mark mark_kwargs = { - "color": chart_kwargs.pop("color", "black"), - "strokeDash": chart_kwargs.pop("strokeDash", [2, 2]), - "strokeWidth": chart_kwargs.pop("strokeWidth", 1), - "opacity": chart_kwargs.pop("opacity", 0.8), + "color": draw_chart_kwargs.pop("color", "black"), + "strokeDash": draw_chart_kwargs.pop("strokeDash", [2, 2]), + "strokeWidth": draw_chart_kwargs.pop("strokeWidth", 1), + "opacity": draw_chart_kwargs.pop("opacity", 0.8), } chart_props = {"width": chart_width, "height": chart_height} - chart_props.update(chart_kwargs) + chart_props.update(draw_chart_kwargs) chart = ( alt.Chart(df) diff --git a/mesa/visualization/space_renderer.py b/mesa/visualization/space_renderer.py index a517a70b9da..090c3fd6627 100644 --- a/mesa/visualization/space_renderer.py +++ b/mesa/visualization/space_renderer.py @@ -4,6 +4,8 @@ backends, supporting various space types and visualization components. """ +from __future__ import annotations + import contextlib import warnings from collections.abc import Callable @@ -63,10 +65,17 @@ def __init__( self.space = getattr(model, "grid", getattr(model, "space", None)) self.space_drawer = self._get_space_drawer() + self.space_mesh = None self.agent_mesh = None self.propertylayer_mesh = None + self.draw_agent_kwargs = {} + self.draw_space_kwargs = {} + + self.agent_portrayal = None + self.propertylayer_portrayal = None + self.post_process_func = None # Keep track of whether post-processing has been applied # to avoid multiple applications on the same axis. @@ -161,55 +170,81 @@ def _map_coordinates(self, arguments): return mapped_arguments - def draw_structure(self, **kwargs): - """Draw the space structure. + def setup_structure(self, **kwargs) -> SpaceRenderer: + """Setup the space structure without drawing. Args: - **kwargs: Additional keyword arguments for the drawing function. + **kwargs: Additional keyword arguments for the setup function. Checkout respective `SpaceDrawer` class on details how to pass **kwargs. Returns: - The visual representation of the space structure. + SpaceRenderer: The current instance for method chaining. """ - # Store space_kwargs for internal use - self.space_kwargs = kwargs + self.draw_space_kwargs = kwargs - self.space_mesh = self.backend_renderer.draw_structure(**self.space_kwargs) - return self.space_mesh + return self - def draw_agents(self, agent_portrayal: Callable, **kwargs): - """Draw agents on the space. + def setup_agents(self, agent_portrayal: Callable, **kwargs) -> SpaceRenderer: + """Setup agents on the space without drawing. Args: agent_portrayal (Callable): Function that takes an agent and returns AgentPortrayalStyle. - **kwargs: Additional keyword arguments for the drawing function. + **kwargs: Additional keyword arguments for the setup function. Checkout respective `SpaceDrawer` class on details how to pass **kwargs. Returns: - The visual representation of the agents. + SpaceRenderer: The current instance for method chaining. """ - # Store data for internal use self.agent_portrayal = agent_portrayal - self.agent_kwargs = kwargs + self.draw_agent_kwargs = kwargs + + return self + def setup_propertylayer( + self, propertylayer_portrayal: Callable | dict + ) -> SpaceRenderer: + """Setup property layers on the space without drawing. + + Args: + propertylayer_portrayal (Callable | dict): Function that returns PropertyLayerStyle + or dict with portrayal parameters. + + Returns: + SpaceRenderer: The current instance for method chaining. + """ + self.propertylayer_portrayal = propertylayer_portrayal + + return self + + def draw_structure(self): + """Draw the space structure. + + Returns: + The visual representation of the space structure. + """ + self.space_mesh = self.backend_renderer.draw_structure(**self.draw_space_kwargs) + return self.space_mesh + + def draw_agents(self): + """Draw agents on the space. + + Returns: + The visual representation of the agents. + """ # Prepare data for agent plotting arguments = self.backend_renderer.collect_agent_data( - self.space, agent_portrayal, default_size=self.space_drawer.s_default + self.space, self.agent_portrayal, default_size=self.space_drawer.s_default ) arguments = self._map_coordinates(arguments) self.agent_mesh = self.backend_renderer.draw_agents( - arguments, **self.agent_kwargs + arguments, **self.draw_agent_kwargs ) return self.agent_mesh - def draw_propertylayer(self, propertylayer_portrayal: Callable | dict): + def draw_propertylayer(self): """Draw property layers on the space. - Args: - propertylayer_portrayal (Callable | dict): Function that returns PropertyLayerStyle - or dict with portrayal parameters. - Returns: The visual representation of the property layers. @@ -267,10 +302,10 @@ def style_callable(layer_object): property_layers = self.space._mesa_property_layers # Convert portrayal to callable if needed - if isinstance(propertylayer_portrayal, dict): - self.propertylayer_portrayal = _dict_to_callable(propertylayer_portrayal) - else: - self.propertylayer_portrayal = propertylayer_portrayal + if isinstance(self.propertylayer_portrayal, dict): + self.propertylayer_portrayal = _dict_to_callable( + self.propertylayer_portrayal + ) number_of_propertylayers = sum( [1 for layer in property_layers if layer != "empty"] @@ -283,41 +318,19 @@ def style_callable(layer_object): ) return self.propertylayer_mesh - def render( - self, - agent_portrayal: Callable | None = None, - propertylayer_portrayal: Callable | dict | None = None, - post_process: Callable | None = None, - **kwargs, - ): + def render(self): """Render the complete space with structure, agents, and property layers. It is an all-in-one method that draws everything required therefore eliminates - the need of calling each method separately, but has a drawback, if want to pass - kwargs to customize the drawing, they have to be broken into - space_kwargs and agent_kwargs. - - Args: - agent_portrayal (Callable | None): Function that returns AgentPortrayalStyle. - If None, agents won't be drawn. - propertylayer_portrayal (Callable | dict | None): Function that returns - PropertyLayerStyle or dict with portrayal parameters. If None, - property layers won't be drawn. - post_process (Callable | None): Function to apply post-processing to the canvas. - **kwargs: Additional keyword arguments for drawing functions. - * ``space_kwargs`` (dict): Arguments for ``draw_structure()``. - * ``agent_kwargs`` (dict): Arguments for ``draw_agents()``. + the need of calling each method separately. """ - space_kwargs = kwargs.pop("space_kwargs", {}) - agent_kwargs = kwargs.pop("agent_kwargs", {}) if self.space_mesh is None: - self.draw_structure(**space_kwargs) - if self.agent_mesh is None and agent_portrayal is not None: - self.draw_agents(agent_portrayal, **agent_kwargs) - if self.propertylayer_mesh is None and propertylayer_portrayal is not None: - self.draw_propertylayer(propertylayer_portrayal) + self.draw_structure() + if self.agent_mesh is None and self.agent_portrayal is not None: + self.draw_agents() + if self.propertylayer_mesh is None and self.propertylayer_portrayal is not None: + self.draw_propertylayer() - self.post_process_func = post_process return self @property @@ -339,13 +352,11 @@ def canvas(self): prop_base, prop_cbar = self.propertylayer_mesh or (None, None) if self.space_mesh: - structure = self.draw_structure(**self.space_kwargs) + structure = self.draw_structure() if self.agent_mesh: - agents = self.draw_agents(self.agent_portrayal, **self.agent_kwargs) + agents = self.draw_agents() if self.propertylayer_mesh: - prop_base, prop_cbar = self.draw_propertylayer( - self.propertylayer_portrayal - ) + prop_base, prop_cbar = self.draw_propertylayer() spatial_charts_list = [ chart for chart in [structure, prop_base, agents] if chart