@@ -157,7 +157,7 @@ def plot_ice(
157
157
bartrv : Variable ,
158
158
X : npt .NDArray [np .float_ ],
159
159
Y : Optional [npt .NDArray [np .float_ ]] = None ,
160
- xs_interval : str = "linear " ,
160
+ xs_interval : str = "quantiles " ,
161
161
xs_values : Optional [Union [int , List [float ]]] = None ,
162
162
var_idx : Optional [List [int ]] = None ,
163
163
var_discrete : Optional [List [int ]] = None ,
@@ -303,7 +303,7 @@ def identity(x):
303
303
idx = np .argsort (new_x )
304
304
axes [count ].plot (new_x [idx ], p_di .mean (0 )[idx ], color = color_mean )
305
305
axes [count ].plot (new_x [idx ], p_di .T [idx ], color = color , alpha = alpha )
306
- axes [count ].set_xlabel (x_labels [var ])
306
+ axes [count ].set_xlabel (x_labels [var ])
307
307
308
308
count += 1
309
309
@@ -316,7 +316,7 @@ def plot_pdp(
316
316
bartrv : Variable ,
317
317
X : npt .NDArray [np .float_ ],
318
318
Y : Optional [npt .NDArray [np .float_ ]] = None ,
319
- xs_interval : str = "linear " ,
319
+ xs_interval : str = "quantiles " ,
320
320
xs_values : Optional [Union [int , List [float ]]] = None ,
321
321
var_idx : Optional [List [int ]] = None ,
322
322
var_discrete : Optional [List [int ]] = None ,
@@ -423,35 +423,39 @@ def identity(x):
423
423
p_d = _sample_posterior (
424
424
all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
425
425
)
426
- new_x = fake_X [:, var ]
427
- for s_i in range (shape ):
428
- p_di = func (p_d [:, :, s_i ])
429
- if var in var_discrete :
430
- y_means = p_di .mean (0 )
431
- hdi = az .hdi (p_di )
432
- axes [count ].errorbar (
433
- new_x ,
434
- y_means ,
435
- (y_means - hdi [:, 0 ], hdi [:, 1 ] - y_means ),
436
- fmt = "." ,
437
- color = color ,
438
- )
439
- else :
440
- az .plot_hdi (
441
- new_x ,
442
- p_di ,
443
- smooth = smooth ,
444
- fill_kwargs = {"alpha" : alpha , "color" : color },
445
- ax = axes [count ],
446
- )
447
- if smooth :
448
- x_data , y_data = _smooth_mean (new_x , p_di , "pdp" , smooth_kwargs )
449
- axes [count ].plot (x_data , y_data , color = color_mean )
426
+ with warnings .catch_warnings ():
427
+ warnings .filterwarnings ("ignore" , message = "hdi currently interprets 2d data" )
428
+ new_x = fake_X [:, var ]
429
+ for s_i in range (shape ):
430
+ p_di = func (p_d [:, :, s_i ])
431
+ if var in var_discrete :
432
+ _ , idx_uni = np .unique (new_x , return_index = True )
433
+ y_means = p_di .mean (0 )[idx_uni ]
434
+ hdi = az .hdi (p_di )[idx_uni ]
435
+ axes [count ].errorbar (
436
+ new_x [idx_uni ],
437
+ y_means ,
438
+ (y_means - hdi [:, 0 ], hdi [:, 1 ] - y_means ),
439
+ fmt = "." ,
440
+ color = color ,
441
+ )
442
+ axes [count ].set_xticks (new_x [idx_uni ])
450
443
else :
451
- axes [count ].plot (new_x , p_di .mean (0 ), color = color_mean )
444
+ az .plot_hdi (
445
+ new_x ,
446
+ p_di ,
447
+ smooth = smooth ,
448
+ fill_kwargs = {"alpha" : alpha , "color" : color },
449
+ ax = axes [count ],
450
+ )
451
+ if smooth :
452
+ x_data , y_data = _smooth_mean (new_x , p_di , "pdp" , smooth_kwargs )
453
+ axes [count ].plot (x_data , y_data , color = color_mean )
454
+ else :
455
+ axes [count ].plot (new_x , p_di .mean (0 ), color = color_mean )
452
456
axes [count ].set_xlabel (x_labels [var ])
453
457
454
- count += 1
458
+ count += 1
455
459
456
460
fig .text (- 0.05 , 0.5 , y_label , va = "center" , rotation = "vertical" , fontsize = 15 )
457
461
@@ -527,16 +531,20 @@ def _get_axes(
527
531
fig .delaxes (axes [i ])
528
532
axes = axes [:n_plots ]
529
533
else :
530
- axes = [ax ]
531
- fig = ax .get_figure ()
534
+ if isinstance (ax , np .ndarray ):
535
+ axes = ax
536
+ fig = ax [0 ].get_figure ()
537
+ else :
538
+ axes = [ax ]
539
+ fig = ax .get_figure () # type: ignore
532
540
533
541
return fig , axes , shape
534
542
535
543
536
544
def _prepare_plot_data (
537
545
X : npt .NDArray [np .float_ ],
538
546
Y : Optional [npt .NDArray [np .float_ ]] = None ,
539
- xs_interval : str = "linear " ,
547
+ xs_interval : str = "quantiles " ,
540
548
xs_values : Optional [Union [int , List [float ]]] = None ,
541
549
var_idx : Optional [List [int ]] = None ,
542
550
var_discrete : Optional [List [int ]] = None ,
@@ -710,7 +718,7 @@ def plot_variable_importance(
710
718
figsize : Optional [Tuple [float , float ]] = None ,
711
719
samples : int = 100 ,
712
720
random_seed : Optional [int ] = None ,
713
- ) -> Tuple [npt .NDArray [np .int_ ], List [plt .axes ]]:
721
+ ) -> Tuple [npt .NDArray [np .int_ ], List [plt .Axes ]]:
714
722
"""
715
723
Estimates variable importance from the BART-posterior.
716
724
0 commit comments