Skip to content

Commit 762401c

Browse files
Change initialization defaults to include randomness
1 parent 9f936b5 commit 762401c

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

python/nutpie/compile_pymc.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,16 @@ def with_data(self, **updates):
145145
user_data=user_data,
146146
)
147147

148-
def _make_sampler(self, settings, cores, progress_type):
149-
model = self._make_model()
148+
def _make_sampler(self, settings, init_mean, cores, progress_type):
149+
model = self._make_model(init_mean)
150150
return _lib.PySampler.from_pymc(
151151
settings,
152152
cores,
153153
model,
154154
progress_type,
155155
)
156156

157-
def _make_model(self):
157+
def _make_model(self, init_mean):
158158
expand_fn = _lib.ExpandFunc(
159159
self.n_dim,
160160
self.n_expanded,
@@ -434,7 +434,7 @@ def compile_pymc_model(
434434
gradient_backend: Literal["pytensor", "jax"] = "pytensor",
435435
overrides: dict[Union["Variable", str], np.ndarray | float | int] | None = None,
436436
jitter_rvs: set["TensorVariable"] | None = None,
437-
default_strategy: Literal["support_point", "prior"] = "support_point",
437+
default_strategy: Literal["support_point", "prior"] = "prior",
438438
**kwargs,
439439
) -> CompiledModel:
440440
"""Compile necessary functions for sampling a pymc model.
@@ -469,6 +469,9 @@ def compile_pymc_model(
469469
"and restart your kernel in case you are in an interactive session."
470470
)
471471

472+
if default_strategy == "support_point" and jitter_rvs is None:
473+
jitter_rvs = set(model.free_RVs)
474+
472475
initial_point_fn = make_initial_point_fn(
473476
model=model,
474477
overrides=overrides,

0 commit comments

Comments
 (0)