1616
1717import numpy as np
1818import numpy .typing as npt
19+ import pymc as pm
20+ import pytensor .tensor as pt
1921from numba import njit
2022from pymc .initial_point import PointType
2123from pymc .model import Model , modelcontext
@@ -120,15 +122,15 @@ class PGBART(ArrayStepShared):
120122 "tune" : (bool , []),
121123 }
122124
123- def __init__ ( # noqa: PLR0915
125+ def __init__ ( # noqa: PLR0912, PLR0915
124126 self ,
125- vars = None , # pylint: disable=redefined-builtin
127+ vars : list [ pm . Distribution ] | None = None ,
126128 num_particles : int = 10 ,
127129 batch : tuple [float , float ] = (0.1 , 0.1 ),
128130 model : Optional [Model ] = None ,
129131 initial_point : PointType | None = None ,
130- compile_kwargs : dict | None = None , # pylint: disable=unused-argument
131- ):
132+ compile_kwargs : dict | None = None ,
133+ ) -> None :
132134 model = modelcontext (model )
133135 if initial_point is None :
134136 initial_point = model .initial_point ()
@@ -137,6 +139,10 @@ def __init__( # noqa: PLR0915
137139 else :
138140 vars = [model .rvs_to_values .get (var , var ) for var in vars ]
139141 vars = inputvars (vars )
142+
143+ if vars is None :
144+ raise ValueError ("Unable to find variables to sample" )
145+
140146 value_bart = vars [0 ]
141147 self .bart = model .values_to_rvs [value_bart ].owner .op
142148
@@ -325,7 +331,7 @@ def normalize(self, particles: list[ParticleTree]) -> float:
325331 return wei / wei .sum ()
326332
327333 def resample (
328- self , particles : list [ParticleTree ], normalized_weights : npt .NDArray [ np . float64 ]
334+ self , particles : list [ParticleTree ], normalized_weights : npt .NDArray
329335 ) -> list [ParticleTree ]:
330336 """
331337 Use systematic resample for all but the first particle
@@ -347,7 +353,7 @@ def resample(
347353 return particles
348354
349355 def get_particle_tree (
350- self , particles : list [ParticleTree ], normalized_weights : npt .NDArray [ np . float64 ]
356+ self , particles : list [ParticleTree ], normalized_weights : npt .NDArray
351357 ) -> tuple [ParticleTree , Tree ]:
352358 """
353359 Sample a new particle and associated tree
@@ -359,7 +365,7 @@ def get_particle_tree(
359365
360366 return new_particle , new_particle .tree
361367
362- def systematic (self , normalized_weights : npt .NDArray [ np . float64 ] ) -> npt .NDArray [np .int_ ]:
368+ def systematic (self , normalized_weights : npt .NDArray ) -> npt .NDArray [np .int_ ]:
363369 """
364370 Systematic resampling.
365371
@@ -395,7 +401,7 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None:
395401 particle .log_weight = new_likelihood
396402
397403 @staticmethod
398- def competence (var , has_grad ) :
404+ def competence (var : pm . Distribution , has_grad : bool ) -> Competence :
399405 """PGBART is only suitable for BART distributions."""
400406 dist = getattr (var .owner , "op" , None )
401407 if isinstance (dist , BARTRV ):
@@ -406,12 +412,12 @@ def competence(var, has_grad):
406412class RunningSd :
407413 """Welford's online algorithm for computing the variance/standard deviation"""
408414
409- def __init__ (self , shape : tuple ) -> None :
415+ def __init__ (self , shape : tuple [ int , ...] ) -> None :
410416 self .count = 0 # number of data points
411417 self .mean = np .zeros (shape ) # running mean
412418 self .m_2 = np .zeros (shape ) # running second moment
413419
414- def update (self , new_value : npt .NDArray [ np . float64 ] ) -> Union [float , npt .NDArray [ np . float64 ] ]:
420+ def update (self , new_value : npt .NDArray ) -> Union [float , npt .NDArray ]:
415421 self .count = self .count + 1
416422 self .mean , self .m_2 , std = _update (self .count , self .mean , self .m_2 , new_value )
417423 return fast_mean (std )
@@ -420,10 +426,10 @@ def update(self, new_value: npt.NDArray[np.float64]) -> Union[float, npt.NDArray
420426@njit
421427def _update (
422428 count : int ,
423- mean : npt .NDArray [ np . float64 ] ,
424- m_2 : npt .NDArray [ np . float64 ] ,
425- new_value : npt .NDArray [ np . float64 ] ,
426- ) -> tuple [npt .NDArray [ np . float64 ] , npt .NDArray [ np . float64 ] , Union [float , npt .NDArray [ np . float64 ] ]]:
429+ mean : npt .NDArray ,
430+ m_2 : npt .NDArray ,
431+ new_value : npt .NDArray ,
432+ ) -> tuple [npt .NDArray , npt .NDArray , Union [float , npt .NDArray ]]:
427433 delta = new_value - mean
428434 mean += delta / count
429435 delta2 = new_value - mean
@@ -434,7 +440,7 @@ def _update(
434440
435441
436442class SampleSplittingVariable :
437- def __init__ (self , alpha_vec : npt .NDArray [ np . float64 ] ) -> None :
443+ def __init__ (self , alpha_vec : npt .NDArray ) -> None :
438444 """
439445 Sample splitting variables proportional to `alpha_vec`.
440446
@@ -547,16 +553,16 @@ def filter_missing_values(available_splitting_values, idx_data_points, missing_d
547553
548554
549555def draw_leaf_value (
550- y_mu_pred : npt .NDArray [ np . float64 ] ,
551- x_mu : npt .NDArray [ np . float64 ] ,
556+ y_mu_pred : npt .NDArray ,
557+ x_mu : npt .NDArray ,
552558 m : int ,
553- norm : npt .NDArray [ np . float64 ] ,
559+ norm : npt .NDArray ,
554560 shape : int ,
555561 response : str ,
556- ) -> tuple [npt .NDArray [ np . float64 ] , Optional [npt .NDArray [ np . float64 ] ]]:
562+ ) -> tuple [npt .NDArray , Optional [npt .NDArray ]]:
557563 """Draw Gaussian distributed leaf values."""
558564 linear_params = None
559- mu_mean = np . empty ( shape )
565+ mu_mean : npt . NDArray
560566 if y_mu_pred .size == 0 :
561567 return np .zeros (shape ), linear_params
562568
@@ -571,7 +577,7 @@ def draw_leaf_value(
571577
572578
573579@njit
574- def fast_mean (ari : npt .NDArray [ np . float64 ] ) -> Union [float , npt .NDArray [ np . float64 ] ]:
580+ def fast_mean (ari : npt .NDArray ) -> Union [float , npt .NDArray ]:
575581 """Use Numba to speed up the computation of the mean."""
576582 if ari .ndim == 1 :
577583 count = ari .shape [0 ]
@@ -590,11 +596,11 @@ def fast_mean(ari: npt.NDArray[np.float64]) -> Union[float, npt.NDArray[np.float
590596
591597@njit
592598def fast_linear_fit (
593- x : npt .NDArray [ np . float64 ] ,
594- y : npt .NDArray [ np . float64 ] ,
599+ x : npt .NDArray ,
600+ y : npt .NDArray ,
595601 m : int ,
596- norm : npt .NDArray [ np . float64 ] ,
597- ) -> tuple [npt .NDArray [ np . float64 ] , list [npt .NDArray [ np . float64 ] ]]:
602+ norm : npt .NDArray ,
603+ ) -> tuple [npt .NDArray , list [npt .NDArray ]]:
598604 n = len (x )
599605 y = y / m + np .expand_dims (norm , axis = 1 )
600606
@@ -678,17 +684,17 @@ def update(self):
678684
679685@njit
680686def inverse_cdf (
681- single_uniform : npt .NDArray [ np . float64 ] , normalized_weights : npt .NDArray [ np . float64 ]
687+ single_uniform : npt .NDArray , normalized_weights : npt .NDArray
682688) -> npt .NDArray [np .int_ ]:
683689 """
684690 Inverse CDF algorithm for a finite distribution.
685691
686692 Parameters
687693 ----------
688- single_uniform: npt.NDArray[np.float64]
694+ single_uniform: npt.NDArray
689695 Ordered points in [0,1]
690696
691- normalized_weights: npt.NDArray[np.float64] )
697+ normalized_weights: npt.NDArray)
692698 Normalized weights
693699
694700 Returns
@@ -711,7 +717,7 @@ def inverse_cdf(
711717
712718
713719@njit
714- def jitter_duplicated (array : npt .NDArray [ np . float64 ] , std : float ) -> npt .NDArray [ np . float64 ] :
720+ def jitter_duplicated (array : npt .NDArray , std : float ) -> npt .NDArray :
715721 """
716722 Jitter duplicated values.
717723 """
@@ -727,12 +733,17 @@ def jitter_duplicated(array: npt.NDArray[np.float64], std: float) -> npt.NDArray
727733
728734
729735@njit
730- def are_whole_number (array : npt .NDArray [ np . float64 ] ) -> np .bool_ :
736+ def are_whole_number (array : npt .NDArray ) -> np .bool_ :
731737 """Check if all values in array are whole numbers"""
732738 return np .all (np .mod (array [~ np .isnan (array )], 1 ) == 0 )
733739
734740
735- def logp (point , out_vars , vars , shared ): # pylint: disable=redefined-builtin
741+ def logp (
742+ point ,
743+ out_vars : list [pm .Distribution ],
744+ vars : list [pm .Distribution ],
745+ shared : list [pt .TensorVariable ],
746+ ):
736747 """Compile PyTensor function of the model and the input and output variables.
737748
738749 Parameters
0 commit comments