@@ -325,7 +325,7 @@ def normalize(self, particles: list[ParticleTree]) -> float:
325
325
return wei / wei .sum ()
326
326
327
327
def resample (
328
- self , particles : list [ParticleTree ], normalized_weights : npt .NDArray [ np . float64 ]
328
+ self , particles : list [ParticleTree ], normalized_weights : npt .NDArray
329
329
) -> list [ParticleTree ]:
330
330
"""
331
331
Use systematic resample for all but the first particle
@@ -347,7 +347,7 @@ def resample(
347
347
return particles
348
348
349
349
def get_particle_tree (
350
- self , particles : list [ParticleTree ], normalized_weights : npt .NDArray [ np . float64 ]
350
+ self , particles : list [ParticleTree ], normalized_weights : npt .NDArray
351
351
) -> tuple [ParticleTree , Tree ]:
352
352
"""
353
353
Sample a new particle and associated tree
@@ -359,7 +359,7 @@ def get_particle_tree(
359
359
360
360
return new_particle , new_particle .tree
361
361
362
- def systematic (self , normalized_weights : npt .NDArray [ np . float64 ] ) -> npt .NDArray [np .int_ ]:
362
+ def systematic (self , normalized_weights : npt .NDArray ) -> npt .NDArray [np .int_ ]:
363
363
"""
364
364
Systematic resampling.
365
365
@@ -411,7 +411,7 @@ def __init__(self, shape: tuple) -> None:
411
411
self .mean = np .zeros (shape ) # running mean
412
412
self .m_2 = np .zeros (shape ) # running second moment
413
413
414
- def update (self , new_value : npt .NDArray [ np . float64 ] ) -> Union [float , npt .NDArray [ np . float64 ] ]:
414
+ def update (self , new_value : npt .NDArray ) -> Union [float , npt .NDArray ]:
415
415
self .count = self .count + 1
416
416
self .mean , self .m_2 , std = _update (self .count , self .mean , self .m_2 , new_value )
417
417
return fast_mean (std )
@@ -420,21 +420,21 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray
420
420
@njit
421
421
def _update (
422
422
count : int ,
423
- mean : npt .NDArray [ np . float64 ] ,
424
- m_2 : npt .NDArray [ np . float64 ] ,
425
- new_value : npt .NDArray [ np . float64 ] ,
426
- ) -> tuple [npt .NDArray [ np . float64 ] , npt .NDArray [ np . float64 ] , Union [float , npt .NDArray [ np . float64 ] ]]:
423
+ mean : npt .NDArray ,
424
+ m_2 : npt .NDArray ,
425
+ new_value : npt .NDArray ,
426
+ ) -> tuple [npt .NDArray , npt .NDArray , Union [float , npt .NDArray ]]:
427
427
delta = new_value - mean
428
428
mean += delta / count
429
429
delta2 = new_value - mean
430
430
m_2 += delta * delta2
431
431
432
432
std = (m_2 / count ) ** 0.5
433
- return mean . astype ( np . float64 ) , m_2 . astype ( np . float64 ) , std . astype ( np . float64 )
433
+ return mean , m_2 , std
434
434
435
435
436
436
class SampleSplittingVariable :
437
- def __init__ (self , alpha_vec : npt .NDArray [ np . float64 ] ) -> None :
437
+ def __init__ (self , alpha_vec : npt .NDArray ) -> None :
438
438
"""
439
439
Sample splitting variables proportional to `alpha_vec`.
440
440
@@ -547,16 +547,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
547
547
548
548
549
549
def draw_leaf_value (
550
- y_mu_pred : npt .NDArray [ np . float64 ] ,
551
- x_mu : npt .NDArray [ np . float64 ] ,
550
+ y_mu_pred : npt .NDArray ,
551
+ x_mu : npt .NDArray ,
552
552
m : int ,
553
- norm : npt .NDArray [ np . float64 ] ,
553
+ norm : npt .NDArray ,
554
554
shape : int ,
555
555
response : str ,
556
- ) -> tuple [npt .NDArray [ np . float64 ] , Optional [npt .NDArray [ np . float64 ] ]]:
556
+ ) -> tuple [npt .NDArray , Optional [npt .NDArray ]]:
557
557
"""Draw Gaussian distributed leaf values."""
558
558
linear_params = None
559
- mu_mean : npt .NDArray [ np . float64 ]
559
+ mu_mean : npt .NDArray
560
560
if y_mu_pred .size == 0 :
561
561
return np .zeros (shape ), linear_params
562
562
@@ -571,7 +571,7 @@ def draw_leaf_value(
571
571
572
572
573
573
@njit
574
- def fast_mean (ari : npt .NDArray [ np . float64 ] ) -> Union [float , npt .NDArray [ np . float64 ] ]:
574
+ def fast_mean (ari : npt .NDArray ) -> Union [float , npt .NDArray ]:
575
575
"""Use Numba to speed up the computation of the mean."""
576
576
if ari .ndim == 1 :
577
577
count = ari .shape [0 ]
@@ -590,11 +590,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float
590
590
591
591
@njit
592
592
def fast_linear_fit (
593
- x : npt .NDArray [ np . float64 ] ,
594
- y : npt .NDArray [ np . float64 ] ,
593
+ x : npt .NDArray ,
594
+ y : npt .NDArray ,
595
595
m : int ,
596
- norm : npt .NDArray [ np . float64 ] ,
597
- ) -> tuple [npt .NDArray [ np . float64 ] , list [npt .NDArray [ np . float64 ] ]]:
596
+ norm : npt .NDArray ,
597
+ ) -> tuple [npt .NDArray , list [npt .NDArray ]]:
598
598
n = len (x )
599
599
y = (y / m + np .expand_dims (norm , axis = 1 )).astype (np .float64 )
600
600
@@ -678,17 +678,17 @@ def update(self):
678
678
679
679
@njit
680
680
def inverse_cdf (
681
- single_uniform : npt .NDArray [ np . float64 ] , normalized_weights : npt .NDArray [ np . float64 ]
681
+ single_uniform : npt .NDArray , normalized_weights : npt .NDArray
682
682
) -> npt .NDArray [np .int_ ]:
683
683
"""
684
684
Inverse CDF algorithm for a finite distribution.
685
685
686
686
Parameters
687
687
----------
688
- single_uniform: npt.NDArray[np.float64]
688
+ single_uniform: npt.NDArray
689
689
Ordered points in [0,1]
690
690
691
- normalized_weights: npt.NDArray[np.float64] )
691
+ normalized_weights: npt.NDArray)
692
692
Normalized weights
693
693
694
694
Returns
@@ -711,7 +711,7 @@ def inverse_cdf(
711
711
712
712
713
713
@njit
714
- def jitter_duplicated (array : npt .NDArray [ np . float64 ] , std : float ) -> npt .NDArray [ np . float64 ] :
714
+ def jitter_duplicated (array : npt .NDArray , std : float ) -> npt .NDArray :
715
715
"""
716
716
Jitter duplicated values.
717
717
"""
@@ -727,7 +727,7 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray
727
727
728
728
729
729
@njit
730
- def are_whole_number (array : npt .NDArray [ np . float64 ] ) -> np .bool_ :
730
+ def are_whole_number (array : npt .NDArray ) -> np .bool_ :
731
731
"""Check if all values in array are whole numbers"""
732
732
return np .all (np .mod (array [~ np .isnan (array )], 1 ) == 0 )
733
733
0 commit comments