@@ -155,14 +155,12 @@ def plot_ice(
155
155
bartrv : Variable ,
156
156
X : npt .NDArray [np .float_ ],
157
157
Y : Optional [npt .NDArray [np .float_ ]] = None ,
158
- xs_interval : str = "quantiles" ,
159
- xs_values : Optional [Union [int , List [float ]]] = None ,
160
158
var_idx : Optional [List [int ]] = None ,
161
159
var_discrete : Optional [List [int ]] = None ,
162
160
func : Optional [Callable ] = None ,
163
161
centered : Optional [bool ] = True ,
164
- samples : int = 50 ,
165
- instances : int = 10 ,
162
+ samples : int = 100 ,
163
+ instances : int = 30 ,
166
164
random_seed : Optional [int ] = None ,
167
165
sharey : bool = True ,
168
166
smooth : bool = True ,
@@ -185,16 +183,6 @@ def plot_ice(
185
183
The covariate matrix.
186
184
Y : Optional[npt.NDArray[np.float_]], by default None.
187
185
The response vector.
188
- xs_interval : str
189
- Method used to compute the values X used to evaluate the predicted function. "linear",
190
- evenly spaced values in the range of X. "quantiles", the evaluation is done at the specified
191
- quantiles of X. "insample", the evaluation is done at the values of X.
192
- For discrete variables these options are ommited.
193
- xs_values : Optional[Union[int, List[float]]], by default None.
194
- Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
195
- points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of
196
- quantiles to compute, which must be between 0 and 1 inclusive.
197
- Ignored when ``xs_interval="insample"``.
198
186
var_idx : Optional[List[int]], by default None.
199
187
List of the indices of the covariate for which to compute the pdp or ice.
200
188
var_discrete : Optional[List[int]], by default None.
@@ -205,22 +193,20 @@ def plot_ice(
205
193
If True the result is centered around the partial response evaluated at the lowest value in
206
194
``xs_interval``. Defaults to True.
207
195
samples : int
208
- Number of posterior samples used in the predictions. Defaults to 50
196
+ Number of posterior samples used in the predictions. Defaults to 100
209
197
instances : int
210
- Number of instances of X to plot. Defaults to 10 .
198
+ Number of instances of X to plot. Defaults to 30 .
211
199
random_seed : Optional[int], by default None.
212
200
Seed used to sample from the posterior. Defaults to None.
213
201
sharey : bool
214
202
Controls sharing of properties among y-axes. Defaults to True.
215
- rug : bool
216
- Whether to include a rugplot. Defaults to True.
217
203
smooth : bool
218
204
If True the result will be smoothed by first computing a linear interpolation of the data
219
205
over a regular grid and then applying the Savitzky-Golay filter to the interpolated data.
220
206
Defaults to True.
221
207
grid : str or tuple
222
208
How to arrange the subplots. Defaults to "long", one subplot below the other.
223
- Other options are "wide", one subplot next to eachother or a tuple indicating the number of
209
+ Other options are "wide", one subplot next to each other or a tuple indicating the number of
224
210
rows and columns.
225
211
color : matplotlib valid color
226
212
Color used to plot the pdp or ice. Defaults to "C0"
@@ -257,17 +243,17 @@ def identity(x):
257
243
indices ,
258
244
var_idx ,
259
245
var_discrete ,
260
- xs_interval ,
261
- xs_values ,
262
- ) = _prepare_plot_data (X , Y , xs_interval , xs_values , var_idx , var_discrete )
246
+ _ ,
247
+ _ ,
248
+ ) = _prepare_plot_data (X , Y , "linear" , None , var_idx , var_discrete )
263
249
264
250
fig , axes , shape = _get_axes (bartrv , var_idx , grid , sharey , figsize , ax )
265
251
266
252
instances_ary = rng .choice (range (X .shape [0 ]), replace = False , size = instances )
267
253
idx_s = list (range (X .shape [0 ]))
268
254
269
255
count = 0
270
- for var in range ( len ( var_idx ) ):
256
+ for i_var , var in enumerate ( var_idx ):
271
257
indices_mi = indices [:]
272
258
indices_mi .remove (var )
273
259
y_pred = []
@@ -283,6 +269,7 @@ def identity(x):
283
269
284
270
new_x = fake_X [:, var ]
285
271
p_d = np .array (y_pred )
272
+ print (p_d .shape )
286
273
287
274
for s_i in range (shape ):
288
275
if centered :
@@ -301,7 +288,7 @@ def identity(x):
301
288
idx = np .argsort (new_x )
302
289
axes [count ].plot (new_x [idx ], p_di .mean (0 )[idx ], color = color_mean )
303
290
axes [count ].plot (new_x [idx ], p_di .T [idx ], color = color , alpha = alpha )
304
- axes [count ].set_xlabel (x_labels [var ])
291
+ axes [count ].set_xlabel (x_labels [i_var ])
305
292
306
293
count += 1
307
294
@@ -349,7 +336,7 @@ def plot_pdp(
349
336
For discrete variables these options are ommited.
350
337
xs_values : Optional[Union[int, List[float]]], by default None.
351
338
Values of X used to evaluate the predicted function. If ``xs_interval="linear"`` number of
352
- points in the evenly spaced grid. If ``xs_interval="quantiles"``quantile or sequence of
339
+ points in the evenly spaced grid. If ``xs_interval="quantiles"`` quantile or sequence of
353
340
quantiles to compute, which must be between 0 and 1 inclusive.
354
341
Ignored when ``xs_interval="insample"``.
355
342
var_idx : Optional[List[int]], by default None.
@@ -717,7 +704,8 @@ def plot_variable_importance(
717
704
xlabel_angle : float = 0 ,
718
705
samples : int = 100 ,
719
706
random_seed : Optional [int ] = None ,
720
- ) -> Tuple [List [int ], List [plt .Axes ]]:
707
+ ax : Optional [plt .Axes ] = None ,
708
+ ) -> Tuple [List [int ], Union [List [plt .Axes ], Any ]]:
721
709
"""
722
710
Estimates variable importance from the BART-posterior.
723
711
@@ -747,6 +735,8 @@ def plot_variable_importance(
747
735
Number of predictions used to compute correlation for subsets of variables. Defaults to 100
748
736
random_seed : Optional[int]
749
737
random_seed used to sample from the posterior. Defaults to None.
738
+ ax : axes
739
+ Matplotlib axes.
750
740
751
741
Returns
752
742
-------
@@ -771,7 +761,8 @@ def plot_variable_importance(
771
761
if figsize is None :
772
762
figsize = (8 , 3 )
773
763
774
- _ , ax = plt .subplots (1 , 1 , figsize = figsize )
764
+ if ax is None :
765
+ _ , ax = plt .subplots (1 , 1 , figsize = figsize )
775
766
776
767
if labels is None :
777
768
labels_ary = np .arange (n_vars ).astype (str )
0 commit comments