Skip to content

Commit 0182665

Browse files
Allow Y to be a tensor (#180)
* Allow Y to be a tensor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update pgbart.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 0e22798 commit 0182665

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

pymc_bart/pgbart.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ def __init__( # noqa: PLR0915
138138
else:
139139
self.X = self.bart.X
140140

141+
if isinstance(self.bart.Y, Variable):
142+
self.Y = self.bart.Y.eval()
143+
else:
144+
self.Y = self.bart.Y
145+
141146
self.missing_data = np.any(np.isnan(self.X))
142147
self.m = self.bart.m
143148
self.response = self.bart.response
@@ -166,26 +171,26 @@ def __init__( # noqa: PLR0915
166171
if rule is ContinuousSplitRule:
167172
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.nanstd(self.X[:, idx]))
168173

169-
init_mean = self.bart.Y.mean()
174+
init_mean = self.Y.mean()
170175
self.num_observations = self.X.shape[0]
171176
self.num_variates = self.X.shape[1]
172177
self.available_predictors = list(range(self.num_variates))
173178

174179
# if data is binary
175180
self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape))
176181

177-
y_unique = np.unique(self.bart.Y)
182+
y_unique = np.unique(self.Y)
178183
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
179184
self.leaf_sd *= 3 / self.m**0.5
180185
else:
181-
self.leaf_sd *= self.bart.Y.std() / self.m**0.5
186+
self.leaf_sd *= self.Y.std() / self.m**0.5
182187

183188
self.running_sd = [
184189
RunningSd((self.leaves_shape, self.num_observations)) for _ in range(self.trees_shape)
185190
]
186191

187192
self.sum_trees = np.full(
188-
(self.trees_shape, self.leaves_shape, self.bart.Y.shape[0]), init_mean
193+
(self.trees_shape, self.leaves_shape, self.Y.shape[0]), init_mean
189194
).astype(config.floatX)
190195
self.sum_trees_noi = self.sum_trees - init_mean
191196
self.a_tree = Tree.new_tree(

0 commit comments

Comments
 (0)