File tree Expand file tree Collapse file tree 2 files changed +13
-3
lines changed Expand file tree Collapse file tree 2 files changed +13
-3
lines changed Original file line number Diff line number Diff line change @@ -309,9 +309,19 @@ def _traverse_tree(
309
309
)
310
310
if excluded is not None and idx_split_variable in excluded :
311
311
prop_nvalue_left = self .get_node (left_node_index ).nvalue / node .nvalue
312
- stack .append ((left_node_index , weights * prop_nvalue_left , idx_split_variable ))
313
312
stack .append (
314
- (right_node_index , weights * (1 - prop_nvalue_left ), idx_split_variable )
313
+ (
314
+ left_node_index ,
315
+ (weights * prop_nvalue_left ).astype (np .float64 ),
316
+ idx_split_variable ,
317
+ )
318
+ )
319
+ stack .append (
320
+ (
321
+ right_node_index ,
322
+ (weights * (1 - prop_nvalue_left )).astype (np .float64 ),
323
+ idx_split_variable ,
324
+ )
315
325
)
316
326
else :
317
327
to_left = (
Original file line number Diff line number Diff line change @@ -848,7 +848,7 @@ def compute_variable_importance( # noqa: PLR0915 PLR0912
848
848
idxs = np .argsort (
849
849
idata ["sample_stats" ]["variable_inclusion" ].mean (("chain" , "draw" )).values
850
850
)
851
- subsets = [idxs [:- i ]. tolist ( ) for i in range (1 , len (idxs ))]
851
+ subsets : list [ list [ int ]] = [list ( idxs [:- i ]) for i in range (1 , len (idxs ))]
852
852
subsets .append (None ) # type: ignore
853
853
854
854
if method == "backward_VI" :
You can’t perform that action at this time.
0 commit comments