Skip to content

Commit 2f0b3aa

Browse files
authored
fix bug with shapes (#208)
1 parent 1ec251b commit 2f0b3aa

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

pymc_bart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
828828

829829
r2_mean = np.zeros(n_vars)
830830
r2_hdi = np.zeros((n_vars, 2))
831-
preds = np.zeros((n_vars, samples, bartrv.eval().shape[0]))
831+
preds = np.zeros((n_vars, samples, *bartrv.eval().T.shape))
832832

833833
if method == "backward_VI":
834834
if fixed >= n_vars:

tests/test_bart.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,4 @@ def test_categorical_model(separate_trees, split_rule):
255255

256256
# Fit should be good enough so right category is selected over 50% of time
257257
assert (idata.predictions.y.median(["chain", "draw"]) == Y).all()
258+
assert pmb.compute_variable_importance(idata, bartrv=lo, X=X)["preds"].shape == (5, 50, 9, 3)

0 commit comments

Comments
 (0)