Skip to content

Commit 4430607

Browse files
aseyboldtlucianopaz
authored andcommitted
fix: add lock for pymc init point func
1 parent f4afe2f commit 4430607

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

python/nutpie/compile_pymc.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import itertools
3+
import threading
34
import warnings
45
from collections.abc import Iterable
56
from dataclasses import dataclass
@@ -31,7 +32,7 @@ def intrinsic(f):
3132
from pytensor.tensor import TensorVariable, Variable
3233

3334

34-
def rv_dict_to_flat_array_wrapper(
35+
def _rv_dict_to_flat_array_wrapper(
3536
fn: Callable[[SeedType], dict[str, np.ndarray]],
3637
names: list[str],
3738
shapes: list[tuple[int]],
@@ -509,6 +510,8 @@ def compile_pymc_model(
509510
return_transformed=True,
510511
)
511512

513+
initial_point_fn = _wrap_with_lock(initial_point_fn)
514+
512515
if backend.lower() == "numba":
513516
if gradient_backend == "jax":
514517
raise ValueError("Gradient backend cannot be jax when using numba backend")
@@ -530,7 +533,18 @@ def compile_pymc_model(
530533
raise ValueError(f"Backend must be one of numba and jax. Got {backend}")
531534

532535

533-
def _compute_shapes(model):
536+
def _wrap_with_lock(func: Callable) -> Callable:
537+
lock = threading.Lock()
538+
539+
@wraps(func)
540+
def wrapper(*args, **kwargs):
541+
with lock:
542+
return func(*args, **kwargs)
543+
544+
return wrapper
545+
546+
547+
def _compute_shapes(model) -> dict[str, tuple[int, ...]]:
534548
import pytensor
535549
from pymc.initial_point import make_initial_point_fn
536550

@@ -663,7 +677,7 @@ def _make_functions(
663677

664678
num_free_vars = count
665679

666-
initial_point_fn = rv_dict_to_flat_array_wrapper(
680+
initial_point_fn = _rv_dict_to_flat_array_wrapper(
667681
pymc_initial_point_fn, names=joined_names, shapes=joined_shapes
668682
)
669683

0 commit comments

Comments
 (0)