@@ -134,6 +134,7 @@ def __init__(
134134 model = None ,
135135 random_seed = None ,
136136 threshold = 0.5 ,
137+ compile_kwargs : dict | None = None ,
137138 ):
138139 """
139140 Initialize the SMC_kernel class.
@@ -154,6 +155,8 @@ def __init__(
154155 Determines the change of beta from stage to stage, i.e.indirectly the number of stages,
155156 the higher the value of `threshold` the higher the number of stages. Defaults to 0.5.
156157 It should be between 0 and 1.
158+ compile_kwargs: dict, optional
159+ Keyword arguments passed to pytensor.function
157160
158161 Attributes
159162 ----------
@@ -172,8 +175,8 @@ def __init__(
172175 self .model = modelcontext (model )
173176 self .variables = self .model .value_vars
174177
175- self .var_info = {}
176- self .tempered_posterior = None
178+ self .var_info : dict [ str , tuple ] = {}
179+ self .tempered_posterior : np . ndarray
177180 self .prior_logp = None
178181 self .likelihood_logp = None
179182 self .tempered_posterior_logp = None
@@ -184,6 +187,7 @@ def __init__(
184187 self .iteration = 0
185188 self .resampling_indexes = None
186189 self .weights = np .ones (self .draws ) / self .draws
190+ self .compile_kwargs = compile_kwargs if compile_kwargs is not None else {}
187191
188192 def initialize_population (self ) -> dict [str , np .ndarray ]:
189193 """Create an initial population from the prior distribution."""
@@ -239,10 +243,10 @@ def _initialize_kernel(self):
239243 shared = make_shared_replacements (initial_point , self .variables , self .model )
240244
241245 self .prior_logp_func = _logp_forw (
242- initial_point , [self .model .varlogp ], self .variables , shared
246+ initial_point , [self .model .varlogp ], self .variables , shared , self . compile_kwargs
243247 )
244248 self .likelihood_logp_func = _logp_forw (
245- initial_point , [self .model .datalogp ], self .variables , shared
249+ initial_point , [self .model .datalogp ], self .variables , shared , self . compile_kwargs
246250 )
247251
248252 priors = [self .prior_logp_func (sample ) for sample in self .tempered_posterior ]
@@ -606,7 +610,7 @@ def systematic_resampling(weights, rng):
606610 return new_indices
607611
608612
609- def _logp_forw (point , out_vars , in_vars , shared ):
613+ def _logp_forw (point , out_vars , in_vars , shared , compile_kwargs = None ):
610614 """Compile PyTensor function of the model and the input and output variables.
611615
612616 Parameters
@@ -617,7 +621,12 @@ def _logp_forw(point, out_vars, in_vars, shared):
617621 Containing Distribution for the input variables
618622 shared : list
619623 Containing TensorVariable for depended shared data
624+ compile_kwargs: dict, optional
625+ Additional keyword arguments passed to pytensor.function
620626 """
627+ if compile_kwargs is None :
628+ compile_kwargs = {}
629+
621630 # Replace integer inputs with rounded float inputs
622631 if any (var .dtype in discrete_types for var in in_vars ):
623632 replace_int_input = {}
@@ -636,6 +645,6 @@ def _logp_forw(point, out_vars, in_vars, shared):
636645 out_list , inarray0 = join_nonshared_inputs (
637646 point = point , outputs = out_vars , inputs = in_vars , shared_inputs = shared
638647 )
639- f = compile ([inarray0 ], out_list [0 ])
648+ f = compile ([inarray0 ], out_list [0 ], ** compile_kwargs )
640649 f .trust_input = True
641650 return f
0 commit comments