Skip to content

Commit 2caed89

Browse files
author
Juan Orduz
committed
some fixes
1 parent 3fb92c8 commit 2caed89

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

pymc_bart/pgbart.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -556,18 +556,18 @@ def draw_leaf_value(
556556
) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
557557
"""Draw Gaussian distributed leaf values."""
558558
linear_params = None
559-
559+
mu_mean: npt.NDArray[np.float64] = np.empty(shape)
560560
if y_mu_pred.size == 0:
561561
return np.zeros(shape), linear_params
562562

563563
if y_mu_pred.size == 1:
564-
mu_mean = np.full(shape, y_mu_pred.item() / m) + norm
564+
mu_mean = (np.full(shape, y_mu_pred.item() / m) + norm).astype(np.float64)
565565
elif y_mu_pred.size < 3 or response == "constant":
566566
mu_mean = fast_mean(y_mu_pred) / m + norm
567567
else:
568568
mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m, norm=norm)
569569

570-
return (mu_mean).astype(np.float64), linear_params
570+
return mu_mean, linear_params
571571

572572

573573
@njit

pymc_bart/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -826,9 +826,9 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
826826
else:
827827
labels = np.arange(n_vars).astype(str)
828828

829-
r2_mean = np.zeros(n_vars)
830-
r2_hdi = np.zeros((n_vars, 2))
831-
preds = np.zeros((n_vars, samples, *bartrv.eval().T.shape))
829+
r2_mean: npt.NDArray[np.float64] = np.zeros(n_vars)
830+
r2_hdi: npt.NDArray[np.float64] = np.zeros((n_vars, 2))
831+
preds: npt.NDArray[np.float64] = np.zeros((n_vars, samples, *bartrv.eval().T.shape))
832832

833833
if method == "backward_VI":
834834
if fixed >= n_vars:

0 commit comments

Comments
 (0)