@@ -131,30 +131,36 @@ def _compute_density_jitter(
131131
132132
133133def _sort_params_by_importance (
134+ study : optuna .Study ,
134135 param_names : list [str ],
135136 param_values : dict [str , np .ndarray ],
136137 objectives : np .ndarray ,
137138) -> list [str ]:
138- """Sort parameters by absolute Spearman-like correlation with objective .
139+ """Sort parameters by importance (most important at top of plot) .
139140
140- Parameters with stronger monotonic relationships with the objective
141- are placed at the top (higher index = plotted higher), matching the
142- convention of SHAP beeswarm plots .
141+ Tries Optuna's built-in fANOVA-based importance first (handles non-monotonic
142+ relationships). Falls back to Spearman rank correlation if sklearn is not
143+ installed .
143144 """
145+ # Try Optuna's fANOVA (requires sklearn).
146+ try :
147+ importances = optuna .importance .get_param_importances (study )
148+ return sorted (param_names , key = lambda n : importances .get (n , 0.0 ))
149+ except (ImportError , RuntimeError ):
150+ pass
151+
152+ # Fallback: absolute Spearman rank correlation.
144153 correlations : dict [str , float ] = {}
145154 for name in param_names :
146155 vals = param_values [name ]
147- # Use rank correlation (Spearman) without scipy.
148156 valid = np .isfinite (vals ) & np .isfinite (objectives )
149157 if valid .sum () < 3 :
150158 correlations [name ] = 0.0
151159 continue
152160 v = vals [valid ]
153161 o = objectives [valid ]
154- # Rank the arrays.
155162 v_rank = np .argsort (np .argsort (v )).astype (float )
156163 o_rank = np .argsort (np .argsort (o )).astype (float )
157- # Pearson correlation of ranks = Spearman correlation.
158164 v_rank -= v_rank .mean ()
159165 o_rank -= o_rank .mean ()
160166 denom = np .sqrt ((v_rank ** 2 ).sum () * (o_rank ** 2 ).sum ())
@@ -163,7 +169,6 @@ def _sort_params_by_importance(
163169 else :
164170 correlations [name ] = abs (float (np .dot (v_rank , o_rank ) / denom ))
165171
166- # Sort ascending so that the most important param is at the top of the plot.
167172 return sorted (param_names , key = lambda n : correlations .get (n , 0.0 ))
168173
169174
@@ -215,7 +220,7 @@ def plot_beeswarm(
215220 raise ValueError ("No valid parameters found in completed trials." )
216221
217222 # Sort parameters by importance (least important at bottom).
218- sorted_params = _sort_params_by_importance (param_names , param_values , objectives )
223+ sorted_params = _sort_params_by_importance (study , param_names , param_values , objectives )
219224
220225 # Resolve colormap.
221226 cmap : Colormap = cm .get_cmap (color_map )
0 commit comments