16
16
17
17
import numpy as np
18
18
import numpy .typing as npt
19
+ import pymc as pm
20
+ import pytensor as pt
19
21
from numba import njit
20
22
from pymc .initial_point import PointType
21
23
from pymc .model import Model , modelcontext
@@ -120,15 +122,15 @@ class PGBART(ArrayStepShared):
120
122
"tune" : (bool , []),
121
123
}
122
124
123
- def __init__ ( # noqa: PLR0915
125
+ def __init__ ( # noqa: PLR0912, PLR0915
124
126
self ,
125
- vars = None , # pylint: disable=redefined-builtin
127
+ vars : list [ pm . Distribution ] | None = None ,
126
128
num_particles : int = 10 ,
127
129
batch : tuple [float , float ] = (0.1 , 0.1 ),
128
130
model : Optional [Model ] = None ,
129
131
initial_point : PointType | None = None ,
130
- compile_kwargs : dict | None = None , # pylint: disable=unused-argument
131
- ):
132
+ compile_kwargs : dict | None = None ,
133
+ ) -> None :
132
134
model = modelcontext (model )
133
135
if initial_point is None :
134
136
initial_point = model .initial_point ()
@@ -137,6 +139,10 @@ def __init__( # noqa: PLR0915
137
139
else :
138
140
vars = [model .rvs_to_values .get (var , var ) for var in vars ]
139
141
vars = inputvars (vars )
142
+
143
+ if vars is None :
144
+ raise ValueError ("Unable to find variables to sample" )
145
+
140
146
value_bart = vars [0 ]
141
147
self .bart = model .values_to_rvs [value_bart ].owner .op
142
148
@@ -395,7 +401,7 @@ def update_weight(self, particle: ParticleTree, odim: int) -> None:
395
401
particle .log_weight = new_likelihood
396
402
397
403
@staticmethod
398
- def competence (var , has_grad ) :
404
+ def competence (var : pm . Distribution , has_grad : bool ) -> Competence :
399
405
"""PGBART is only suitable for BART distributions."""
400
406
dist = getattr (var .owner , "op" , None )
401
407
if isinstance (dist , BARTRV ):
@@ -406,7 +412,7 @@ def competence(var, has_grad):
406
412
class RunningSd :
407
413
"""Welford's online algorithm for computing the variance/standard deviation"""
408
414
409
- def __init__ (self , shape : tuple ) -> None :
415
+ def __init__ (self , shape : tuple [ int , ...] ) -> None :
410
416
self .count = 0 # number of data points
411
417
self .mean = np .zeros (shape ) # running mean
412
418
self .m_2 = np .zeros (shape ) # running second moment
@@ -561,7 +567,7 @@ def draw_leaf_value(
561
567
return np .zeros (shape ), linear_params
562
568
563
569
if y_mu_pred .size == 1 :
564
- mu_mean = ( np .full (shape , y_mu_pred .item () / m ) + norm ). astype ( np . float64 )
570
+ mu_mean = np .full (shape , y_mu_pred .item () / m ) + norm
565
571
elif y_mu_pred .size < 3 or response == "constant" :
566
572
mu_mean = fast_mean (y_mu_pred ) / m + norm
567
573
else :
@@ -585,7 +591,7 @@ def fast_mean(ari: npt.NDArray) -> Union[float, npt.NDArray]:
585
591
for j in range (ari .shape [0 ]):
586
592
for i in range (count ):
587
593
res [j ] += ari [j , i ]
588
- return ( res / count ). astype ( np . float64 )
594
+ return res / count
589
595
590
596
591
597
@njit
@@ -596,7 +602,7 @@ def fast_linear_fit(
596
602
norm : npt .NDArray ,
597
603
) -> tuple [npt .NDArray , list [npt .NDArray ]]:
598
604
n = len (x )
599
- y = ( y / m + np .expand_dims (norm , axis = 1 )). astype ( np . float64 )
605
+ y = y / m + np .expand_dims (norm , axis = 1 )
600
606
601
607
xbar = np .sum (x ) / n
602
608
ybar = np .sum (y , axis = 1 ) / n
@@ -732,7 +738,9 @@ def are_whole_number(array: npt.NDArray) -> np.bool_:
732
738
return np .all (np .mod (array [~ np .isnan (array )], 1 ) == 0 )
733
739
734
740
735
- def logp (point , out_vars , vars , shared ): # pylint: disable=redefined-builtin
741
+ def logp (
742
+ point , out_vars : list [pm .Distribution ], vars : list [pm .Distribution ], shared : list [pt .Tensor ]
743
+ ):
736
744
"""Compile PyTensor function of the model and the input and output variables.
737
745
738
746
Parameters
0 commit comments