diff --git a/pymc_bart/utils.py b/pymc_bart/utils.py index d9738dd..d9d5241 100644 --- a/pymc_bart/utils.py +++ b/pymc_bart/utils.py @@ -828,7 +828,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912 r2_mean = np.zeros(n_vars) r2_hdi = np.zeros((n_vars, 2)) - preds = np.zeros((n_vars, samples, bartrv.eval().shape[0])) + preds = np.zeros((n_vars, samples, *bartrv.eval().T.shape)) if method == "backward_VI": if fixed >= n_vars: diff --git a/tests/test_bart.py b/tests/test_bart.py index a003363..226d938 100644 --- a/tests/test_bart.py +++ b/tests/test_bart.py @@ -255,3 +255,4 @@ def test_categorical_model(separate_trees, split_rule): # Fit should be good enough so right category is selected over 50% of time assert (idata.predictions.y.median(["chain", "draw"]) == Y).all() + assert pmb.compute_variable_importance(idata, bartrv=lo, X=X)["preds"].shape == (5, 50, 9, 3)