21
21
def _sample_posterior (
22
22
all_trees : List [List [Tree ]],
23
23
X : TensorLike ,
24
- m : int ,
25
24
rng : np .random .Generator ,
26
25
size : Optional [Union [int , Tuple [int , ...]]] = None ,
27
26
excluded : Optional [List [int ]] = None ,
@@ -37,8 +36,6 @@ def _sample_posterior(
37
36
X : tensor-like
38
37
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
39
38
out-of-sample predictions.
40
- m : int
41
- Number of trees
42
39
rng : NumPy RandomGenerator
43
40
size : int or tuple
44
41
Number of samples.
@@ -66,7 +63,7 @@ def _sample_posterior(
66
63
67
64
for ind , p in enumerate (pred ):
68
65
for tree in stacked_trees [idx [ind ]]:
69
- p += np .vstack ([tree .predict (x = x , m = m , excluded = excluded , shape = shape ) for x in X ])
66
+ p += np .vstack ([tree .predict (x = x , excluded = excluded , shape = shape ) for x in X ])
70
67
pred .reshape ((* size_iter , shape , - 1 ))
71
68
return pred
72
69
@@ -239,7 +236,6 @@ def plot_ice(
239
236
axes: matplotlib axes
240
237
"""
241
238
all_trees = bartrv .owner .op .all_trees
242
- m : int = bartrv .owner .op .m
243
239
rng = np .random .default_rng (random_seed )
244
240
245
241
if func is None :
@@ -271,7 +267,7 @@ def plot_ice(
271
267
fake_X [:, indices_mi ] = X [:, indices_mi ][instance ]
272
268
y_pred .append (
273
269
np .mean (
274
- _sample_posterior (all_trees , X = fake_X , m = m , rng = rng , size = samples , shape = shape ),
270
+ _sample_posterior (all_trees , X = fake_X , rng = rng , size = samples , shape = shape ),
275
271
0 ,
276
272
)
277
273
)
@@ -386,7 +382,6 @@ def plot_pdp(
386
382
axes: matplotlib axes
387
383
"""
388
384
all_trees : list = bartrv .owner .op .all_trees
389
- m : int = bartrv .owner .op .m
390
385
rng = np .random .default_rng (random_seed )
391
386
392
387
if func is None :
@@ -411,7 +406,7 @@ def plot_pdp(
411
406
excluded .remove (var )
412
407
fake_X , new_x = _create_pdp_data (X , xs_interval , var , xs_values , var_discrete )
413
408
p_d = _sample_posterior (
414
- all_trees , X = fake_X , m = m , rng = rng , size = samples , excluded = excluded , shape = shape
409
+ all_trees , X = fake_X , rng = rng , size = samples , excluded = excluded , shape = shape
415
410
)
416
411
417
412
for s_i in range (shape ):
@@ -738,8 +733,6 @@ def plot_variable_importance(
738
733
"""
739
734
_ , axes = plt .subplots (2 , 1 , figsize = figsize )
740
735
741
- m : int = bartrv .owner .op .m
742
-
743
736
if bartrv .ndim == 1 : # type: ignore
744
737
shape = 1
745
738
else :
@@ -775,7 +768,7 @@ def plot_variable_importance(
775
768
all_trees = bartrv .owner .op .all_trees
776
769
777
770
predicted_all = _sample_posterior (
778
- all_trees , X = X , m = m , rng = rng , size = samples , excluded = None , shape = shape
771
+ all_trees , X = X , rng = rng , size = samples , excluded = None , shape = shape
779
772
)
780
773
781
774
ev_mean = np .zeros (len (var_imp ))
@@ -784,7 +777,6 @@ def plot_variable_importance(
784
777
predicted_subset = _sample_posterior (
785
778
all_trees = all_trees ,
786
779
X = X ,
787
- m = m ,
788
780
rng = rng ,
789
781
size = samples ,
790
782
excluded = subset ,
0 commit comments