@@ -96,6 +96,8 @@ def custom_transform(x):
9696from pydantic .dataclasses import dataclass
9797from pymc .distributions .shape_utils import Dims
9898
99+ from pymc_extras .deserialize import deserialize , register_deserialization
100+
99101
100102class UnsupportedShapeError (Exception ):
101103 """Error for when the shapes from variables are not compatible."""
@@ -685,6 +687,134 @@ def preliz(self):
685687
686688 return getattr (pz , self .distribution )(** self .parameters )
687689
690+ def to_dict (self ) -> dict [str , Any ]:
691+ """Convert the prior to dictionary format.
692+
693+ Returns
694+ -------
695+ dict[str, Any]
696+ The dictionary format of the prior.
697+
698+ Examples
699+ --------
700+ Convert a prior to the dictionary format.
701+
702+ .. code-block:: python
703+
704+ from pymc_extras.prior import Prior
705+
706+ dist = Prior("Normal", mu=0, sigma=1)
707+
708+ dist.to_dict()
709+
710+ Convert a hierarchical prior to the dictionary format.
711+
712+ .. code-block:: python
713+
714+ dist = Prior(
715+ "Normal",
716+ mu=Prior("Normal"),
717+ sigma=Prior("HalfNormal"),
718+ dims="channel",
719+ )
720+
721+ dist.to_dict()
722+
723+ """
724+ data : dict [str , Any ] = {
725+ "dist" : self .distribution ,
726+ }
727+ if self .parameters :
728+
729+ def handle_value (value ):
730+ if isinstance (value , Prior ):
731+ return value .to_dict ()
732+
733+ if isinstance (value , pt .TensorVariable ):
734+ value = value .eval ()
735+
736+ if isinstance (value , np .ndarray ):
737+ return value .tolist ()
738+
739+ if hasattr (value , "to_dict" ):
740+ return value .to_dict ()
741+
742+ return value
743+
744+ data ["kwargs" ] = {
745+ param : handle_value (value ) for param , value in self .parameters .items ()
746+ }
747+ if not self .centered :
748+ data ["centered" ] = False
749+
750+ if self .dims :
751+ data ["dims" ] = self .dims
752+
753+ if self .transform :
754+ data ["transform" ] = self .transform
755+
756+ return data
757+
758+ @classmethod
759+ def from_dict (cls , data ) -> Prior :
760+ """Create a Prior from the dictionary format.
761+
762+ Parameters
763+ ----------
764+ data : dict[str, Any]
765+ The dictionary format of the prior.
766+
767+ Returns
768+ -------
769+ Prior
770+ The prior distribution.
771+
772+ Examples
773+ --------
774+ Convert prior in the dictionary format to a Prior instance.
775+
776+ .. code-block:: python
777+
778+ from pymc_extras.prior import Prior
779+
780+ data = {
781+ "dist": "Normal",
782+ "kwargs": {"mu": 0, "sigma": 1},
783+ }
784+
785+ dist = Prior.from_dict(data)
786+ dist
787+ # Prior("Normal", mu=0, sigma=1)
788+
789+ """
790+ if not isinstance (data , dict ):
791+ msg = (
792+ "Must be a dictionary representation of a prior distribution. "
793+ f"Not of type: { type (data )} "
794+ )
795+ raise ValueError (msg )
796+
797+ dist = data ["dist" ]
798+ kwargs = data .get ("kwargs" , {})
799+
800+ def handle_value (value ):
801+ if isinstance (value , dict ):
802+ return deserialize (value )
803+
804+ if isinstance (value , list ):
805+ return np .array (value )
806+
807+ return value
808+
809+ kwargs = {param : handle_value (value ) for param , value in kwargs .items ()}
810+ centered = data .get ("centered" , True )
811+ dims = data .get ("dims" )
812+ if isinstance (dims , list ):
813+ dims = tuple (dims )
814+ transform = data .get ("transform" )
815+
816+ return cls (dist , dims = dims , centered = centered , transform = transform , ** kwargs )
817+
688818 def constrain (self , lower : float , upper : float , mass : float = 0.95 , kwargs = None ) -> Prior :
689819 """Create a new prior with a given mass constrained within the given bounds.
690820
@@ -1022,6 +1152,34 @@ def create_variable(self, name: str) -> pt.TensorVariable:
10221152 dims = self .dims ,
10231153 )
10241154
1155+ def to_dict (self ) -> dict [str , Any ]:
1156+ """Convert the censored distribution to a dictionary."""
1157+
1158+ def handle_value (value ):
1159+ if isinstance (value , pt .TensorVariable ):
1160+ return value .eval ().tolist ()
1161+
1162+ return value
1163+
1164+ return {
1165+ "class" : "Censored" ,
1166+ "data" : {
1167+ "dist" : self .distribution .to_dict (),
1168+ "lower" : handle_value (self .lower ),
1169+ "upper" : handle_value (self .upper ),
1170+ },
1171+ }
1172+
1173+ @classmethod
1174+ def from_dict (cls , data : dict [str , Any ]) -> Censored :
1175+ """Create a censored distribution from a dictionary."""
1176+ data = data ["data" ]
1177+ return cls ( # type: ignore
1178+ distribution = Prior .from_dict (data ["dist" ]),
1179+ lower = data ["lower" ],
1180+ upper = data ["upper" ],
1181+ )
1182+
10251183 def sample_prior (
10261184 self ,
10271185 coords = None ,
@@ -1184,3 +1342,15 @@ def create_variable(self, name: str) -> pt.TensorVariable:
11841342 """
11851343 var = self .dist .create_variable (f"{ name } _unscaled" )
11861344 return pm .Deterministic (name , var * self .factor , dims = self .dims )
1345+
1346+
1347+ def _is_prior_type (data : dict ) -> bool :
1348+ return "dist" in data
1349+
1350+
1351+ def _is_censored_type (data : dict ) -> bool :
1352+ return data .keys () == {"class" , "data" } and data ["class" ] == "Censored"
1353+
1354+
1355+ register_deserialization (is_type = _is_prior_type , deserialize = Prior .from_dict )
1356+ register_deserialization (is_type = _is_censored_type , deserialize = Censored .from_dict )
0 commit comments