Skip to content

Commit fb87d44

Browse files
committed
Use fANOVA with Spearman fallback for parameter sorting, add sample image
1 parent 547d32b commit fb87d44

File tree

4 files changed

+25
-20
lines changed

4 files changed

+25
-20
lines changed

package/visualization/plot_beeswarm/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ plot_beeswarm = mod.plot_beeswarm
3030

3131

3232
def objective(trial: optuna.trial.Trial) -> float:
33-
x = trial.suggest_float("x", -5.0, 5.0)
34-
y = trial.suggest_float("y", -5.0, 5.0)
35-
z = trial.suggest_float("z", -5.0, 5.0)
36-
w = trial.suggest_float("w", -5.0, 5.0)
37-
return x**2 + 0.5 * y**2 + 0.1 * z**2 + 0.01 * w**2
33+
x = trial.suggest_float("x", 0.0, 10.0)
34+
y = trial.suggest_float("y", 0.0, 10.0)
35+
z = trial.suggest_float("z", 0.0, 10.0)
36+
w = trial.suggest_float("w", 0.0, 10.0)
37+
return 1.0 * x + 0.5 * y + 0.1 * z + 0.01 * w
3838

3939

4040
study = optuna.create_study()

package/visualization/plot_beeswarm/__init__.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,30 +131,36 @@ def _compute_density_jitter(
131131

132132

133133
def _sort_params_by_importance(
134+
study: optuna.Study,
134135
param_names: list[str],
135136
param_values: dict[str, np.ndarray],
136137
objectives: np.ndarray,
137138
) -> list[str]:
138-
"""Sort parameters by absolute Spearman-like correlation with objective.
139+
"""Sort parameters by importance (most important at top of plot).
139140
140-
Parameters with stronger monotonic relationships with the objective
141-
are placed at the top (higher index = plotted higher), matching the
142-
convention of SHAP beeswarm plots.
141+
Tries Optuna's built-in fANOVA-based importance first (handles non-monotonic
142+
relationships). Falls back to Spearman rank correlation if sklearn is not
143+
installed.
143144
"""
145+
# Try Optuna's fANOVA (requires sklearn).
146+
try:
147+
importances = optuna.importance.get_param_importances(study)
148+
return sorted(param_names, key=lambda n: importances.get(n, 0.0))
149+
except (ImportError, RuntimeError):
150+
pass
151+
152+
# Fallback: absolute Spearman rank correlation.
144153
correlations: dict[str, float] = {}
145154
for name in param_names:
146155
vals = param_values[name]
147-
# Use rank correlation (Spearman) without scipy.
148156
valid = np.isfinite(vals) & np.isfinite(objectives)
149157
if valid.sum() < 3:
150158
correlations[name] = 0.0
151159
continue
152160
v = vals[valid]
153161
o = objectives[valid]
154-
# Rank the arrays.
155162
v_rank = np.argsort(np.argsort(v)).astype(float)
156163
o_rank = np.argsort(np.argsort(o)).astype(float)
157-
# Pearson correlation of ranks = Spearman correlation.
158164
v_rank -= v_rank.mean()
159165
o_rank -= o_rank.mean()
160166
denom = np.sqrt((v_rank**2).sum() * (o_rank**2).sum())
@@ -163,7 +169,6 @@ def _sort_params_by_importance(
163169
else:
164170
correlations[name] = abs(float(np.dot(v_rank, o_rank) / denom))
165171

166-
# Sort ascending so that the most important param is at the top of the plot.
167172
return sorted(param_names, key=lambda n: correlations.get(n, 0.0))
168173

169174

@@ -215,7 +220,7 @@ def plot_beeswarm(
215220
raise ValueError("No valid parameters found in completed trials.")
216221

217222
# Sort parameters by importance (least important at bottom).
218-
sorted_params = _sort_params_by_importance(param_names, param_values, objectives)
223+
sorted_params = _sort_params_by_importance(study, param_names, param_values, objectives)
219224

220225
# Resolve colormap.
221226
cmap: Colormap = cm.get_cmap(color_map)

package/visualization/plot_beeswarm/example.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121

2222

2323
def objective(trial: optuna.trial.Trial) -> float:
24-
"""Multi-parameter objective with clear parameter-objective relationships."""
25-
x = trial.suggest_float("x", -5.0, 5.0)
26-
y = trial.suggest_float("y", -5.0, 5.0)
27-
z = trial.suggest_float("z", -5.0, 5.0)
28-
w = trial.suggest_float("w", -5.0, 5.0)
24+
"""Multi-parameter objective with clear monotonic relationships."""
25+
x = trial.suggest_float("x", 0.0, 10.0)
26+
y = trial.suggest_float("y", 0.0, 10.0)
27+
z = trial.suggest_float("z", 0.0, 10.0)
28+
w = trial.suggest_float("w", 0.0, 10.0)
2929
# x has the strongest influence, w the weakest.
30-
return x**2 + 0.5 * y**2 + 0.1 * z**2 + 0.01 * w**2
30+
return 1.0 * x + 0.5 * y + 0.1 * z + 0.01 * w
3131

3232

3333
if __name__ == "__main__":
102 KB
Loading

0 commit comments

Comments
 (0)