Skip to content

Commit 1ff4549

Browse files
author
Juan Orduz
committed
some fixes
1 parent 5b58bba commit 1ff4549

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

pymc_bart/tree.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,19 @@ def _traverse_tree(
309309
)
310310
if excluded is not None and idx_split_variable in excluded:
311311
prop_nvalue_left = self.get_node(left_node_index).nvalue / node.nvalue
312-
stack.append((left_node_index, weights * prop_nvalue_left, idx_split_variable))
313312
stack.append(
314-
(right_node_index, weights * (1 - prop_nvalue_left), idx_split_variable)
313+
(
314+
left_node_index,
315+
(weights * prop_nvalue_left).astype(np.float64),
316+
idx_split_variable,
317+
)
318+
)
319+
stack.append(
320+
(
321+
right_node_index,
322+
(weights * (1 - prop_nvalue_left)).astype(np.float64),
323+
idx_split_variable,
324+
)
315325
)
316326
else:
317327
to_left = (

pymc_bart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
848848
idxs = np.argsort(
849849
idata["sample_stats"]["variable_inclusion"].mean(("chain", "draw")).values
850850
)
851-
subsets = [idxs[:-i].tolist() for i in range(1, len(idxs))]
851+
subsets: list[list[int]] = [list(idxs[:-i]) for i in range(1, len(idxs))]
852852
subsets.append(None) # type: ignore
853853

854854
if method == "backward_VI":

0 commit comments

Comments
 (0)