4
4
import matplotlib .pyplot as plt
5
5
import numpy as np
6
6
7
+ from aesara .tensor .var import Variable
7
8
from numpy .random import RandomState
8
9
from scipy .interpolate import griddata
9
10
from scipy .signal import savgol_filter
10
11
from scipy .stats import pearsonr
11
12
12
13
13
- def predict (idata , rng , X , size = None , excluded = None ):
14
+ def predict (bartrv , rng , X , size = None , excluded = None ):
14
15
"""
15
16
Generate samples from the BART-posterior.
16
17
17
18
Parameters
18
19
----------
19
- idata : InferenceData
20
- InferenceData containing a collection of BART_trees in sample_stats group
20
+ bartrv : BART Random Variable
21
+ BART variable once the model that include it has been fitted.
21
22
rng: NumPy random generator
22
23
X : array-like
23
24
A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
@@ -27,8 +28,10 @@ def predict(idata, rng, X, size=None, excluded=None):
27
28
excluded : list
28
29
indexes of the variables to exclude when computing predictions
29
30
"""
30
- bart_trees = idata .sample_stats .bart_trees
31
- stacked_trees = bart_trees .stack (trees = ["chain" , "draw" ])
31
+ stacked_trees = bartrv .owner .op .all_trees
32
+ if isinstance (X , Variable ):
33
+ X = X .eval ()
34
+
32
35
if size is None :
33
36
size = ()
34
37
elif isinstance (size , int ):
@@ -38,20 +41,49 @@ def predict(idata, rng, X, size=None, excluded=None):
38
41
for s in size :
39
42
flatten_size *= s
40
43
41
- idx = rng .randint (len (stacked_trees . trees ), size = flatten_size )
42
- shape = stacked_trees . isel ( trees = 0 ). values [0 ].predict (X [0 ]).size
44
+ idx = rng .randint (len (stacked_trees ), size = flatten_size )
45
+ shape = stacked_trees [ 0 ] [0 ].predict (X [0 ]).size
43
46
44
47
pred = np .zeros ((flatten_size , X .shape [0 ], shape ))
45
48
46
49
for ind , p in enumerate (pred ):
47
- for tree in stacked_trees . isel ( trees = idx [ind ]). values :
50
+ for tree in stacked_trees [ idx [ind ]] :
48
51
p += np .array ([tree .predict (x , excluded ) for x in X ])
49
52
pred .reshape ((* size , shape , - 1 ))
50
53
return pred
51
54
52
55
56
+ def sample_posterior (all_trees , X ):
57
+ """
58
+ Generate samples from the BART-posterior.
59
+
60
+ Parameters
61
+ ----------
62
+ all_trees : list
63
+ List of all trees sampled from a posterior
64
+ X : array-like
65
+ A covariate matrix. Use the same used to fit BART for in-sample predictions or a new one for
66
+ out-of-sample predictions.
67
+ m : int
68
+ Number of trees
69
+ """
70
+ stacked_trees = all_trees
71
+ idx = np .random .randint (len (stacked_trees ))
72
+ if isinstance (X , Variable ):
73
+ X = X .eval ()
74
+
75
+ shape = stacked_trees [0 ][0 ].predict (X [0 ]).size
76
+
77
+ pred = np .zeros ((1 , X .shape [0 ], shape ))
78
+
79
+ for p in pred :
80
+ for tree in stacked_trees [idx ]:
81
+ p += np .array ([tree .predict (x ) for x in X ])
82
+ return pred .squeeze ()
83
+
84
+
53
85
def plot_dependence (
54
- idata ,
86
+ bartrv ,
55
87
X ,
56
88
Y = None ,
57
89
kind = "pdp" ,
@@ -79,8 +111,8 @@ def plot_dependence(
79
111
80
112
Parameters
81
113
----------
82
- idata: InferenceData
83
- InferenceData containing a collection of BART_trees in sample_stats group
114
+ bartrv : BART Random Variable
115
+ BART variable once the model that include it has been fitted.
84
116
X : array-like
85
117
The covariate matrix.
86
118
Y : array-like
@@ -149,6 +181,9 @@ def plot_dependence(
149
181
150
182
rng = RandomState (seed = random_seed )
151
183
184
+ if isinstance (X , Variable ):
185
+ X = X .eval ()
186
+
152
187
if hasattr (X , "columns" ) and hasattr (X , "values" ):
153
188
x_names = list (X .columns )
154
189
X = X .values
@@ -207,13 +242,13 @@ def plot_dependence(
207
242
for x_i in new_x_i :
208
243
new_X [:, indices_mi ] = X [:, indices_mi ]
209
244
new_X [:, i ] = x_i
210
- y_pred .append (np .mean (predict (idata , rng , X = new_X , size = samples ), 1 ))
245
+ y_pred .append (np .mean (predict (bartrv , rng , X = new_X , size = samples ), 1 ))
211
246
new_x_target .append (new_x_i )
212
247
else :
213
248
for instance in instances :
214
249
new_X = X [idx_s ]
215
250
new_X [:, indices_mi ] = X [:, indices_mi ][instance ]
216
- y_pred .append (np .mean (predict (idata , rng , X = new_X , size = samples ), 0 ))
251
+ y_pred .append (np .mean (predict (bartrv , rng , X = new_X , size = samples ), 0 ))
217
252
new_x_target .append (new_X [:, i ])
218
253
y_mins .append (np .min (y_pred ))
219
254
new_y .append (np .array (y_pred ).T )
@@ -310,7 +345,7 @@ def plot_dependence(
310
345
311
346
312
347
def plot_variable_importance (
313
- idata , X , labels = None , sort_vars = True , figsize = None , samples = 100 , random_seed = None
348
+ idata , bartrv , X , labels = None , sort_vars = True , figsize = None , samples = 100 , random_seed = None
314
349
):
315
350
"""
316
351
Estimates variable importance from the BART-posterior.
@@ -319,6 +354,8 @@ def plot_variable_importance(
319
354
----------
320
355
idata: InferenceData
321
356
InferenceData containing a collection of BART_trees in sample_stats group
357
+ bartrv : BART Random Variable
358
+ BART variable once the model that include it has been fitted.
322
359
X : array-like
323
360
The covariate matrix.
324
361
labels : list
@@ -365,12 +402,12 @@ def plot_variable_importance(
365
402
axes [0 ].set_xlabel ("covariables" )
366
403
axes [0 ].set_ylabel ("importance" )
367
404
368
- predicted_all = predict (idata , rng , X = X , size = samples , excluded = None )
405
+ predicted_all = predict (bartrv , rng , X = X , size = samples , excluded = None )
369
406
370
407
ev_mean = np .zeros (len (var_imp ))
371
408
ev_hdi = np .zeros ((len (var_imp ), 2 ))
372
409
for idx , subset in enumerate (subsets ):
373
- predicted_subset = predict (idata , rng , X = X , size = samples , excluded = subset )
410
+ predicted_subset = predict (bartrv , rng , X = X , size = samples , excluded = subset )
374
411
pearson = np .zeros (samples )
375
412
for j in range (samples ):
376
413
pearson [j ] = (
0 commit comments