3131
3232from pymc .backends .arviz import dict_to_dataset , to_inference_data
3333from pymc .backends .base import MultiTrace
34+ from pymc .distributions .distribution import _support_point
35+ from pymc .logprob .abstract import _logcdf , _logprob
3436from pymc .model import Model , modelcontext
3537from pymc .sampling .parallel import _cpu_count
3638from pymc .smc .kernels import IMH
@@ -375,11 +377,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
375377 # main process and our worker functions
376378 _progress = manager .dict ()
377379
380+ # check if model contains CustomDistributions defined without dist argument
381+ custom_methods = _find_custom_methods (params [3 ])
382+
378383 # "manually" (de)serialize params before/after multiprocessing
379384 params = tuple (cloudpickle .dumps (p ) for p in params )
380385 kernel_kwargs = {key : cloudpickle .dumps (value ) for key , value in kernel_kwargs .items ()}
381386
382- with ProcessPoolExecutor (max_workers = cores ) as executor :
387+ with ProcessPoolExecutor (
388+ max_workers = cores ,
389+ initializer = _register_custom_methods ,
390+ initargs = (custom_methods ,),
391+ ) as executor :
383392 for c in range (chains ): # iterate over the jobs we need to run
384393 # set visible false so we don't have a lot of bars all at once:
385394 task_id = progress .add_task (
@@ -406,3 +415,25 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
406415 progress .update (status = f"Stage: { stage } Beta: { beta :.3f} " , task_id = task_id )
407416
408417 return tuple (cloudpickle .loads (r .result ()) for r in futures )
418+
419+
420+ def _find_custom_methods (model ):
421+ custom_methods = {}
422+ for rv in model .free_RVs + model .observed_RVs :
423+ cls = rv .owner .op .__class__
424+ if hasattr (cls , "_random_fn" ):
425+ custom_methods [cloudpickle .dumps (cls )] = (
426+ cloudpickle .dumps (_logprob .registry [cls ]),
427+ cloudpickle .dumps (_logcdf .registry [cls ]),
428+ cloudpickle .dumps (_support_point .registry [cls ]),
429+ )
430+
431+ return custom_methods
432+
433+
434+ def _register_custom_methods (custom_methods ):
435+ for cls , (logprob , logcdf , support_point ) in custom_methods .items ():
436+ cls = cloudpickle .loads (cls )
437+ _logprob .register (cls , cloudpickle .loads (logprob ))
438+ _logcdf .register (cls , cloudpickle .loads (logcdf ))
439+ _support_point .register (cls , cloudpickle .loads (support_point ))
0 commit comments