Skip to content

Commit 94c0435

Browse files
authored
Merge pull request #358 from c-bata/feat/plot-beeswarm
Follow-up #357: Update example.py for `plot_beeswarm`.
2 parents 7118930 + c737a94 commit 94c0435

File tree

5 files changed

+407
-0
lines changed

5 files changed

+407
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2026 Yasunori Morishima
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
---
2+
author: Yasunori Morishima
3+
title: SHAP-like Beeswarm Plot
4+
description: A SHAP-style beeswarm plot that visualizes the relationship between hyperparameter values and objective function values across trials. Density-based jitter reveals the distribution of trials.
5+
tags: [visualization, beeswarm, hyperparameter importance]
6+
optuna_versions: [4.7.0]
7+
license: MIT License
8+
---
9+
10+
## Class or Function Names
11+
12+
- `plot_beeswarm(study, *, params=None, target=None, target_name="Objective Value", color_map="RdBu_r", ax=None)`
13+
14+
- `study`: An Optuna study with completed trials.
15+
- `params`: A list of parameter names to include. If `None`, all parameters across completed trials are used.
16+
- `target`: A callable that extracts a scalar value from a `FrozenTrial`. Defaults to `trial.value`.
17+
- `target_name`: Label for the x-axis. Defaults to `"Objective Value"`.
18+
- `color_map`: Matplotlib colormap name. Defaults to `"RdBu_r"` (blue for low, red for high).
19+
- `ax`: Matplotlib axes to draw on. If `None`, a new figure is created.
20+
- **Returns**: A tuple of `(figure, axes, colorbar)`.
21+
22+
## Example
23+
24+
```python
25+
import optuna
26+
import optunahub
27+
28+
mod = optunahub.load_module(package="visualization/plot_beeswarm")
29+
plot_beeswarm = mod.plot_beeswarm
30+
31+
32+
def objective(trial: optuna.trial.Trial) -> float:
33+
x = trial.suggest_float("x", 0.0, 10.0)
34+
y = trial.suggest_float("y", 0.0, 10.0)
35+
z = trial.suggest_float("z", 0.0, 10.0)
36+
w = trial.suggest_float("w", 0.0, 10.0)
37+
return 1.0 * x + 0.5 * y + 0.1 * z + 0.01 * w
38+
39+
40+
study = optuna.create_study()
41+
study.optimize(objective, n_trials=500)
42+
43+
fig, ax, cbar = plot_beeswarm(study)
44+
```
45+
46+
![Beeswarm Plot Example](images/beeswarm.png)
47+
48+
## How It Works
49+
50+
Each row in the plot represents a hyperparameter. Each dot is one trial:
51+
52+
- **X-axis**: Objective function value
53+
- **Y-axis**: Parameter (rows), sorted by correlation with the objective (most important at top)
54+
- **Color**: Normalized parameter value (blue = low, red = high)
55+
- **Vertical spread**: Density-based jitter — wider where trials are concentrated
56+
57+
This makes it easy to spot monotonic relationships (e.g., "higher x leads to higher objective") at a glance.
58+
59+
## References
60+
61+
- Inspired by [SHAP beeswarm plots](https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/beeswarm.html)
62+
- Original feature request: [optuna/optuna#4987](https://github.com/optuna/optuna/issues/4987)
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Callable
4+
from collections.abc import Sequence
5+
from typing import TYPE_CHECKING
6+
7+
import matplotlib.cm as cm
8+
import matplotlib.colors as mcolors
9+
import matplotlib.pyplot as plt
10+
import numpy as np
11+
import optuna
12+
from optuna.trial import FrozenTrial
13+
14+
15+
if TYPE_CHECKING:
16+
from matplotlib.axes import Axes
17+
from matplotlib.colorbar import Colorbar
18+
from matplotlib.colors import Colormap
19+
from matplotlib.figure import Figure
20+
21+
22+
def _get_param_values_and_objectives(
23+
study: optuna.Study,
24+
params: Sequence[str] | None,
25+
target: Callable[[FrozenTrial], float] | None,
26+
) -> tuple[list[str], dict[str, np.ndarray], np.ndarray]:
27+
"""Extract parameter values and objective values from completed trials."""
28+
trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]
29+
if len(trials) == 0:
30+
raise ValueError("The study has no completed trials.")
31+
32+
if target is None:
33+
34+
def target(t: FrozenTrial) -> float:
35+
return t.value # type: ignore[return-value]
36+
37+
objectives = np.array([target(t) for t in trials])
38+
39+
# Determine parameters to plot.
40+
if params is None:
41+
all_param_names: set[str] = set()
42+
for t in trials:
43+
all_param_names.update(t.params.keys())
44+
param_names = sorted(all_param_names)
45+
else:
46+
param_names = list(params)
47+
48+
# Collect parameter values (numeric only; categoricals are label-encoded).
49+
param_values: dict[str, np.ndarray] = {}
50+
for name in param_names:
51+
vals: list[float] = []
52+
mask: list[bool] = []
53+
for t in trials:
54+
if name in t.params:
55+
v = t.params[name]
56+
if isinstance(v, (int, float)):
57+
vals.append(float(v))
58+
mask.append(True)
59+
elif isinstance(v, str):
60+
# Label-encode categorical values.
61+
vals.append(float(hash(v) % 10000))
62+
mask.append(True)
63+
else:
64+
vals.append(0.0)
65+
mask.append(False)
66+
else:
67+
vals.append(0.0)
68+
mask.append(False)
69+
arr = np.array(vals)
70+
m = np.array(mask)
71+
if m.sum() > 0:
72+
param_values[name] = arr
73+
74+
# Filter param_names to those with valid values.
75+
param_names = [n for n in param_names if n in param_values]
76+
return param_names, param_values, objectives
77+
78+
79+
def _compute_density_jitter(
80+
x: np.ndarray,
81+
*,
82+
nbins: int = 50,
83+
jitter_scale: float = 0.4,
84+
seed: int = 0,
85+
) -> np.ndarray:
86+
"""Compute density-based jitter for beeswarm layout.
87+
88+
This is the core of the beeswarm plot: points in dense regions of the
89+
x-axis get larger y-jitter, creating the characteristic "swarm" shape
90+
similar to SHAP beeswarm plots.
91+
92+
The approach uses histogram-based density estimation to avoid a scipy
93+
dependency, then applies Gaussian jitter scaled by the local density.
94+
95+
Args:
96+
x: 1-D array of x-positions (e.g. objective values).
97+
nbins: Number of bins for density estimation.
98+
jitter_scale: Maximum jitter magnitude (in row-index units).
99+
seed: Random seed for reproducibility.
100+
101+
Returns:
102+
1-D array of y-offsets.
103+
"""
104+
n = len(x)
105+
if n <= 1:
106+
return np.zeros(n)
107+
108+
# Histogram-based density estimation.
109+
x_min, x_max = x.min(), x.max()
110+
if x_max - x_min < 1e-12:
111+
# All x values are identical; uniform jitter.
112+
rng = np.random.default_rng(seed)
113+
return rng.uniform(-jitter_scale * 0.5, jitter_scale * 0.5, n)
114+
115+
hist, bin_edges = np.histogram(x, bins=nbins)
116+
bin_indices = np.digitize(x, bin_edges[:-1]) - 1
117+
bin_indices = np.clip(bin_indices, 0, len(hist) - 1)
118+
density = hist[bin_indices].astype(float)
119+
120+
# Normalize density to [0, 1].
121+
d_max = density.max()
122+
if d_max > 0:
123+
density /= d_max
124+
125+
# Apply Gaussian jitter scaled by density.
126+
rng = np.random.default_rng(seed)
127+
raw_jitter = rng.standard_normal(n)
128+
# Clip to avoid extreme outliers.
129+
raw_jitter = np.clip(raw_jitter, -2.5, 2.5) / 2.5
130+
return raw_jitter * density * jitter_scale
131+
132+
133+
def _sort_params_by_importance(
134+
study: optuna.Study,
135+
param_names: list[str],
136+
param_values: dict[str, np.ndarray],
137+
objectives: np.ndarray,
138+
) -> list[str]:
139+
"""Sort parameters by importance (most important at top of plot).
140+
141+
Tries Optuna's built-in fANOVA-based importance first (handles non-monotonic
142+
relationships). Falls back to Spearman rank correlation if sklearn is not
143+
installed.
144+
"""
145+
# Try Optuna's fANOVA (requires sklearn).
146+
try:
147+
importances = optuna.importance.get_param_importances(study)
148+
return sorted(param_names, key=lambda n: importances.get(n, 0.0))
149+
except (ImportError, RuntimeError):
150+
pass
151+
152+
# Fallback: absolute Spearman rank correlation.
153+
correlations: dict[str, float] = {}
154+
for name in param_names:
155+
vals = param_values[name]
156+
valid = np.isfinite(vals) & np.isfinite(objectives)
157+
if valid.sum() < 3:
158+
correlations[name] = 0.0
159+
continue
160+
v = vals[valid]
161+
o = objectives[valid]
162+
v_rank = np.argsort(np.argsort(v)).astype(float)
163+
o_rank = np.argsort(np.argsort(o)).astype(float)
164+
v_rank -= v_rank.mean()
165+
o_rank -= o_rank.mean()
166+
denom = np.sqrt((v_rank**2).sum() * (o_rank**2).sum())
167+
if denom < 1e-12:
168+
correlations[name] = 0.0
169+
else:
170+
correlations[name] = abs(float(np.dot(v_rank, o_rank) / denom))
171+
172+
return sorted(param_names, key=lambda n: correlations.get(n, 0.0))
173+
174+
175+
def plot_beeswarm(
176+
study: optuna.Study,
177+
*,
178+
params: Sequence[str] | None = None,
179+
target: Callable[[FrozenTrial], float] | None = None,
180+
target_name: str = "Objective Value",
181+
color_map: str = "RdBu_r",
182+
ax: Axes | None = None,
183+
) -> tuple[Figure, Axes, Colorbar]:
184+
"""Plot a SHAP-style beeswarm plot for an Optuna study.
185+
186+
Each row represents a hyperparameter. Each dot is one trial, positioned
187+
on the x-axis by its objective value. The dot color represents the
188+
parameter value (blue = low, red = high). In dense regions, dots are
189+
spread vertically to reveal the distribution (beeswarm layout).
190+
191+
This visualization is useful for understanding monotonic relationships
192+
between hyperparameter values and the objective function.
193+
194+
Args:
195+
study:
196+
An Optuna study with completed trials.
197+
params:
198+
A list of parameter names to include. If ``None``, all parameters
199+
across completed trials are used.
200+
target:
201+
A callable that extracts a scalar value from a
202+
:class:`~optuna.trial.FrozenTrial`. Defaults to
203+
``trial.value``.
204+
target_name:
205+
Label for the x-axis. Defaults to ``"Objective Value"``.
206+
color_map:
207+
Matplotlib colormap name for parameter value coloring.
208+
Defaults to ``"RdBu_r"`` (blue for low, red for high values).
209+
ax:
210+
Matplotlib axes to draw on. If ``None``, a new figure is created.
211+
212+
Returns:
213+
A tuple of ``(figure, axes, colorbar)``.
214+
215+
Raises:
216+
ValueError: If the study has no completed trials.
217+
"""
218+
param_names, param_values, objectives = _get_param_values_and_objectives(study, params, target)
219+
if len(param_names) == 0:
220+
raise ValueError("No valid parameters found in completed trials.")
221+
222+
# Sort parameters by importance (least important at bottom).
223+
sorted_params = _sort_params_by_importance(study, param_names, param_values, objectives)
224+
225+
# Resolve colormap.
226+
cmap: Colormap = cm.get_cmap(color_map)
227+
228+
# Create figure if needed.
229+
if ax is None:
230+
n_params = len(sorted_params)
231+
fig_height = max(3.0, 0.5 * n_params + 1.5)
232+
fig, ax = plt.subplots(figsize=(10, fig_height))
233+
else:
234+
fig = ax.get_figure()
235+
236+
# Plot each parameter as a row.
237+
norm = mcolors.Normalize(vmin=0.0, vmax=1.0)
238+
for row_idx, param_name in enumerate(sorted_params):
239+
vals = param_values[param_name]
240+
241+
# Normalize parameter values to [0, 1] for coloring.
242+
v_min, v_max = vals.min(), vals.max()
243+
if v_max - v_min < 1e-12:
244+
norm_vals = np.full_like(vals, 0.5)
245+
else:
246+
norm_vals = (vals - v_min) / (v_max - v_min)
247+
248+
# Compute density-based jitter.
249+
jitter = _compute_density_jitter(objectives, seed=row_idx)
250+
y_positions = row_idx + jitter
251+
252+
# Map normalized values to colors.
253+
colors = cmap(norm_vals)
254+
255+
ax.scatter(
256+
objectives,
257+
y_positions,
258+
c=colors,
259+
s=8,
260+
alpha=0.75,
261+
edgecolors="none",
262+
rasterized=True,
263+
)
264+
265+
# Configure axes.
266+
ax.set_yticks(range(len(sorted_params)))
267+
ax.set_yticklabels(sorted_params, fontsize=14)
268+
ax.set_xlabel(target_name, fontsize=14)
269+
ax.set_title("Beeswarm Plot", fontsize=16)
270+
ax.spines["top"].set_visible(False)
271+
ax.spines["right"].set_visible(False)
272+
ax.tick_params(axis="x", labelsize=12)
273+
274+
# Add colorbar.
275+
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
276+
sm.set_array([])
277+
cbar = fig.colorbar(sm, ax=ax, pad=0.02)
278+
cbar.set_label("Parameter value (normalized)", fontsize=12)
279+
cbar.set_ticks([0.0, 0.5, 1.0])
280+
cbar.set_ticklabels(["Low", "Mid", "High"])
281+
cbar.ax.tick_params(labelsize=12)
282+
283+
fig.tight_layout()
284+
return fig, ax, cbar
285+
286+
287+
__all__ = ["plot_beeswarm"]

0 commit comments

Comments
 (0)