1717import numpy as np
1818import numpy .typing as npt
1919from numba import njit
20+ from pymc .initial_point import PointType
2021from pymc .model import Model , modelcontext
2122from pymc .pytensorf import inputvars , join_nonshared_inputs , make_shared_replacements
2223from pymc .step_methods .arraystep import ArrayStepShared
@@ -125,9 +126,12 @@ def __init__( # noqa: PLR0915
125126 num_particles : int = 10 ,
126127 batch : tuple [float , float ] = (0.1 , 0.1 ),
127128 model : Optional [Model ] = None ,
129+ initial_point : PointType | None = None ,
130+ compile_kwargs : dict | None = None , # pylint: disable=unused-argument
128131 ):
129132 model = modelcontext (model )
130- initial_values = model .initial_point ()
133+ if initial_point is None :
134+ initial_point = model .initial_point ()
131135 if vars is None :
132136 vars = model .value_vars
133137 else :
@@ -150,7 +154,7 @@ def __init__( # noqa: PLR0915
150154 self .m = self .bart .m
151155 self .response = self .bart .response
152156
153- shape = initial_values [value_bart .name ].shape
157+ shape = initial_point [value_bart .name ].shape
154158
155159 self .shape = 1 if len (shape ) == 1 else shape [0 ]
156160
@@ -217,8 +221,8 @@ def __init__( # noqa: PLR0915
217221
218222 self .num_particles = num_particles
219223 self .indices = list (range (1 , num_particles ))
220- shared = make_shared_replacements (initial_values , vars , model )
221- self .likelihood_logp = logp (initial_values , [model .datalogp ], vars , shared )
224+ shared = make_shared_replacements (initial_point , vars , model )
225+ self .likelihood_logp = logp (initial_point , [model .datalogp ], vars , shared )
222226 self .all_particles = [
223227 [ParticleTree (self .a_tree ) for _ in range (self .m )] for _ in range (self .trees_shape )
224228 ]
0 commit comments