diff --git a/pymc_extras/prior.py b/pymc_extras/prior.py index 9e523ddee..bbac00c1d 100644 --- a/pymc_extras/prior.py +++ b/pymc_extras/prior.py @@ -1176,7 +1176,7 @@ def from_dict(cls, data: dict[str, Any]) -> Censored: """Create a censored distribution from a dictionary.""" data = data["data"] return cls( # type: ignore - distribution=Prior.from_dict(data["dist"]), + distribution=deserialize(data["dist"]), lower=data["lower"], upper=data["upper"], ) diff --git a/tests/test_prior.py b/tests/test_prior.py index f0201a7c4..0ff142718 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -1168,3 +1168,38 @@ def test_import_incorrect_directly() -> None: match = "PyMC doesn't have a distribution of name 'SomeIncorrectDistribution'" with pytest.raises(UnsupportedDistributionError, match=match): from pymc_extras.prior import SomeIncorrectDistribution # noqa: F401 + + +@pytest.fixture +def alternative_prior_deserialize(): + def is_type(data): + return isinstance(data, dict) and "distribution" in data + + def deserialize(data): + return Prior(**data) + + register_deserialization(is_type=is_type, deserialize=deserialize) + + yield + + DESERIALIZERS.pop() + + +def test_censored_with_alternative(alternative_prior_deserialize) -> None: + data = { + "class": "Censored", + "data": { + "dist": { + "distribution": "Normal", + }, + "lower": 0, + "upper": 10, + }, + } + + instance = deserialize(data) + + assert isinstance(instance, Censored) + assert instance.lower == 0 + assert instance.upper == 10 + assert instance.distribution == Prior("Normal")