@@ -65,10 +65,6 @@ def new_leaf_node(
65
65
linear_params = linear_params ,
66
66
)
67
67
68
- @classmethod
69
- def new_split_node (cls , split_value : npt .NDArray [np .float_ ], idx_split_variable : int ) -> "Node" :
70
- return cls (value = split_value , idx_split_variable = idx_split_variable )
71
-
72
68
def is_split_node (self ) -> bool :
73
69
return self .idx_split_variable >= 0
74
70
@@ -282,42 +278,42 @@ def _traverse_tree(
282
278
"""
283
279
284
280
x_shape = (1 ,) if len (X .shape ) == 1 else X .shape [:- 1 ]
281
+ nd_dims = (...,) + (None ,) * len (x_shape )
285
282
286
- stack = [(0 , np .ones (x_shape ))] # (node_index, weight) initial state
283
+ stack = [(0 , np .ones (x_shape ), 0 )] # (node_index, weight, idx_split_variable ) initial state
287
284
p_d = (
288
285
np .zeros (shape + x_shape ) if isinstance (shape , tuple ) else np .zeros ((shape ,) + x_shape )
289
286
)
290
287
while stack :
291
- node_index , weights = stack .pop ()
288
+ node_index , weights , idx_split_variable = stack .pop ()
292
289
node = self .get_node (node_index )
293
290
if node .is_leaf_node ():
294
291
params = node .linear_params
295
- nd_dims = (...,) + (None ,) * len (x_shape )
296
292
if params is None :
297
293
p_d += weights * node .value [nd_dims ]
298
294
else :
299
- # this produce nonsensical results
300
295
p_d += weights * (
301
- params [0 ][nd_dims ] + params [1 ][nd_dims ] * X [..., node . idx_split_variable ]
296
+ params [0 ][nd_dims ] + params [1 ][nd_dims ] * X [..., idx_split_variable ]
302
297
)
303
- # this produce reasonable result
304
- # p_d += weight * node.value.mean()
305
298
else :
306
299
left_node_index , right_node_index = get_idx_left_child (
307
300
node_index
308
301
), get_idx_right_child (node_index )
302
+ idx_split_variable = node .idx_split_variable
309
303
if excluded is not None and node .idx_split_variable in excluded :
310
304
prop_nvalue_left = self .get_node (left_node_index ).nvalue / node .nvalue
311
- stack .append ((left_node_index , weights * prop_nvalue_left ))
312
- stack .append ((right_node_index , weights * (1 - prop_nvalue_left )))
305
+ stack .append ((left_node_index , weights * prop_nvalue_left , idx_split_variable ))
306
+ stack .append (
307
+ (right_node_index , weights * (1 - prop_nvalue_left ), idx_split_variable )
308
+ )
313
309
else :
314
310
to_left = (
315
- self .split_rules [node . idx_split_variable ]
316
- .divide (X [..., node . idx_split_variable ], node .value )
311
+ self .split_rules [idx_split_variable ]
312
+ .divide (X [..., idx_split_variable ], node .value )
317
313
.astype ("float" )
318
314
)
319
- stack .append ((left_node_index , weights * to_left ))
320
- stack .append ((right_node_index , weights * (1 - to_left )))
315
+ stack .append ((left_node_index , weights * to_left , idx_split_variable ))
316
+ stack .append ((right_node_index , weights * (1 - to_left ), idx_split_variable ))
321
317
322
318
if len (X .shape ) == 1 :
323
319
p_d = p_d [..., 0 ]
0 commit comments