@@ -137,22 +137,6 @@ def plot_convergence(
137137 return ax
138138
139139
140- def plot_dependence (* args , kind = "pdp" , ** kwargs ): # pylint: disable=unused-argument
141- """
142- Partial dependence or individual conditional expectation plot.
143- """
144- if kind == "pdp" :
145- warnings .warn (
146- "This function has been deprecated. Use plot_pdp instead." ,
147- FutureWarning ,
148- )
149- elif kind == "ice" :
150- warnings .warn (
151- "This function has been deprecated. Use plot_ice instead." ,
152- FutureWarning ,
153- )
154-
155-
156140def plot_ice (
157141 bartrv : Variable ,
158142 X : npt .NDArray [np .float64 ],
@@ -307,6 +291,7 @@ def plot_pdp(
307291 var_discrete : Optional [list [int ]] = None ,
308292 func : Optional [Callable ] = None ,
309293 samples : int = 200 ,
294+ ref_line : bool = True ,
310295 random_seed : Optional [int ] = None ,
311296 sharey : bool = True ,
312297 smooth : bool = True ,
@@ -347,6 +332,8 @@ def plot_pdp(
347332 Arbitrary function to apply to the predictions. Defaults to the identity function.
348333 samples : int
349334 Number of posterior samples used in the predictions. Defaults to 200
335+ ref_line : bool
336+ If True a reference line is plotted at the mean of the partial dependence. Defaults to True.
350337 random_seed : Optional[int], by default None.
351338 Seed used to sample from the posterior. Defaults to None.
352339 sharey : bool
@@ -402,6 +389,7 @@ def identity(x):
402389
403390 count = 0
404391 fake_X = _create_pdp_data (X , xs_interval , xs_values )
392+ null_pd = []
405393 for var in range (len (var_idx )):
406394 excluded = indices [:]
407395 excluded .remove (var )
@@ -413,6 +401,7 @@ def identity(x):
413401 new_x = fake_X [:, var ]
414402 for s_i in range (shape ):
415403 p_di = func (p_d [:, :, s_i ])
404+ null_pd .append (p_di .mean ())
416405 if var in var_discrete :
417406 _ , idx_uni = np .unique (new_x , return_index = True )
418407 y_means = p_di .mean (0 )[idx_uni ]
@@ -442,6 +431,11 @@ def identity(x):
442431
443432 count += 1
444433
434+ if ref_line :
435+ ref_val = sum (null_pd ) / len (null_pd )
436+ for ax_ in np .ravel (axes ):
437+ ax_ .axhline (ref_val , color = "0.7" , linestyle = "--" )
438+
445439 fig .text (- 0.05 , 0.5 , y_label , va = "center" , rotation = "vertical" , fontsize = 15 )
446440
447441 return axes
@@ -949,11 +943,13 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
949943
950944 indices = least_important_vars [::- 1 ]
951945
952- labels = np .array (["+ " + ele if index != 0 else ele for index , ele in enumerate (labels )])
946+ labels = np .array (
947+ ["+ " + ele if index != 0 else ele for index , ele in enumerate (labels [indices ])]
948+ )
953949
954950 vi_results = {
955951 "indices" : np .asarray (indices ),
956- "labels" : labels [ indices ] ,
952+ "labels" : labels ,
957953 "r2_mean" : r2_mean ,
958954 "r2_hdi" : r2_hdi ,
959955 "preds" : preds ,
0 commit comments