Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

from pymc.backends.arviz import dict_to_dataset, to_inference_data
from pymc.backends.base import MultiTrace
from pymc.distributions.distribution import _support_point
from pymc.logprob.abstract import _logcdf, _logprob
from pymc.model import Model, modelcontext
from pymc.sampling.parallel import _cpu_count
from pymc.smc.kernels import IMH
Expand Down Expand Up @@ -383,11 +385,18 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
# main process and our worker functions
_progress = manager.dict()

# check if model contains CustomDistributions defined without dist argument
custom_methods = _find_custom_methods(params[3])

# "manually" (de)serialize params before/after multiprocessing
params = tuple(cloudpickle.dumps(p) for p in params)
kernel_kwargs = {key: cloudpickle.dumps(value) for key, value in kernel_kwargs.items()}

with ProcessPoolExecutor(max_workers=cores) as executor:
with ProcessPoolExecutor(
max_workers=cores,
initializer=_register_custom_methods,
initargs=(custom_methods,),
) as executor:
for c in range(chains): # iterate over the jobs we need to run
# set visible false so we don't have a lot of bars all at once:
task_id = progress.add_task(f"Chain {c}", status="Stage: 0 Beta: 0")
Expand All @@ -414,3 +423,25 @@ def run_chains(chains, progressbar, params, random_seed, kernel_kwargs, cores):
)

return tuple(cloudpickle.loads(r.result()) for r in futures)


def _find_custom_methods(model):
custom_methods = {}
for rv in model.free_RVs + model.observed_RVs:
cls = rv.owner.op.__class__
if hasattr(cls, "_random_fn"):
custom_methods[cloudpickle.dumps(cls)] = (
cloudpickle.dumps(_logprob.registry[cls]),
cloudpickle.dumps(_logcdf.registry[cls]),
cloudpickle.dumps(_support_point.registry[cls]),
)

return custom_methods


def _register_custom_methods(custom_methods):
for cls, (logprob, logcdf, support_point) in custom_methods.items():
cls = cloudpickle.loads(cls)
_logprob.register(cls, cloudpickle.loads(logprob))
_logcdf.register(cls, cloudpickle.loads(logcdf))
_support_point.register(cls, cloudpickle.loads(support_point))
12 changes: 12 additions & 0 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ def test_unobserved_categorical(self):

assert np.all(np.median(trace["mu"], axis=0) == [1, 2])

def test_parallel_custom(self):
def _logp(value, mu):
return -((value - mu) ** 2)

def _random(mu, rng=None, size=None):
return rng.normal(loc=mu, scale=1, size=size)

with pm.Model():
mu = pm.CustomDist("mu", 0, logp=_logp, random=_random)
Copy link
Member

@ricardoV94 ricardoV94 May 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two kinds of CustomDist, if instead of random you pass dist you get a different Op type, that I guess would still fail after this PR

Edit: I see you mentioned this in your top message

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it fail with random but not dist. Testing locally it seems to fail with both for me?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tested both the code I used in the linked issue and your example below. I still get no errors using pm.Potential or dist argument. Originally I was using pymc==5.10.0 but now I tested with pymc==5.13.0. Maybe there are differences between Windows and Linux if you're using one?

I'm not completely sure why using dist works for me but, based on some quick testing, DistributionMeta.__new__ is called when e.g. Normal is defined and the overloads for builtin distributions are registered there. I'm not well versed in multiprocessing or the way that Python does importing but my hunch is that the worker processes automatically import stuff from pymc and the overloads get registered as a side effect. For user-defined logprob etc. this is not the case since the registration isn't done during importing.

pm.CustomDist("y", mu, logp=_logp, class_name="", random=_random, observed=[1, 2])
pm.sample_smc(draws=6, cores=2)

def test_marginal_likelihood(self):
"""
Verifies that the log marginal likelihood function
Expand Down