@@ -187,18 +187,18 @@ def astep(self, _):
187
187
188
188
_ , normalized_weights = self .normalize (particles )
189
189
# Get the new tree and update
190
- new_particle , new_tree = self .get_particle_tree (particles , normalized_weights )
191
- self .all_particles [tree_id ] = new_particle
190
+ self .all_particles [tree_id ], new_tree = self .get_particle_tree (
191
+ particles , normalized_weights
192
+ )
192
193
self .sum_trees = self .sum_trees_noi + new_tree ._predict ()
193
194
self .all_trees [tree_id ] = new_tree .trim ()
194
- used_variates = new_tree .get_split_variables ()
195
195
196
196
if self .tune :
197
197
self .ssv = SampleSplittingVariable (self .alpha_vec )
198
- for index in used_variates :
198
+ for index in new_tree . get_split_variables () :
199
199
self .alpha_vec [index ] += 1
200
200
else :
201
- for index in used_variates :
201
+ for index in new_tree . get_split_variables () :
202
202
variable_inclusion [index ] += 1
203
203
204
204
if not self .tune :
@@ -284,12 +284,9 @@ def init_particles(self, tree_id: int) -> np.ndarray:
284
284
particles = [p0 , p1 ]
285
285
286
286
for _ in self .indices :
287
- pt = ParticleTree (self .a_tree )
288
- if self .tune :
289
- pt .kfactor = self .uniform .random ()
290
- else :
291
- pt .kfactor = p0 .kfactor
292
- particles .append (pt )
287
+ particles .append (
288
+ ParticleTree (self .a_tree , self .uniform .random () if self .tune else p0 .kfactor )
289
+ )
293
290
294
291
return np .array (particles )
295
292
@@ -305,10 +302,10 @@ def update_weight(self, particle, old=False):
305
302
)
306
303
if old :
307
304
particle .log_weight = new_likelihood
308
- particle .old_likelihood_logp = new_likelihood
309
305
else :
310
306
particle .log_weight += new_likelihood - particle .old_likelihood_logp
311
- particle .old_likelihood_logp = new_likelihood
307
+
308
+ particle .old_likelihood_logp = new_likelihood
312
309
313
310
@staticmethod
314
311
def competence (var , has_grad ):
@@ -324,21 +321,19 @@ class ParticleTree:
324
321
325
322
__slots__ = "tree" , "expansion_nodes" , "log_weight" , "old_likelihood_logp" , "kfactor"
326
323
327
- def __init__ (self , tree ):
324
+ def __init__ (self , tree , kfactor = 0.75 ):
328
325
self .tree = tree .copy ()
329
326
self .expansion_nodes = [0 ]
330
327
self .log_weight = 0
331
328
self .old_likelihood_logp = 0
332
- self .kfactor = 0.75
329
+ self .kfactor = kfactor
333
330
334
331
def copy (self ):
335
332
p = ParticleTree (self .tree )
336
- p .expansion_nodes , p .log_weight , p .old_likelihood_logp , p .kfactor = (
337
- self .expansion_nodes .copy (),
338
- self .log_weight ,
339
- self .old_likelihood_logp ,
340
- self .kfactor ,
341
- )
333
+ p .expansion_nodes = self .expansion_nodes .copy ()
334
+ p .log_weight = self .log_weight
335
+ p .old_likelihood_logp = self .old_likelihood_logp
336
+ p .kfactor = self .kfactor
342
337
return p
343
338
344
339
def sample_tree (
@@ -360,7 +355,7 @@ def sample_tree(
360
355
prob_leaf = prior_prob_leaf_node [get_depth (index_leaf_node )]
361
356
362
357
if prob_leaf < np .random .random ():
363
- index_selected_predictor = grow_tree (
358
+ idx_new_nodes = grow_tree (
364
359
self .tree ,
365
360
index_leaf_node ,
366
361
ssv ,
@@ -373,9 +368,8 @@ def sample_tree(
373
368
self .kfactor ,
374
369
shape ,
375
370
)
376
- if index_selected_predictor is not None :
377
- new_indexes = self .tree .idx_leaf_nodes [- 2 :]
378
- self .expansion_nodes .extend (new_indexes )
371
+ if idx_new_nodes is not None :
372
+ self .expansion_nodes .extend (idx_new_nodes )
379
373
tree_grew = True
380
374
381
375
return tree_grew
@@ -389,8 +383,7 @@ def sample_leafs(self, sum_trees, m, normal, shape):
389
383
node_value = draw_leaf_value (
390
384
sum_trees [:, idx_data_points ],
391
385
m ,
392
- normal ,
393
- self .kfactor ,
386
+ normal .random () * self .kfactor ,
394
387
shape ,
395
388
)
396
389
leaf .value = node_value
@@ -463,55 +456,43 @@ def grow_tree(
463
456
split_value = get_split_value (available_splitting_values , idx_data_points , missing_data )
464
457
465
458
if split_value is None :
466
- index_selected_predictor = None
467
- else :
468
- new_idx_data_points = get_new_idx_data_points (
469
- split_value , idx_data_points , selected_predictor , X
470
- )
471
- current_node_children = (
472
- current_node .get_idx_left_child (),
473
- current_node .get_idx_right_child (),
459
+ return None
460
+ new_idx_data_points = get_new_idx_data_points (
461
+ available_splitting_values , split_value , idx_data_points
462
+ )
463
+ current_node_children = (
464
+ current_node .get_idx_left_child (),
465
+ current_node .get_idx_right_child (),
466
+ )
467
+
468
+ new_nodes = []
469
+ for idx in range (2 ):
470
+ idx_data_point = new_idx_data_points [idx ]
471
+ node_value = draw_leaf_value (
472
+ sum_trees [:, idx_data_point ],
473
+ m ,
474
+ normal .random () * kfactor ,
475
+ shape ,
474
476
)
475
477
476
- new_nodes = []
477
- for idx in range (2 ):
478
- idx_data_point = new_idx_data_points [idx ]
479
- node_value = draw_leaf_value (
480
- sum_trees [:, idx_data_point ],
481
- m ,
482
- normal ,
483
- kfactor ,
484
- shape ,
485
- )
486
-
487
- new_node = Node .new_leaf_node (
488
- index = current_node_children [idx ],
489
- value = node_value ,
490
- idx_data_points = idx_data_point ,
491
- )
492
- new_nodes .append (new_node )
493
-
494
- new_split_node = Node .new_split_node (
495
- index = index_leaf_node ,
496
- split_value = split_value ,
497
- idx_split_variable = selected_predictor ,
478
+ new_node = Node .new_leaf_node (
479
+ index = current_node_children [idx ],
480
+ value = node_value ,
481
+ idx_data_points = idx_data_point ,
498
482
)
483
+ new_nodes .append (new_node )
499
484
500
- # update tree nodes and indexes
501
- tree .delete_leaf_node (index_leaf_node )
502
- tree .set_node (index_leaf_node , new_split_node )
503
- tree .set_node (new_nodes [0 ].index , new_nodes [0 ])
504
- tree .set_node (new_nodes [1 ].index , new_nodes [1 ])
505
-
506
- return index_selected_predictor
485
+ tree .grow_leaf_node (current_node , selected_predictor , split_value , index_leaf_node )
486
+ tree .set_node (new_nodes [0 ].index , new_nodes [0 ])
487
+ tree .set_node (new_nodes [1 ].index , new_nodes [1 ])
507
488
489
+ return [new_nodes [0 ].index , new_nodes [1 ].index ]
508
490
509
- def get_new_idx_data_points (split_value , idx_data_points , selected_predictor , X ):
510
- left_idx = X [idx_data_points , selected_predictor ] <= split_value
511
- left_node_idx_data_points = idx_data_points [left_idx ]
512
- right_node_idx_data_points = idx_data_points [~ left_idx ]
513
491
514
- return left_node_idx_data_points , right_node_idx_data_points
492
+ @njit
493
+ def get_new_idx_data_points (available_splitting_values , split_value , idx_data_points ):
494
+ split_idx = available_splitting_values <= split_value
495
+ return idx_data_points [split_idx ], idx_data_points [~ split_idx ]
515
496
516
497
517
498
def get_split_value (available_splitting_values , idx_data_points , missing_data ):
@@ -529,19 +510,18 @@ def get_split_value(available_splitting_values, idx_data_points, missing_data):
529
510
return split_value
530
511
531
512
532
- def draw_leaf_value (y_mu_pred , m , normal , kfactor , shape ):
513
+ @njit
514
+ def draw_leaf_value (y_mu_pred , m , norm , shape ):
533
515
"""Draw Gaussian distributed leaf values."""
534
516
if y_mu_pred .size == 0 :
535
517
return np .zeros (shape )
518
+
519
+ if y_mu_pred .size == 1 :
520
+ mu_mean = np .full (shape , y_mu_pred .item () / m )
536
521
else :
537
- norm = normal .random () * kfactor
538
- if y_mu_pred .size == 1 :
539
- mu_mean = np .full (shape , y_mu_pred .item () / m )
540
- else :
541
- mu_mean = fast_mean (y_mu_pred ) / m
522
+ mu_mean = fast_mean (y_mu_pred ) / m
542
523
543
- draw = norm + mu_mean
544
- return draw
524
+ return norm + mu_mean
545
525
546
526
547
527
@njit
0 commit comments