5
5
import numpy as np
6
6
7
7
from aesara .tensor .var import Variable
8
- from numpy .random import RandomState
9
8
from scipy .interpolate import griddata
10
9
from scipy .signal import savgol_filter
11
10
from scipy .stats import pearsonr
12
11
13
12
14
- def predict ( bartrv , rng , X , size = None , excluded = None ):
13
+ def _sample_posterior ( all_trees , X , rng , size = None , excluded = None ):
15
14
"""
16
15
Generate samples from the BART-posterior.
17
16
18
17
Parameters
19
18
----------
20
- bartrv : BART Random Variable
21
- BART variable once the model that include it has been fitted.
22
- rng: NumPy random generator
19
+ all_trees : list
20
+ List of all trees sampled from a posterior
23
21
X : array-like
24
22
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
25
23
out-of-sample predictions.
24
+ rng : NumPy RandomGenerator
26
25
size : int or tuple
27
26
Number of samples.
28
27
excluded : list
29
- indexes of the variables to exclude when computing predictions
28
+ Indexes of the variables to exclude when computing predictions
30
29
"""
31
- stacked_trees = bartrv . owner . op . all_trees
30
+ stacked_trees = all_trees
32
31
if isinstance (X , Variable ):
33
32
X = X .eval ()
34
33
@@ -41,7 +40,7 @@ def predict(bartrv, rng, X, size=None, excluded=None):
41
40
for s in size :
42
41
flatten_size *= s
43
42
44
- idx = rng .randint ( len (stacked_trees ), size = flatten_size )
43
+ idx = rng .integers ( 0 , len (stacked_trees ), size = flatten_size )
45
44
shape = stacked_trees [0 ][0 ].predict (X [0 ]).size
46
45
47
46
pred = np .zeros ((flatten_size , X .shape [0 ], shape ))
@@ -53,35 +52,6 @@ def predict(bartrv, rng, X, size=None, excluded=None):
53
52
return pred
54
53
55
54
56
- def sample_posterior (all_trees , X ):
57
- """
58
- Generate samples from the BART-posterior.
59
-
60
- Parameters
61
- ----------
62
- all_trees : list
63
- List of all trees sampled from a posterior
64
- X : array-like
65
- A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
66
- out-of-sample predictions.
67
- m : int
68
- Number of trees
69
- """
70
- stacked_trees = all_trees
71
- idx = np .random .randint (len (stacked_trees ))
72
- if isinstance (X , Variable ):
73
- X = X .eval ()
74
-
75
- shape = stacked_trees [0 ][0 ].predict (X [0 ]).size
76
-
77
- pred = np .zeros ((1 , X .shape [0 ], shape ))
78
-
79
- for p in pred :
80
- for tree in stacked_trees [idx ]:
81
- p += np .array ([tree .predict (x ) for x in X ])
82
- return pred .squeeze ()
83
-
84
-
85
55
def plot_dependence (
86
56
bartrv ,
87
57
X ,
@@ -179,8 +149,6 @@ def plot_dependence(
179
149
Available option are 'insample', 'linear' or 'quantiles'"""
180
150
)
181
151
182
- rng = RandomState (seed = random_seed )
183
-
184
152
if isinstance (X , Variable ):
185
153
X = X .eval ()
186
154
@@ -195,6 +163,8 @@ def plot_dependence(
195
163
else :
196
164
y_label = "Predicted Y"
197
165
166
+ rng = np .random .default_rng (random_seed )
167
+
198
168
num_covariates = X .shape [1 ]
199
169
200
170
indices = list (range (num_covariates ))
@@ -216,14 +186,15 @@ def plot_dependence(
216
186
xs_values = [0.05 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.95 ]
217
187
218
188
if kind == "ice" :
219
- instances = np . random .choice (range (X .shape [0 ]), replace = False , size = instances )
189
+ instances = rng .choice (range (X .shape [0 ]), replace = False , size = instances )
220
190
221
191
new_y = []
222
192
new_x_target = []
223
193
y_mins = []
224
194
225
195
new_X = np .zeros_like (X )
226
196
idx_s = list (range (X .shape [0 ]))
197
+ all_trees = bartrv .owner .op .all_trees
227
198
for i in var_idx :
228
199
indices_mi = indices [:]
229
200
indices_mi .pop (i )
@@ -242,13 +213,17 @@ def plot_dependence(
242
213
for x_i in new_x_i :
243
214
new_X [:, indices_mi ] = X [:, indices_mi ]
244
215
new_X [:, i ] = x_i
245
- y_pred .append (np .mean (predict (bartrv , rng , X = new_X , size = samples ), 1 ))
216
+ y_pred .append (
217
+ np .mean (_sample_posterior (all_trees , X = new_X , rng = rng , size = samples ), 1 )
218
+ )
246
219
new_x_target .append (new_x_i )
247
220
else :
248
221
for instance in instances :
249
222
new_X = X [idx_s ]
250
223
new_X [:, indices_mi ] = X [:, indices_mi ][instance ]
251
- y_pred .append (np .mean (predict (bartrv , rng , X = new_X , size = samples ), 0 ))
224
+ y_pred .append (
225
+ np .mean (_sample_posterior (all_trees , X = new_X , rng = rng , size = samples ), 0 )
226
+ )
252
227
new_x_target .append (new_X [:, i ])
253
228
y_mins .append (np .min (y_pred ))
254
229
new_y .append (np .array (y_pred ).T )
@@ -328,7 +303,7 @@ def plot_dependence(
328
303
nxi ,
329
304
nyi ,
330
305
smooth = smooth ,
331
- fill_kwargs = {"alpha" : alpha },
306
+ fill_kwargs = {"alpha" : alpha , "color" : color },
332
307
ax = ax ,
333
308
)
334
309
ax .plot (nxi [idx ], nyi [idx ].mean (0 ), color = color )
@@ -374,7 +349,6 @@ def plot_variable_importance(
374
349
idxs: indexes of the covariates from higher to lower relative importance
375
350
axes: matplotlib axes
376
351
"""
377
- rng = RandomState (seed = random_seed )
378
352
_ , axes = plt .subplots (2 , 1 , figsize = figsize )
379
353
380
354
if hasattr (X , "columns" ) and hasattr (X , "values" ):
@@ -387,6 +361,8 @@ def plot_variable_importance(
387
361
else :
388
362
labels = np .array (labels )
389
363
364
+ rng = np .random .default_rng (random_seed )
365
+
390
366
ticks = np .arange (len (var_imp ), dtype = int )
391
367
idxs = np .argsort (var_imp )
392
368
subsets = [idxs [:- i ] for i in range (1 , len (idxs ))]
@@ -402,12 +378,14 @@ def plot_variable_importance(
402
378
axes [0 ].set_xlabel ("covariables" )
403
379
axes [0 ].set_ylabel ("importance" )
404
380
405
- predicted_all = predict (bartrv , rng , X = X , size = samples , excluded = None )
381
+ all_trees = bartrv .owner .op .all_trees
382
+
383
+ predicted_all = _sample_posterior (all_trees , X = X , rng = rng , size = samples , excluded = None )
406
384
407
385
ev_mean = np .zeros (len (var_imp ))
408
386
ev_hdi = np .zeros ((len (var_imp ), 2 ))
409
387
for idx , subset in enumerate (subsets ):
410
- predicted_subset = predict ( bartrv , rng , X = X , size = samples , excluded = subset )
388
+ predicted_subset = _sample_posterior ( all_trees , X = X , rng = rng , size = samples , excluded = subset )
411
389
pearson = np .zeros (samples )
412
390
for j in range (samples ):
413
391
pearson [j ] = (
0 commit comments