@@ -138,6 +138,11 @@ def __init__( # noqa: PLR0915
138
138
else :
139
139
self .X = self .bart .X
140
140
141
+ if isinstance (self .bart .Y , Variable ):
142
+ self .Y = self .bart .Y .eval ()
143
+ else :
144
+ self .Y = self .bart .Y
145
+
141
146
self .missing_data = np .any (np .isnan (self .X ))
142
147
self .m = self .bart .m
143
148
self .response = self .bart .response
@@ -166,26 +171,26 @@ def __init__( # noqa: PLR0915
166
171
if rule is ContinuousSplitRule :
167
172
self .X [:, idx ] = jitter_duplicated (self .X [:, idx ], np .nanstd (self .X [:, idx ]))
168
173
169
- init_mean = self .bart . Y .mean ()
174
+ init_mean = self .Y .mean ()
170
175
self .num_observations = self .X .shape [0 ]
171
176
self .num_variates = self .X .shape [1 ]
172
177
self .available_predictors = list (range (self .num_variates ))
173
178
174
179
# if data is binary
175
180
self .leaf_sd = np .ones ((self .trees_shape , self .leaves_shape ))
176
181
177
- y_unique = np .unique (self .bart . Y )
182
+ y_unique = np .unique (self .Y )
178
183
if y_unique .size == 2 and np .all (y_unique == [0 , 1 ]):
179
184
self .leaf_sd *= 3 / self .m ** 0.5
180
185
else :
181
- self .leaf_sd *= self .bart . Y .std () / self .m ** 0.5
186
+ self .leaf_sd *= self .Y .std () / self .m ** 0.5
182
187
183
188
self .running_sd = [
184
189
RunningSd ((self .leaves_shape , self .num_observations )) for _ in range (self .trees_shape )
185
190
]
186
191
187
192
self .sum_trees = np .full (
188
- (self .trees_shape , self .leaves_shape , self .bart . Y .shape [0 ]), init_mean
193
+ (self .trees_shape , self .leaves_shape , self .Y .shape [0 ]), init_mean
189
194
).astype (config .floatX )
190
195
self .sum_trees_noi = self .sum_trees - init_mean
191
196
self .a_tree = Tree .new_tree (
0 commit comments