Skip to content

Commit 85f9016

Browse files
authored
Deserialize arbitrary objects (#1296)
* support for deserialization of three classes * add the deserialize logic * correct the type hint * add an example * separate out the error * catch error at deserialization * relax the input type * add test suite * add test for arbitrary serialization via Prior * use deserialize within from_json * test for deserialize support within Prior * use general deserialization in individual media transformations * test both deserialize funcs * test arb deserialization in adstock * add similar deserialization check for saturation * better naming of the tests * support parsing of hsgp kwargs * add to the module level docstring * allow VariableFactory in parse_model_config * add module to documentation * add to the module documentation * Reorder the top level documentation * add more docstring about the functions
1 parent 35c4e14 commit 85f9016

File tree

13 files changed

+597
-25
lines changed

13 files changed

+597
-25
lines changed

docs/source/api/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@
1515
prior
1616
metrics
1717
mlflow
18+
deserialize
1819
```

pymc_marketing/deserialize.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Deserialize into a PyMC-Marketing object.
15+
16+
This is a two step process:
17+
18+
1. Determine if the data is of the correct type.
19+
2. Deserialize the data into a python object for PyMC-Marketing.
20+
21+
This is used to deserialize JSON data into PyMC-Marketing objects
22+
throughout the package.
23+
24+
Examples
25+
--------
26+
Make use of the already registered PyMC-Marketing deserializers:
27+
28+
.. code-block:: python
29+
30+
from pymc_marketing.deserialize import deserialize
31+
32+
prior_class_data = {
33+
"dist": "Normal",
34+
"kwargs": {"mu": 0, "sigma": 1}
35+
}
36+
prior = deserialize(prior_class_data)
37+
# Prior("Normal", mu=0, sigma=1)
38+
39+
Register custom class deserialization:
40+
41+
.. code-block:: python
42+
43+
from pymc_marketing.deserialize import register_deserialization
44+
45+
class MyClass:
46+
def __init__(self, value: int):
47+
self.value = value
48+
49+
def to_dict(self) -> dict:
50+
# Example of what the to_dict method might look like.
51+
return {"value": self.value}
52+
53+
register_deserialization(
54+
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
55+
deserialize=lambda data: MyClass(value=data["value"]),
56+
)
57+
58+
Deserialize data into that custom class:
59+
60+
.. code-block:: python
61+
62+
from pymc_marketing.deserialize import deserialize
63+
64+
data = {"value": 42}
65+
obj = deserialize(data)
66+
assert isinstance(obj, MyClass)
67+
68+
69+
"""
70+
71+
from collections.abc import Callable
72+
from dataclasses import dataclass
73+
from typing import Any
74+
75+
IsType = Callable[[Any], bool]
76+
Deserialize = Callable[[Any], Any]
77+
78+
79+
@dataclass
80+
class Deserializer:
81+
"""Object to store information required for deserialization.
82+
83+
All deserializers should be stored via the :func:`register_deserialization` function
84+
instead of creating this object directly.
85+
86+
Attributes
87+
----------
88+
is_type : IsType
89+
Function to determine if the data is of the correct type.
90+
deserialize : Deserialize
91+
Function to deserialize the data.
92+
93+
Examples
94+
--------
95+
.. code-block:: python
96+
97+
from typing import Any
98+
99+
class MyClass:
100+
def __init__(self, value: int):
101+
self.value = value
102+
103+
from pymc_marketing.deserialize import Deserializer
104+
105+
def is_type(data: Any) -> bool:
106+
return data.keys() == {"value"} and isinstance(data["value"], int)
107+
108+
def deserialize(data: dict) -> MyClass:
109+
return MyClass(value=data["value"])
110+
111+
deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
112+
113+
"""
114+
115+
is_type: IsType
116+
deserialize: Deserialize
117+
118+
119+
DESERIALIZERS: list[Deserializer] = []
120+
121+
122+
class DeserializableError(Exception):
123+
"""Error raised when data cannot be deserialized."""
124+
125+
def __init__(self, data: Any):
126+
self.data = data
127+
super().__init__(
128+
f"Couldn't deserialize {data}. Use register_deserialization to add a deserialization mapping."
129+
)
130+
131+
132+
def deserialize(data: Any) -> Any:
133+
"""Deserialize a dictionary into a Python object.
134+
135+
Use the :func:`register_deserialization` function to add custom deserializations.
136+
137+
Deserialization is a two step process due to the dynamic nature of the data:
138+
139+
1. Determine if the data is of the correct type.
140+
2. Deserialize the data into a Python object.
141+
142+
Each registered deserialization is checked in order until one is found that can
143+
deserialize the data. If no deserialization is found, a :class:`DeserializableError` is raised.
144+
145+
A :class:`DeserializableError` is raised when the data fails to be deserialized
146+
by any of the registered deserializers.
147+
148+
Parameters
149+
----------
150+
data : Any
151+
The data to deserialize.
152+
153+
Returns
154+
-------
155+
Any
156+
The deserialized object.
157+
158+
Raises
159+
------
160+
DeserializableError
161+
Raised when the data doesn't match any registered deserializations
162+
or fails to be deserialized.
163+
164+
Examples
165+
--------
166+
Deserialize a :class:`pymc_marketing.prior.Prior` object:
167+
168+
.. code-block:: python
169+
170+
from pymc_marketing.deserialize import deserialize
171+
172+
data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
173+
prior = deserialize(data)
174+
# Prior("Normal", mu=0, sigma=1)
175+
176+
"""
177+
for mapping in DESERIALIZERS:
178+
try:
179+
is_type = mapping.is_type(data)
180+
except Exception:
181+
is_type = False
182+
183+
if not is_type:
184+
continue
185+
186+
try:
187+
return mapping.deserialize(data)
188+
except Exception as e:
189+
raise DeserializableError(data) from e
190+
else:
191+
raise DeserializableError(data)
192+
193+
194+
def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
195+
"""Register an arbitrary deserialization.
196+
197+
Use the :func:`deserialize` function to then deserialize data using all registered
198+
deserialize functions.
199+
200+
Classes from PyMC-Marketing have their deserialization mappings registered
201+
automatically. However, custom classes will need to be registered manually
202+
using this function before they can be deserialized.
203+
204+
Parameters
205+
----------
206+
is_type : Callable[[Any], bool]
207+
Function to determine if the data is of the correct type.
208+
deserialize : Callable[[dict], Any]
209+
Function to deserialize the data of that type.
210+
211+
Examples
212+
--------
213+
Register a custom class deserialization:
214+
215+
.. code-block:: python
216+
217+
from pymc_marketing.deserialize import register_deserialization
218+
219+
class MyClass:
220+
def __init__(self, value: int):
221+
self.value = value
222+
223+
def to_dict(self) -> dict:
224+
# Example of what the to_dict method might look like.
225+
return {"value": self.value}
226+
227+
register_deserialization(
228+
is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
229+
deserialize=lambda data: MyClass(value=data["value"]),
230+
)
231+
232+
Use that custom class deserialization:
233+
234+
.. code-block:: python
235+
236+
from pymc_marketing.deserialize import deserialize
237+
238+
data = {"value": 42}
239+
obj = deserialize(data)
240+
assert isinstance(obj, MyClass)
241+
242+
"""
243+
mapping = Deserializer(is_type=is_type, deserialize=deserialize)
244+
DESERIALIZERS.append(mapping)

pymc_marketing/hsgp_kwargs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import pymc as pm
1919
from pydantic import BaseModel, Field, InstanceOf
2020

21+
from pymc_marketing.deserialize import register_deserialization
22+
2123

2224
class HSGPKwargs(BaseModel):
2325
"""HSGP keyword arguments for the time-varying prior.
@@ -80,3 +82,20 @@ class HSGPKwargs(BaseModel):
8082
cov_func: InstanceOf[pm.gp.cov.Covariance] | str | None = Field(
8183
None, description="Gaussian process Covariance function"
8284
)
85+
86+
87+
def _is_hsgp_kwargs(data) -> bool:
88+
return isinstance(data, dict) and data.keys() == {
89+
"m",
90+
"L",
91+
"eta_lam",
92+
"ls_mu",
93+
"ls_sigma",
94+
"cov_func",
95+
}
96+
97+
98+
register_deserialization(
99+
is_type=_is_hsgp_kwargs,
100+
deserialize=lambda data: HSGPKwargs.model_validate(data),
101+
)

pymc_marketing/mmm/components/adstock.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ def function(self, x, alpha):
5656
import xarray as xr
5757
from pydantic import Field, validate_call
5858

59+
from pymc_marketing.deserialize import deserialize, register_deserialization
5960
from pymc_marketing.mmm.components.base import (
6061
SupportedPrior,
6162
Transformation,
62-
_deserialize,
6363
)
6464
from pymc_marketing.mmm.transformers import (
6565
ConvMode,
@@ -345,6 +345,16 @@ def adstock_from_dict(data: dict) -> AdstockTransformation:
345345
cls = ADSTOCK_TRANSFORMATIONS[lookup_name]
346346

347347
if "priors" in data:
348-
data["priors"] = {k: _deserialize(v) for k, v in data["priors"].items()}
348+
data["priors"] = {k: deserialize(v) for k, v in data["priors"].items()}
349349

350350
return cls(**data)
351+
352+
353+
def _is_adstock(data):
354+
return "lookup_name" in data and data["lookup_name"] in ADSTOCK_TRANSFORMATIONS
355+
356+
357+
register_deserialization(
358+
is_type=_is_adstock,
359+
deserialize=adstock_from_dict,
360+
)

pymc_marketing/mmm/components/base.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -592,10 +592,3 @@ def _serialize_value(value: Any) -> Any:
592592
return value.tolist()
593593

594594
return value
595-
596-
597-
def _deserialize(value):
598-
try:
599-
return Prior.from_json(value)
600-
except Exception:
601-
return value

pymc_marketing/mmm/components/saturation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ def function(self, x, b):
7676
import xarray as xr
7777
from pydantic import Field, InstanceOf, validate_call
7878

79+
from pymc_marketing.deserialize import deserialize, register_deserialization
7980
from pymc_marketing.mmm.components.base import (
8081
Transformation,
81-
_deserialize,
8282
)
8383
from pymc_marketing.mmm.transformers import (
8484
hill_function,
@@ -483,6 +483,13 @@ def saturation_from_dict(data: dict) -> SaturationTransformation:
483483

484484
if "priors" in data:
485485
data["priors"] = {
486-
key: _deserialize(value) for key, value in data["priors"].items()
486+
key: deserialize(value) for key, value in data["priors"].items()
487487
}
488488
return cls(**data)
489+
490+
491+
def _is_saturation(data):
492+
return "lookup_name" in data and data["lookup_name"] in SATURATION_TRANSFORMATIONS
493+
494+
495+
register_deserialization(_is_saturation, saturation_from_dict)

pymc_marketing/model_config.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,16 @@
1616
import warnings
1717
from typing import Any
1818

19+
from pymc_marketing.deserialize import deserialize
1920
from pymc_marketing.hsgp_kwargs import HSGPKwargs
20-
from pymc_marketing.prior import Prior
21+
from pymc_marketing.prior import Prior, VariableFactory
2122

2223

2324
class ModelConfigError(Exception):
2425
"""Exception raised for errors in model configuration."""
2526

2627

27-
ModelConfig = dict[str, HSGPKwargs | Prior | Any]
28+
ModelConfig = dict[str, VariableFactory | HSGPKwargs | Prior | Any]
2829

2930

3031
def parse_model_config(
@@ -121,11 +122,11 @@ def handle_prior_config(name, prior_config):
121122
if name in non_distributions or name in hsgp_kwargs_fields:
122123
return prior_config
123124

124-
if isinstance(prior_config, Prior):
125+
if isinstance(prior_config, Prior) or isinstance(prior_config, VariableFactory):
125126
return prior_config
126127

127128
try:
128-
dist = Prior.from_json(prior_config)
129+
dist = deserialize(prior_config)
129130
except Exception as e:
130131
parse_errors.append(f"Parameter {name}: {e}")
131132
else:

0 commit comments

Comments
 (0)