|
1 | 1 | import math |
2 | 2 |
|
3 | 3 | import solara |
4 | | -from matplotlib.figure import Figure |
5 | | -from matplotlib.ticker import MaxNLocator |
6 | 4 |
|
7 | 5 | from mesa.examples.basic.virus_on_network.model import ( |
8 | 6 | State, |
9 | 7 | VirusOnNetwork, |
10 | 8 | number_infected, |
11 | 9 | ) |
12 | | -from mesa.visualization import Slider, SolaraViz, make_space_component |
13 | | - |
| 10 | +from mesa.visualization import ( |
| 11 | + Slider, |
| 12 | + SolaraViz, |
| 13 | + make_plot_measure, |
| 14 | + make_space_component, |
| 15 | +) |
14 | 16 |
|
15 | | -def agent_portrayal(graph): |
16 | | - def get_agent(node): |
17 | | - return graph.nodes[node]["agent"][0] |
18 | 17 |
|
19 | | - edge_width = [] |
20 | | - edge_color = [] |
21 | | - for u, v in graph.edges(): |
22 | | - agent1 = get_agent(u) |
23 | | - agent2 = get_agent(v) |
24 | | - w = 2 |
25 | | - ec = "#e8e8e8" |
26 | | - if State.RESISTANT in (agent1.state, agent2.state): |
27 | | - w = 3 |
28 | | - ec = "black" |
29 | | - edge_width.append(w) |
30 | | - edge_color.append(ec) |
| 18 | +def agent_portrayal(agent): |
31 | 19 | node_color_dict = { |
32 | 20 | State.INFECTED: "tab:red", |
33 | 21 | State.SUSCEPTIBLE: "tab:green", |
34 | 22 | State.RESISTANT: "tab:gray", |
35 | 23 | } |
36 | | - node_color = [node_color_dict[get_agent(node).state] for node in graph.nodes()] |
37 | | - return { |
38 | | - "width": edge_width, |
39 | | - "edge_color": edge_color, |
40 | | - "node_color": node_color, |
41 | | - } |
| 24 | + return {"color": node_color_dict[agent.state], "size": 10} |
42 | 25 |
|
43 | 26 |
|
44 | 27 | def get_resistant_susceptible_ratio(model): |
45 | 28 | ratio = model.resistant_susceptible_ratio() |
46 | 29 | ratio_text = r"$\infty$" if ratio is math.inf else f"{ratio:.2f}" |
47 | 30 | infected_text = str(number_infected(model)) |
48 | 31 |
|
49 | | - return f"Resistant/Susceptible Ratio: {ratio_text}<br>Infected Remaining: {infected_text}" |
50 | | - |
51 | | - |
52 | | -def make_plot(model): |
53 | | - # This is for the case when we want to plot multiple measures in 1 figure. |
54 | | - fig = Figure() |
55 | | - ax = fig.subplots() |
56 | | - measures = ["Infected", "Susceptible", "Resistant"] |
57 | | - colors = ["tab:red", "tab:green", "tab:gray"] |
58 | | - for i, m in enumerate(measures): |
59 | | - color = colors[i] |
60 | | - df = model.datacollector.get_model_vars_dataframe() |
61 | | - ax.plot(df.loc[:, m], label=m, color=color) |
62 | | - fig.legend() |
63 | | - # Set integer x axis |
64 | | - ax.xaxis.set_major_locator(MaxNLocator(integer=True)) |
65 | | - ax.set_xlabel("Step") |
66 | | - ax.set_ylabel("Number of Agents") |
67 | | - return solara.FigureMatplotlib(fig) |
| 32 | + return solara.Markdown( |
| 33 | + f"Resistant/Susceptible Ratio: {ratio_text}<br>Infected Remaining: {infected_text}" |
| 34 | + ) |
68 | 35 |
|
69 | 36 |
|
70 | 37 | model_params = { |
@@ -120,15 +87,18 @@ def make_plot(model): |
120 | 87 | } |
121 | 88 |
|
122 | 89 | SpacePlot = make_space_component(agent_portrayal) |
| 90 | +StatePlot = make_plot_measure( |
| 91 | + {"Infected": "tab:red", "Susceptible": "tab:green", "Resistant": "tab:gray"} |
| 92 | +) |
123 | 93 |
|
124 | 94 | model1 = VirusOnNetwork() |
125 | 95 |
|
126 | 96 | page = SolaraViz( |
127 | 97 | model1, |
128 | 98 | [ |
129 | 99 | SpacePlot, |
130 | | - make_plot, |
131 | | - # get_resistant_susceptible_ratio, # TODO: Fix and uncomment |
| 100 | + StatePlot, |
| 101 | + get_resistant_susceptible_ratio, |
132 | 102 | ], |
133 | 103 | model_params=model_params, |
134 | 104 | name="Virus Model", |
|
0 commit comments