Skip to content

Commit 794726f

Browse files
committed
2 parents 950cb70 + 9fbca4b commit 794726f

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ ci:
1212

1313
repos:
1414
- repo: https://github.com/astral-sh/ruff-pre-commit
15-
rev: v0.5.4
15+
rev: v0.6.1
1616
hooks:
1717
- id: ruff
1818
args: ["--fix", "--output-format=full"]
1919
- id: ruff-format
2020
args: ["--line-length=100"]
2121
- repo: https://github.com/pre-commit/mirrors-mypy
22-
rev: v1.11.0
22+
rev: v1.11.1
2323
hooks:
2424
- id: mypy
2525
args: [--ignore-missing-imports]

pymc_bart/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"plot_pdp",
3737
"plot_variable_importance",
3838
]
39-
__version__ = "0.5.14"
39+
__version__ = "0.6.0"
4040

4141

4242
pm.STEP_METHODS = list(pm.STEP_METHODS) + [PGBART]

pymc_bart/pgbart.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,11 @@ def __init__( # noqa: PLR0915
138138
else:
139139
self.X = self.bart.X
140140

141+
if isinstance(self.bart.Y, Variable):
142+
self.Y = self.bart.Y.eval()
143+
else:
144+
self.Y = self.bart.Y
145+
141146
self.missing_data = np.any(np.isnan(self.X))
142147
self.m = self.bart.m
143148
self.response = self.bart.response
@@ -166,26 +171,26 @@ def __init__( # noqa: PLR0915
166171
if rule is ContinuousSplitRule:
167172
self.X[:, idx] = jitter_duplicated(self.X[:, idx], np.nanstd(self.X[:, idx]))
168173

169-
init_mean = self.bart.Y.mean()
174+
init_mean = self.Y.mean()
170175
self.num_observations = self.X.shape[0]
171176
self.num_variates = self.X.shape[1]
172177
self.available_predictors = list(range(self.num_variates))
173178

174179
# if data is binary
175180
self.leaf_sd = np.ones((self.trees_shape, self.leaves_shape))
176181

177-
y_unique = np.unique(self.bart.Y)
182+
y_unique = np.unique(self.Y)
178183
if y_unique.size == 2 and np.all(y_unique == [0, 1]):
179184
self.leaf_sd *= 3 / self.m**0.5
180185
else:
181-
self.leaf_sd *= self.bart.Y.std() / self.m**0.5
186+
self.leaf_sd *= self.Y.std() / self.m**0.5
182187

183188
self.running_sd = [
184189
RunningSd((self.leaves_shape, self.num_observations)) for _ in range(self.trees_shape)
185190
]
186191

187192
self.sum_trees = np.full(
188-
(self.trees_shape, self.leaves_shape, self.bart.Y.shape[0]), init_mean
193+
(self.trees_shape, self.leaves_shape, self.Y.shape[0]), init_mean
189194
).astype(config.floatX)
190195
self.sum_trees_noi = self.sum_trees - init_mean
191196
self.a_tree = Tree.new_tree(

0 commit comments

Comments
 (0)