diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 12eb36a0d..073d7dd27 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -84,6 +84,7 @@ def custom_transform(x): import copy from collections.abc import Callable +from functools import partial from inspect import signature from typing import Any, Protocol, runtime_checkable @@ -1354,3 +1355,34 @@ def _is_censored_type(data: dict) -> bool: register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict) register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict) + + +def __getattr__(name: str): + """Get Prior class through the module. + + Examples + -------- + Create a normal distribution. + + .. code-block:: python + + from pymc_extras.prior import Normal + + dist = Normal(mu=1, sigma=2) + + Create a hierarchical normal distribution. + + .. code-block:: python + + import pymc_extras.prior as pr + + dist = pr.Normal(mu=pr.Normal(), sigma=pr.HalfNormal(), dims="channel") + samples = dist.sample_prior(coords={"channel": ["C1", "C2", "C3"]}) + + """ + # Protect against doctest + if name == "__wrapped__": + return + + _get_pymc_distribution(name) + return partial(Prior, distribution=name) diff --git a/tests/test_prior.py b/tests/test_prior.py index 70729b9f9..b002fe691 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -12,6 +12,8 @@ from pydantic import ValidationError from pymc.model_graph import fast_eval +import pymc_extras.prior as pr + from pymc_extras.deserialize import ( DESERIALIZERS, deserialize, @@ -1141,3 +1143,22 @@ def test_scaled_sample_prior() -> None: assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} assert "scaled_var" in prior assert "scaled_var_unscaled" in prior + + +def test_getattr() -> None: + assert pr.Normal() == Prior("Normal") + + +def test_import_directly() -> None: + try: + from pymc_extras.prior import Normal + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + assert Normal() == Prior("Normal") + + +def test_import_incorrect_directly() -> None: + match = "PyMC doesn't have a distribution of name 'SomeIncorrectDistribution'" + with pytest.raises(UnsupportedDistributionError, match=match): + from pymc_extras.prior import SomeIncorrectDistribution # noqa: F401