Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@

from pymc.backends.arviz import dict_to_dataset, to_inference_data
from pymc.backends.base import MultiTrace
from pymc.distributions.distribution import _support_point
from pymc.logprob.abstract import _logcdf, _logprob
from pymc.model import Model, modelcontext
from pymc.sampling.parallel import _cpu_count
from pymc.smc.kernels import IMH
Expand Down Expand Up @@ -375,11 +377,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
# main process and our worker functions
_progress = manager.dict()

# check if model contains CustomDistributions defined without dist argument
custom_methods = _find_custom_methods(params[3])

# "manually" (de)serialize params before/after multiprocessing
params = tuple(cloudpickle.dumps(p) for p in params)
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}

with ProcessPoolExecutor(max_workers=cores) as executor:
with ProcessPoolExecutor(
max_workers=cores,
initializer=_register_custom_methods,
initargs=(custom_methods,),
) as executor:
for c in range(chains): # iterate over the jobs we need to run
# set visible false so we don't have a lot of bars all at once:
task_id = progress.add_task(
Expand All @@ -406,3 +415,25 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
progress.update(status=f"Stage: {stage} Beta: {beta:.3f}", task_id=task_id)

return tuple(cloudpickle.loads(r.result()) for r in futures)


def _find_custom_methods(model):
custom_methods = {}
for rv in model.free_RVs + model.observed_RVs:
cls = rv.owner.op.__class__
if hasattr(cls, "_random_fn"):
custom_methods[cloudpickle.dumps(cls)] = (
cloudpickle.dumps(_logprob.registry[cls]),
cloudpickle.dumps(_logcdf.registry[cls]),
cloudpickle.dumps(_support_point.registry[cls]),
)

return custom_methods


def _register_custom_methods(custom_methods):
for cls, (logprob, logcdf, support_point) in custom_methods.items():
cls = cloudpickle.loads(cls)
_logprob.register(cls, cloudpickle.loads(logprob))
_logcdf.register(cls, cloudpickle.loads(logcdf))
_support_point.register(cls, cloudpickle.loads(support_point))