@@ -80,47 +80,43 @@ def __init__(
80
80
else :
81
81
self .X = self .bart .X
82
82
83
- self .Y = self .bart .Y
84
83
self .missing_data = np .any (np .isnan (self .X ))
85
84
self .m = self .bart .m
86
- self .alpha = self .bart .alpha
87
85
shape = initial_values [value_bart .name ].shape
88
86
if len (shape ) == 1 :
89
87
self .shape = 1
90
88
else :
91
89
self .shape = shape [0 ]
92
90
93
- # self.alpha_vec = self.bart.split_prior
94
91
if self .bart .split_prior :
95
92
self .alpha_vec = self .bart .split_prior
96
93
else :
97
94
self .alpha_vec = np .ones (self .X .shape [1 ])
98
- self . init_mean = self .Y .mean ()
95
+ init_mean = self . bart .Y .mean ()
99
96
# if data is binary
100
- y_unique = np .unique (self .Y )
97
+ y_unique = np .unique (self .bart . Y )
101
98
if y_unique .size == 2 and np .all (y_unique == [0 , 1 ]):
102
99
mu_std = 3 / self .m ** 0.5
103
- # maybe we need to check for count data
104
100
else :
105
- mu_std = self .Y .std () / self .m ** 0.5
101
+ mu_std = self .bart . Y .std () / self .m ** 0.5
106
102
107
103
self .num_observations = self .X .shape [0 ]
108
104
self .num_variates = self .X .shape [1 ]
109
105
self .available_predictors = list (range (self .num_variates ))
110
106
111
- self .sum_trees = np .full ((self .shape , self .Y .shape [0 ]), self . init_mean ).astype (
107
+ self .sum_trees = np .full ((self .shape , self .bart . Y .shape [0 ]), init_mean ).astype (
112
108
config .floatX
113
109
)
114
- self .sum_trees_noi = self .sum_trees - (self . init_mean / self .m )
110
+ self .sum_trees_noi = self .sum_trees - (init_mean / self .m )
115
111
self .a_tree = Tree (
116
- leaf_node_value = self . init_mean / self .m ,
112
+ leaf_node_value = init_mean / self .m ,
117
113
idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
118
114
num_observations = self .num_observations ,
119
115
shape = self .shape ,
120
116
)
121
117
self .normal = NormalSampler (mu_std , self .shape )
122
118
self .uniform = UniformSampler (0.33 , 0.75 , self .shape )
123
- self .prior_prob_leaf_node = compute_prior_probability (self .alpha )
119
+ self .prior_prob_leaf_node = compute_prior_probability (self .bart . alpha )
124
120
self .ssv = SampleSplittingVariable (self .alpha_vec )
125
121
126
122
self .tune = True
@@ -143,7 +139,7 @@ def __init__(
143
139
self .likelihood_logp = logp (initial_values , [model .datalogp ], vars , shared )
144
140
self .all_particles = []
145
141
for _ in range (self .m ):
146
- self .a_tree .leaf_node_value = self . init_mean / self .m
142
+ self .a_tree .leaf_node_value = init_mean / self .m
147
143
p = ParticleTree (self .a_tree )
148
144
self .all_particles .append (p )
149
145
self .all_trees = np .array ([p .tree for p in self .all_particles ])
0 commit comments