Skip to content

Commit b2fecad

Browse files
williambdeanandreacate
authored andcommitted
Add deserialize module (pymc-devs#489)
1 parent 8371dbe commit b2fecad

File tree

3 files changed

+403
-0
lines changed

3 files changed

+403
-0
lines changed

docs/api_reference.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ Prior
6161
Censored
6262
Scaled
6363

64+
Deserialize
65+
===========
66+
67+
.. currentmodule:: pymc_extras.deserialize
68+
.. autosummary::
69+
:toctree: generated/
70+
71+
deserialize
72+
register_deserialization
73+
Deserializer
74+
6475

6576
Transforms
6677
==========

pymc_extras/prior.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ def custom_transform(x):
9696
from pydantic.dataclasses import dataclass
9797
from pymc.distributions.shape_utils import Dims
9898

99+
from pymc_extras.deserialize import deserialize, register_deserialization
100+
99101

100102
class UnsupportedShapeError(Exception):
101103
"""Error for when the shapes from variables are not compatible."""
@@ -685,6 +687,134 @@ def preliz(self):
685687

686688
return getattr(pz, self.distribution)(**self.parameters)
687689

690+
def to_dict(self) -> dict[str, Any]:
691+
"""Convert the prior to dictionary format.
692+
693+
Returns
694+
-------
695+
dict[str, Any]
696+
The dictionary format of the prior.
697+
698+
Examples
699+
--------
700+
Convert a prior to the dictionary format.
701+
702+
.. code-block:: python
703+
704+
from pymc_extras.prior import Prior
705+
706+
dist = Prior("Normal", mu=0, sigma=1)
707+
708+
dist.to_dict()
709+
710+
Convert a hierarchical prior to the dictionary format.
711+
712+
.. code-block:: python
713+
714+
dist = Prior(
715+
"Normal",
716+
mu=Prior("Normal"),
717+
sigma=Prior("HalfNormal"),
718+
dims="channel",
719+
)
720+
721+
dist.to_dict()
722+
723+
"""
724+
data: dict[str, Any] = {
725+
"dist": self.distribution,
726+
}
727+
if self.parameters:
728+
729+
def handle_value(value):
730+
if isinstance(value, Prior):
731+
return value.to_dict()
732+
733+
if isinstance(value, pt.TensorVariable):
734+
value = value.eval()
735+
736+
if isinstance(value, np.ndarray):
737+
return value.tolist()
738+
739+
if hasattr(value, "to_dict"):
740+
return value.to_dict()
741+
742+
return value
743+
744+
data["kwargs"] = {
745+
param: handle_value(value) for param, value in self.parameters.items()
746+
}
747+
if not self.centered:
748+
data["centered"] = False
749+
750+
if self.dims:
751+
data["dims"] = self.dims
752+
753+
if self.transform:
754+
data["transform"] = self.transform
755+
756+
return data
757+
758+
@classmethod
759+
def from_dict(cls, data) -> Prior:
760+
"""Create a Prior from the dictionary format.
761+
762+
Parameters
763+
----------
764+
data : dict[str, Any]
765+
The dictionary format of the prior.
766+
767+
Returns
768+
-------
769+
Prior
770+
The prior distribution.
771+
772+
Examples
773+
--------
774+
Convert prior in the dictionary format to a Prior instance.
775+
776+
.. code-block:: python
777+
778+
from pymc_extras.prior import Prior
779+
780+
data = {
781+
"dist": "Normal",
782+
"kwargs": {"mu": 0, "sigma": 1},
783+
}
784+
785+
dist = Prior.from_dict(data)
786+
dist
787+
# Prior("Normal", mu=0, sigma=1)
788+
789+
"""
790+
if not isinstance(data, dict):
791+
msg = (
792+
"Must be a dictionary representation of a prior distribution. "
793+
f"Not of type: {type(data)}"
794+
)
795+
raise ValueError(msg)
796+
797+
dist = data["dist"]
798+
kwargs = data.get("kwargs", {})
799+
800+
def handle_value(value):
801+
if isinstance(value, dict):
802+
return deserialize(value)
803+
804+
if isinstance(value, list):
805+
return np.array(value)
806+
807+
return value
808+
809+
kwargs = {param: handle_value(value) for param, value in kwargs.items()}
810+
centered = data.get("centered", True)
811+
dims = data.get("dims")
812+
if isinstance(dims, list):
813+
dims = tuple(dims)
814+
transform = data.get("transform")
815+
816+
return cls(dist, dims=dims, centered=centered, transform=transform, **kwargs)
817+
688818
def constrain(self, lower: float, upper: float, mass: float = 0.95, kwargs=None) -> Prior:
689819
"""Create a new prior with a given mass constrained within the given bounds.
690820
@@ -1022,6 +1152,34 @@ def create_variable(self, name: str) -> pt.TensorVariable:
10221152
dims=self.dims,
10231153
)
10241154

1155+
def to_dict(self) -> dict[str, Any]:
1156+
"""Convert the censored distribution to a dictionary."""
1157+
1158+
def handle_value(value):
1159+
if isinstance(value, pt.TensorVariable):
1160+
return value.eval().tolist()
1161+
1162+
return value
1163+
1164+
return {
1165+
"class": "Censored",
1166+
"data": {
1167+
"dist": self.distribution.to_dict(),
1168+
"lower": handle_value(self.lower),
1169+
"upper": handle_value(self.upper),
1170+
},
1171+
}
1172+
1173+
@classmethod
1174+
def from_dict(cls, data: dict[str, Any]) -> Censored:
1175+
"""Create a censored distribution from a dictionary."""
1176+
data = data["data"]
1177+
return cls( # type: ignore
1178+
distribution=Prior.from_dict(data["dist"]),
1179+
lower=data["lower"],
1180+
upper=data["upper"],
1181+
)
1182+
10251183
def sample_prior(
10261184
self,
10271185
coords=None,
@@ -1184,3 +1342,15 @@ def create_variable(self, name: str) -> pt.TensorVariable:
11841342
"""
11851343
var = self.dist.create_variable(f"{name}_unscaled")
11861344
return pm.Deterministic(name, var * self.factor, dims=self.dims)
1345+
1346+
1347+
def _is_prior_type(data: dict) -> bool:
1348+
return "dist" in data
1349+
1350+
1351+
def _is_censored_type(data: dict) -> bool:
1352+
return data.keys() == {"class", "data"} and data["class"] == "Censored"
1353+
1354+
1355+
register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_dict)
1356+
register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict)

0 commit comments

Comments
 (0)