Skip to content

Commit 3fb92c8

Browse files
author
Juan Orduz
committed
some fixes
1 parent a6e6fea commit 3fb92c8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pymc_bart/pgbart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def draw_leaf_value(
556556
) -> tuple[npt.NDArray[np.float64], Optional[npt.NDArray[np.float64]]]:
557557
"""Draw Gaussian distributed leaf values."""
558558
linear_params = None
559-
mu_mean = np.empty(shape)
559+
560560
if y_mu_pred.size == 0:
561561
return np.zeros(shape), linear_params
562562

@@ -567,7 +567,7 @@ def draw_leaf_value(
567567
else:
568568
mu_mean, linear_params = fast_linear_fit(x=x_mu, y=y_mu_pred, m=m, norm=norm)
569569

570-
return mu_mean, linear_params
570+
return (mu_mean).astype(np.float64), linear_params
571571

572572

573573
@njit

0 commit comments

Comments
 (0)