Skip to content

Commit 1a29aa4

Browse files
authored
expand ax.scatter kwargs that can be used (#2445)
1 parent 217cb58 commit 1a29aa4

File tree

3 files changed

+45
-29
lines changed

3 files changed

+45
-29
lines changed

docs/overview.md

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -177,20 +177,20 @@ def agent_portrayal(agent):
177177

178178
model_params = {
179179
"N": {
180-
"type": "SliderInt",
181-
"value": 50,
182-
"label": "Number of agents:",
183-
"min": 10,
184-
"max": 100,
185-
"step": 1,
180+
"type": "SliderInt",
181+
"value": 50,
182+
"label": "Number of agents:",
183+
"min": 10,
184+
"max": 100,
185+
"step": 1,
186186
}
187187
}
188188

189189
page = SolaraViz(
190190
MyModel,
191191
[
192-
make_space_component(agent_portrayal),
193-
make_plot_component("mean_age")
192+
make_space_component(agent_portrayal),
193+
make_plot_component("mean_age")
194194
],
195195
model_params=model_params
196196
)

mesa/examples/basic/boltzmann_wealth_model/app.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
from mesa.examples.basic.boltzmann_wealth_model.model import BoltzmannWealthModel
2-
from mesa.visualization import (
3-
SolaraViz,
4-
make_plot_component,
5-
make_space_component,
6-
)
2+
from mesa.visualization import SolaraViz, make_plot_component, make_space_component
73

84

95
def agent_portrayal(agent):
10-
size = 10
11-
color = "tab:red"
12-
if agent.wealth > 0:
13-
size = 50
14-
color = "tab:blue"
15-
return {"size": size, "color": color}
6+
color = agent.wealth # we are using a colormap to translate wealth to color
7+
return {"color": color}
168

179

1810
model_params = {
@@ -28,6 +20,11 @@ def agent_portrayal(agent):
2820
"height": 10,
2921
}
3022

23+
24+
def post_process(ax):
25+
ax.get_figure().colorbar(ax.collections[0], label="wealth", ax=ax)
26+
27+
3128
# Create initial model instance
3229
model1 = BoltzmannWealthModel(50, 10, 10)
3330

@@ -36,7 +33,10 @@ def agent_portrayal(agent):
3633
# Under the hood these are just classes that receive the model instance.
3734
# You can also author your own visualization elements, which can also be functions
3835
# that receive the model instance and return a valid solara component.
39-
SpaceGraph = make_space_component(agent_portrayal)
36+
37+
SpaceGraph = make_space_component(
38+
agent_portrayal, cmap="viridis", vmin=0, vmax=10, post_process=post_process
39+
)
4040
GiniPlot = make_plot_component("Gini")
4141

4242
# Create the SolaraViz page. This will automatically create a server and display the

mesa/visualization/components/matplotlib.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def draw_orthogonal_grid(
309309
agent_portrayal: Callable,
310310
ax: Axes | None = None,
311311
draw_grid: bool = True,
312+
**kwargs,
312313
):
313314
"""Visualize a orthogonal grid.
314315
@@ -317,6 +318,7 @@ def draw_orthogonal_grid(
317318
agent_portrayal: a callable that is called with the agent and returns a dict
318319
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
319320
draw_grid: whether to draw the grid
321+
kwargs: additional keyword arguments passed to ax.scatter
320322
321323
Returns:
322324
Returns the Axes object with the plot drawn onto it.
@@ -333,7 +335,7 @@ def draw_orthogonal_grid(
333335
arguments = collect_agent_data(space, agent_portrayal, size=s_default)
334336

335337
# plot the agents
336-
_scatter(ax, arguments)
338+
_scatter(ax, arguments, **kwargs)
337339

338340
# further styling
339341
ax.set_xlim(-0.5, space.width - 0.5)
@@ -354,6 +356,7 @@ def draw_hex_grid(
354356
agent_portrayal: Callable,
355357
ax: Axes | None = None,
356358
draw_grid: bool = True,
359+
**kwargs,
357360
):
358361
"""Visualize a hex grid.
359362
@@ -362,6 +365,7 @@ def draw_hex_grid(
362365
agent_portrayal: a callable that is called with the agent and returns a dict
363366
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
364367
draw_grid: whether to draw the grid
368+
kwargs: additional keyword arguments passed to ax.scatter
365369
366370
Returns:
367371
Returns the Axes object with the plot drawn onto it.
@@ -394,7 +398,7 @@ def draw_hex_grid(
394398
arguments["loc"] = loc
395399

396400
# plot the agents
397-
_scatter(ax, arguments)
401+
_scatter(ax, arguments, **kwargs)
398402

399403
# further styling and adding of grid
400404
ax.set_xlim(-1, space.width + 0.5)
@@ -443,6 +447,7 @@ def draw_network(
443447
draw_grid: bool = True,
444448
layout_alg=nx.spring_layout,
445449
layout_kwargs=None,
450+
**kwargs,
446451
):
447452
"""Visualize a network space.
448453
@@ -453,6 +458,7 @@ def draw_network(
453458
draw_grid: whether to draw the grid
454459
layout_alg: a networkx layout algorithm or other callable with the same behavior
455460
layout_kwargs: a dictionary of keyword arguments for the layout algorithm
461+
kwargs: additional keyword arguments passed to ax.scatter
456462
457463
Returns:
458464
Returns the Axes object with the plot drawn onto it.
@@ -488,7 +494,7 @@ def draw_network(
488494
arguments["loc"] = pos[arguments["loc"]]
489495

490496
# plot the agents
491-
_scatter(ax, arguments)
497+
_scatter(ax, arguments, **kwargs)
492498

493499
# further styling
494500
ax.set_axis_off()
@@ -506,14 +512,15 @@ def draw_network(
506512

507513

508514
def draw_continuous_space(
509-
space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None
515+
space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None, **kwargs
510516
):
511517
"""Visualize a continuous space.
512518
513519
Args:
514520
space: the space to visualize
515521
agent_portrayal: a callable that is called with the agent and returns a dict
516522
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
523+
kwargs: additional keyword arguments passed to ax.scatter
517524
518525
Returns:
519526
Returns the Axes object with the plot drawn onto it.
@@ -536,7 +543,7 @@ def draw_continuous_space(
536543
arguments = collect_agent_data(space, agent_portrayal, size=s_default)
537544

538545
# plot the agents
539-
_scatter(ax, arguments)
546+
_scatter(ax, arguments, **kwargs)
540547

541548
# further visual styling
542549
border_style = "solid" if not space.torus else (0, (5, 10))
@@ -552,14 +559,15 @@ def draw_continuous_space(
552559

553560

554561
def draw_voroinoi_grid(
555-
space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None
562+
space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None, **kwargs
556563
):
557564
"""Visualize a voronoi grid.
558565
559566
Args:
560567
space: the space to visualize
561568
agent_portrayal: a callable that is called with the agent and returns a dict
562569
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
570+
kwargs: additional keyword arguments passed to ax.scatter
563571
564572
Returns:
565573
Returns the Axes object with the plot drawn onto it.
@@ -589,7 +597,7 @@ def draw_voroinoi_grid(
589597
ax.set_xlim(x_min - x_padding, x_max + x_padding)
590598
ax.set_ylim(y_min - y_padding, y_max + y_padding)
591599

592-
_scatter(ax, arguments)
600+
_scatter(ax, arguments, **kwargs)
593601

594602
for cell in space.all_cells:
595603
polygon = cell.properties["polygon"]
@@ -604,8 +612,15 @@ def draw_voroinoi_grid(
604612
return ax
605613

606614

607-
def _scatter(ax: Axes, arguments):
608-
"""Helper function for plotting the agents."""
615+
def _scatter(ax: Axes, arguments, **kwargs):
616+
"""Helper function for plotting the agents.
617+
618+
Args:
619+
ax: a Matplotlib Axes instance
620+
arguments: the agents specific arguments for platting
621+
kwargs: additional keyword arguments for ax.scatter
622+
623+
"""
609624
loc = arguments.pop("loc")
610625

611626
x = loc[:, 0]
@@ -624,6 +639,7 @@ def _scatter(ax: Axes, arguments):
624639
marker=mark,
625640
zorder=z_order,
626641
**{k: v[logical] for k, v in arguments.items()},
642+
**kwargs,
627643
)
628644

629645

0 commit comments

Comments
 (0)