@@ -313,7 +313,7 @@ def normalize(self, particles: List[ParticleTree]) -> float:
313
313
return wei / wei .sum ()
314
314
315
315
def resample (
316
- self , particles : List [ParticleTree ], normalized_weights : npt .NDArray [np .float_ ]
316
+ self , particles : List [ParticleTree ], normalized_weights : npt .NDArray [np .float64 ]
317
317
) -> List [ParticleTree ]:
318
318
"""
319
319
Use systematic resample for all but the first particle
@@ -335,7 +335,7 @@ def resample(
335
335
return particles
336
336
337
337
def get_particle_tree (
338
- self , particles : List [ParticleTree ], normalized_weights : npt .NDArray [np .float_ ]
338
+ self , particles : List [ParticleTree ], normalized_weights : npt .NDArray [np .float64 ]
339
339
) -> Tuple [ParticleTree , Tree ]:
340
340
"""
341
341
Sample a new particle and associated tree
@@ -347,7 +347,7 @@ def get_particle_tree(
347
347
348
348
return new_particle , new_particle .tree
349
349
350
- def systematic (self , normalized_weights : npt .NDArray [np .float_ ]) -> npt .NDArray [np .int_ ]:
350
+ def systematic (self , normalized_weights : npt .NDArray [np .float64 ]) -> npt .NDArray [np .int_ ]:
351
351
"""
352
352
Systematic resampling.
353
353
@@ -399,7 +399,7 @@ def __init__(self, shape: tuple) -> None:
399
399
self .mean = np .zeros (shape ) # running mean
400
400
self .m_2 = np .zeros (shape ) # running second moment
401
401
402
- def update (self , new_value : npt .NDArray [np .float_ ]) -> Union [float , npt .NDArray [np .float_ ]]:
402
+ def update (self , new_value : npt .NDArray [np .float64 ]) -> Union [float , npt .NDArray [np .float64 ]]:
403
403
self .count = self .count + 1
404
404
self .mean , self .m_2 , std = _update (self .count , self .mean , self .m_2 , new_value )
405
405
return fast_mean (std )
@@ -408,10 +408,10 @@ def update(self, new_value: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[
408
408
@njit
409
409
def _update (
410
410
count : int ,
411
- mean : npt .NDArray [np .float_ ],
412
- m_2 : npt .NDArray [np .float_ ],
413
- new_value : npt .NDArray [np .float_ ],
414
- ) -> Tuple [npt .NDArray [np .float_ ], npt .NDArray [np .float_ ], Union [float , npt .NDArray [np .float_ ]]]:
411
+ mean : npt .NDArray [np .float64 ],
412
+ m_2 : npt .NDArray [np .float64 ],
413
+ new_value : npt .NDArray [np .float64 ],
414
+ ) -> Tuple [npt .NDArray [np .float64 ], npt .NDArray [np .float64 ], Union [float , npt .NDArray [np .float64 ]]]:
415
415
delta = new_value - mean
416
416
mean += delta / count
417
417
delta2 = new_value - mean
@@ -422,7 +422,7 @@ def _update(
422
422
423
423
424
424
class SampleSplittingVariable :
425
- def __init__ (self , alpha_vec : npt .NDArray [np .float_ ]) -> None :
425
+ def __init__ (self , alpha_vec : npt .NDArray [np .float64 ]) -> None :
426
426
"""
427
427
Sample splitting variables proportional to `alpha_vec`.
428
428
@@ -535,13 +535,13 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
535
535
536
536
537
537
def draw_leaf_value (
538
- y_mu_pred : npt .NDArray [np .float_ ],
539
- x_mu : npt .NDArray [np .float_ ],
538
+ y_mu_pred : npt .NDArray [np .float64 ],
539
+ x_mu : npt .NDArray [np .float64 ],
540
540
m : int ,
541
- norm : npt .NDArray [np .float_ ],
541
+ norm : npt .NDArray [np .float64 ],
542
542
shape : int ,
543
543
response : str ,
544
- ) -> Tuple [npt .NDArray [np .float_ ], Optional [npt .NDArray [np .float_ ]]]:
544
+ ) -> Tuple [npt .NDArray [np .float64 ], Optional [npt .NDArray [np .float64 ]]]:
545
545
"""Draw Gaussian distributed leaf values."""
546
546
linear_params = None
547
547
mu_mean = np .empty (shape )
@@ -559,7 +559,7 @@ def draw_leaf_value(
559
559
560
560
561
561
@njit
562
- def fast_mean (ari : npt .NDArray [np .float_ ]) -> Union [float , npt .NDArray [np .float_ ]]:
562
+ def fast_mean (ari : npt .NDArray [np .float64 ]) -> Union [float , npt .NDArray [np .float64 ]]:
563
563
"""Use Numba to speed up the computation of the mean."""
564
564
if ari .ndim == 1 :
565
565
count = ari .shape [0 ]
@@ -578,11 +578,11 @@ def fast_mean(ari: npt.NDArray[np.float_]) -> Union[float, npt.NDArray[np.float_
578
578
579
579
@njit
580
580
def fast_linear_fit (
581
- x : npt .NDArray [np .float_ ],
582
- y : npt .NDArray [np .float_ ],
581
+ x : npt .NDArray [np .float64 ],
582
+ y : npt .NDArray [np .float64 ],
583
583
m : int ,
584
- norm : npt .NDArray [np .float_ ],
585
- ) -> Tuple [npt .NDArray [np .float_ ], List [npt .NDArray [np .float_ ]]]:
584
+ norm : npt .NDArray [np .float64 ],
585
+ ) -> Tuple [npt .NDArray [np .float64 ], List [npt .NDArray [np .float64 ]]]:
586
586
n = len (x )
587
587
y = y / m + np .expand_dims (norm , axis = 1 )
588
588
@@ -666,17 +666,17 @@ def update(self):
666
666
667
667
@njit
668
668
def inverse_cdf (
669
- single_uniform : npt .NDArray [np .float_ ], normalized_weights : npt .NDArray [np .float_ ]
669
+ single_uniform : npt .NDArray [np .float64 ], normalized_weights : npt .NDArray [np .float64 ]
670
670
) -> npt .NDArray [np .int_ ]:
671
671
"""
672
672
Inverse CDF algorithm for a finite distribution.
673
673
674
674
Parameters
675
675
----------
676
- single_uniform: npt.NDArray[np.float_ ]
676
+ single_uniform: npt.NDArray[np.float64 ]
677
677
Ordered points in [0,1]
678
678
679
- normalized_weights: npt.NDArray[np.float_ ])
679
+ normalized_weights: npt.NDArray[np.float64 ])
680
680
Normalized weights
681
681
682
682
Returns
@@ -699,7 +699,7 @@ def inverse_cdf(
699
699
700
700
701
701
@njit
702
- def jitter_duplicated (array : npt .NDArray [np .float_ ], std : float ) -> npt .NDArray [np .float_ ]:
702
+ def jitter_duplicated (array : npt .NDArray [np .float64 ], std : float ) -> npt .NDArray [np .float64 ]:
703
703
"""
704
704
Jitter duplicated values.
705
705
"""
@@ -715,7 +715,7 @@ def jitter_duplicated(array: npt.NDArray[np.float_], std: float) -> npt.NDArray[
715
715
716
716
717
717
@njit
718
- def are_whole_number (array : npt .NDArray [np .float_ ]) -> np .bool_ :
718
+ def are_whole_number (array : npt .NDArray [np .float64 ]) -> np .bool_ :
719
719
"""Check if all values in array are whole numbers"""
720
720
return np .all (np .mod (array [~ np .isnan (array )], 1 ) == 0 )
721
721
0 commit comments