Skip to content

Commit 498eefc

Browse files
committed
review
1 parent edd84be commit 498eefc

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

doc/examples/other_model_types.ipynb

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,18 @@
100100
"## Model selection\n",
101101
"\n",
102102
"To perform model selection, we define three methods to:\n",
103+
"\n",
103104
"1. calibrate a single model\n",
104105
"\n",
105-
"This is where we convert a PEtab Select model into a statsmodels model, fit the model with statsmodels, then save the log-likelihood value in the PEtab Select model.\n",
106+
" This is where we convert a PEtab Select model into a statsmodels model, fit the model with statsmodels, then save the log-likelihood value in the PEtab Select model.\n",
106107
"\n",
107108
"2. perform a single iteration of a model selection method involving calibration of multiple models\n",
108109
"\n",
109-
"This is generic code that executes the required PEtab Select commands.\n",
110+
" This is generic code that executes the required PEtab Select commands.\n",
110111
"\n",
111112
"3. perform a full model selection run, involving all iterations of a model selection method\n",
112113
"\n",
113-
"This loads the PEtab Select problem from disk and performs all iterations of a model selection method.\n",
114+
" This loads the PEtab Select problem from disk and performs all iterations of a model selection method.\n",
114115
"\n",
115116
"In this case, the PEtab Select problem is setup to perform backward selection. See the files in `other_model_types_problem/petab_select`."
116117
]
@@ -259,15 +260,18 @@
259260
"outputs": [],
260261
"source": [
261262
"# Plot the history of model selection iterations.\n",
263+
"import matplotlib.pyplot as plt\n",
262264
"\n",
263265
"draw_networkx_kwargs = {\n",
264266
" \"arrowstyle\": \"-|>\",\n",
265267
" \"node_shape\": \"s\",\n",
266-
" \"node_size\": 1000,\n",
267268
" \"edgecolors\": \"k\",\n",
268269
"}\n",
270+
"fig, ax = plt.subplots(figsize=(20, 20))\n",
269271
"petab_select.plot.graph_iteration_layers(\n",
270-
" plot_data=plot_data, draw_networkx_kwargs=draw_networkx_kwargs\n",
272+
" plot_data=plot_data,\n",
273+
" draw_networkx_kwargs=draw_networkx_kwargs,\n",
274+
" ax=ax,\n",
271275
");"
272276
]
273277
}

petab_select/plot.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
"""The font size of axis tick labels."""
2929
DEFAULT_NODE_COLOR = "darkgrey"
3030
"""The default color of nodes in graph plots."""
31+
FONT_HEIGHT_WIDTH_RATIO = 2
32+
"""The ratio of the font height to font width. Used for graph node sizes."""
33+
NODE_SIZE_LABEL_SIZE_RATIO = 500
34+
"""The ratio of node size to label size. Used for graph node sizes."""
3135

3236

3337
__all__ = [
@@ -510,6 +514,22 @@ def scatter_criterion_vs_n_estimated(
510514
return ax
511515

512516

517+
def labels_to_node_sizes(G: nx.Graph) -> list[float]:
518+
"""Compute reasonable node sizes to scale with node label sizes."""
519+
node_sizes = []
520+
for label in G.nodes:
521+
height = 1
522+
width = len(label)
523+
if "\n" in label:
524+
height = len(label.split("\n"))
525+
width = len(sorted(label.split("\n"), key=lambda x: len(x))[-1])
526+
label_size = (
527+
width if width * FONT_HEIGHT_WIDTH_RATIO > height else height
528+
)
529+
node_sizes.append(label_size * NODE_SIZE_LABEL_SIZE_RATIO)
530+
return node_sizes
531+
532+
513533
def graph_iteration_layers(
514534
plot_data: PlotData,
515535
ax: plt.Axes = None,
@@ -541,7 +561,6 @@ def graph_iteration_layers(
541561
default_draw_networkx_kwargs = {
542562
"arrowstyle": "-|>",
543563
"node_shape": "s",
544-
"node_size": 250,
545564
"edgecolors": "k",
546565
}
547566
if draw_networkx_kwargs is None:
@@ -610,6 +629,11 @@ def graph_iteration_layers(
610629
# Apply custom labels. Need `copy=True` to preserve node ordering.
611630
G = nx.relabel_nodes(G, mapping=plot_data.labels, copy=True)
612631

632+
draw_networkx_kwargs["node_size"] = draw_networkx_kwargs.get(
633+
"node_size",
634+
labels_to_node_sizes(G=G),
635+
)
636+
613637
nx.draw_networkx(
614638
G, pos, ax=ax, node_color=node_colors, **draw_networkx_kwargs
615639
)

0 commit comments

Comments
 (0)