11import dataclasses
22import itertools
3+ import threading
34import warnings
45from collections .abc import Iterable
56from dataclasses import dataclass
@@ -31,7 +32,7 @@ def intrinsic(f):
3132 from pytensor .tensor import TensorVariable , Variable
3233
3334
34- def rv_dict_to_flat_array_wrapper (
35+ def _rv_dict_to_flat_array_wrapper (
3536 fn : Callable [[SeedType ], dict [str , np .ndarray ]],
3637 names : list [str ],
3738 shapes : list [tuple [int ]],
@@ -509,6 +510,8 @@ def compile_pymc_model(
509510 return_transformed = True ,
510511 )
511512
513+ initial_point_fn = _wrap_with_lock (initial_point_fn )
514+
512515 if backend .lower () == "numba" :
513516 if gradient_backend == "jax" :
514517 raise ValueError ("Gradient backend cannot be jax when using numba backend" )
@@ -530,7 +533,18 @@ def compile_pymc_model(
530533 raise ValueError (f"Backend must be one of numba and jax. Got { backend } " )
531534
532535
533- def _compute_shapes (model ):
536+ def _wrap_with_lock (func : Callable ) -> Callable :
537+ lock = threading .Lock ()
538+
539+ @wraps (func )
540+ def wrapper (* args , ** kwargs ):
541+ with lock :
542+ return func (* args , ** kwargs )
543+
544+ return wrapper
545+
546+
547+ def _compute_shapes (model ) -> dict [str , tuple [int , ...]]:
534548 import pytensor
535549 from pymc .initial_point import make_initial_point_fn
536550
@@ -663,7 +677,7 @@ def _make_functions(
663677
664678 num_free_vars = count
665679
666- initial_point_fn = rv_dict_to_flat_array_wrapper (
680+ initial_point_fn = _rv_dict_to_flat_array_wrapper (
667681 pymc_initial_point_fn , names = joined_names , shapes = joined_shapes
668682 )
669683
0 commit comments