@@ -145,16 +145,16 @@ def with_data(self, **updates):
145145 user_data = user_data ,
146146 )
147147
148- def _make_sampler (self , settings , cores , progress_type ):
149- model = self ._make_model ()
148+ def _make_sampler (self , settings , init_mean , cores , progress_type ):
149+ model = self ._make_model (init_mean )
150150 return _lib .PySampler .from_pymc (
151151 settings ,
152152 cores ,
153153 model ,
154154 progress_type ,
155155 )
156156
157- def _make_model (self ):
157+ def _make_model (self , init_mean ):
158158 expand_fn = _lib .ExpandFunc (
159159 self .n_dim ,
160160 self .n_expanded ,
@@ -434,7 +434,7 @@ def compile_pymc_model(
434434 gradient_backend : Literal ["pytensor" , "jax" ] = "pytensor" ,
435435 overrides : dict [Union ["Variable" , str ], np .ndarray | float | int ] | None = None ,
436436 jitter_rvs : set ["TensorVariable" ] | None = None ,
437- default_strategy : Literal ["support_point" , "prior" ] = "support_point " ,
437+ default_strategy : Literal ["support_point" , "prior" ] = "prior " ,
438438 ** kwargs ,
439439) -> CompiledModel :
440440 """Compile necessary functions for sampling a pymc model.
@@ -469,6 +469,9 @@ def compile_pymc_model(
469469 "and restart your kernel in case you are in an interactive session."
470470 )
471471
472+ if default_strategy == "support_point" and jitter_rvs is None :
473+ jitter_rvs = set (model .free_RVs )
474+
472475 initial_point_fn = make_initial_point_fn (
473476 model = model ,
474477 overrides = overrides ,
0 commit comments