Skip to content

Commit d766722

Browse files
PoC: Use Pydantic as data validator (#809)
* prior with pydantic * dependencies * validate adstock * make mypy happy * add validation sample curve * make the prior type tighter * add test type * add validation init mmm * mmm * start with Fourier * fix type * fix test and imprtove docstrings * docstrings * types * self type * init validator * types model builder * improve docstrings * more input validations mmm init * validation budget optimizer * fix dummy example types * hsgp kwargs class * add kwargs * undo type hint in dict * fix fourier names * better docs * fix tests * add type hint * undo * fix type error * feedback2 * restrict signature * serialize fourier * docs and tests * fix docs * work on parsing * add hsgp to parsing config * add tests * uncomment * undo changes * undo model config parser * handle hsgp_kwargs * add hsgp flag * docs * undo type hint * improve hints * add more sections to docs * Update pymc_marketing/mmm/tvp.py Co-authored-by: Will Dean <[email protected]> * feedback 4 * fix test --------- Co-authored-by: Will Dean <[email protected]>
1 parent a97e272 commit d766722

22 files changed

+480
-235
lines changed

docs/source/api/index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
:toctree: generated/
99
1010
clv
11+
hsgp_kwargs
1112
mmm
1213
model_config
14+
model_builder
1315
prior
1416
```

docs/source/notebooks/mmm/mmm_example.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1119,7 +1119,7 @@
11191119
"source": [
11201120
"dummy_model = MMM(\n",
11211121
" date_column=\"\",\n",
1122-
" channel_columns=\"\",\n",
1122+
" channel_columns=[\"\"],\n",
11231123
" adstock=\"geometric\",\n",
11241124
" saturation=\"logistic\",\n",
11251125
" adstock_max_lag=4,\n",

pymc_marketing/clv/models/basic.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import arviz as az
2121
import pandas as pd
2222
import pymc as pm
23+
from pydantic import ConfigDict, InstanceOf, validate_call
2324
from pymc.backends import NDArray
2425
from pymc.backends.base import MultiTrace
2526
from pymc.model.core import Model
@@ -32,11 +33,12 @@
3233
class CLVModel(ModelBuilder):
3334
_model_type = "CLVModel"
3435

36+
@validate_call(config=ConfigDict(arbitrary_types_allowed=True))
3537
def __init__(
3638
self,
3739
data: pd.DataFrame,
3840
*,
39-
model_config: ModelConfig | None = None,
41+
model_config: InstanceOf[ModelConfig] | None = None,
4042
sampler_config: dict | None = None,
4143
non_distributions: list[str] | None = None,
4244
):
@@ -65,7 +67,7 @@ def _validate_cols(
6567
if data[required_col].nunique() != n:
6668
raise ValueError(f"Column {required_col} has duplicate entries")
6769

68-
def __repr__(self):
70+
def __repr__(self) -> str:
6971
if not hasattr(self, "model"):
7072
return self._model_type
7173
else:

pymc_marketing/hsgp_kwargs.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2024 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Class to store and validate keyword argument for the Hilbert Space Gaussian Process (HSGP) components."""
15+
16+
from typing import Annotated
17+
18+
import pymc as pm
19+
from pydantic import BaseModel, Field, InstanceOf
20+
21+
22+
class HSGPKwargs(BaseModel):
23+
"""HSGP keyword arguments for the time-varying prior.
24+
25+
See [1]_ and [2]_ for the theoretical background on the Hilbert Space Gaussian Process (HSGP).
26+
See , [6]_ for a practical guide through the method using code examples.
27+
See the :class:`~pymc.gp.HSGP` class for more information on the Hilbert Space Gaussian Process in PyMC.
28+
We also recommend the following resources for a more practical introduction to HSGP: [3]_, [4]_, [5]_.
29+
30+
References
31+
----------
32+
.. [1] Solin, A., Sarkka, S. (2019) Hilbert Space Methods for Reduced-Rank Gaussian Process Regression.
33+
.. [2] Ruitort-Mayol, G., and Anderson, M., and Solin, A., and Vehtari, A. (2022). Practical Hilbert Space Approximate Bayesian Gaussian Processes for Probabilistic Programming.
34+
.. [3] PyMC Example Gallery: `"Gaussian Processes: HSGP Reference & First Steps" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Basic.html>`_.
35+
.. [4] PyMC Example Gallery: `"Gaussian Processes: HSGP Advanced Usage" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/HSGP-Advanced.html>`_.
36+
.. [5] PyMC Example Gallery: `"Baby Births Modelling with HSGPs" <https://www.pymc.io/projects/examples/en/latest/gaussian_processes/GP-Births.html>`_.
37+
.. [6] Orduz, J. `"A Conceptual and Practical Introduction to Hilbert Space GPs Approximation Methods" <https://juanitorduz.github.io/hsgp_intro/>`_.
38+
39+
Parameters
40+
----------
41+
m : int
42+
Number of basis functions. Default is 200.
43+
L : float, optional
44+
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data
45+
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP).
46+
By default it is None.
47+
eta_lam : float
48+
Exponential prior for the variance. Default is 1.
49+
ls_mu : float
50+
Mean of the inverse gamma prior for the lengthscale. Default is 5.
51+
ls_sigma : float
52+
Standard deviation of the inverse gamma prior for the lengthscale. Default is 5.
53+
cov_func : ~pymc.gp.cov.Covariance, optional
54+
Gaussian process Covariance function. By default it is None.
55+
""" # noqa E501
56+
57+
m: int = Field(200, description="Number of basis functions")
58+
L: (
59+
Annotated[
60+
float,
61+
Field(
62+
gt=0,
63+
description="""
64+
Extent of basis functions. Set this to reflect the expected range of in+out-of-sample data
65+
(considering that time-indices are zero-centered).Default is `X_mid * 2` (identical to `c=2` in HSGP)
66+
""",
67+
),
68+
]
69+
| None
70+
) = None
71+
eta_lam: float = Field(1, gt=0, description="Exponential prior for the variance")
72+
ls_mu: float = Field(
73+
5, gt=0, description="Mean of the inverse gamma prior for the lengthscale"
74+
)
75+
ls_sigma: float = Field(
76+
5,
77+
gt=0,
78+
description="Standard deviation of the inverse gamma prior for the lengthscale",
79+
)
80+
cov_func: InstanceOf[pm.gp.cov.Covariance] | None = Field(
81+
None, description="Gaussian process Covariance function"
82+
)

pymc_marketing/mmm/budget_optimizer.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any
1818

1919
import numpy as np
20+
from pydantic import BaseModel, ConfigDict, Field
2021
from scipy.optimize import minimize
2122

2223
from pymc_marketing.mmm.components.adstock import AdstockTransformation
@@ -30,7 +31,7 @@ def __init__(self, message: str):
3031
super().__init__(message)
3132

3233

33-
class BudgetOptimizer:
34+
class BudgetOptimizer(BaseModel):
3435
"""
3536
A class for optimizing budget allocation in a marketing mix model.
3637
@@ -58,19 +59,21 @@ class BudgetOptimizer:
5859
Default is True.
5960
"""
6061

61-
def __init__(
62-
self,
63-
adstock: AdstockTransformation,
64-
saturation: SaturationTransformation,
65-
num_days: int,
66-
parameters: dict[str, dict[str, dict[str, float]]],
67-
adstock_first: bool = True,
68-
):
69-
self.adstock = adstock
70-
self.saturation = saturation
71-
self.num_days = num_days
72-
self.parameters = parameters
73-
self.adstock_first = adstock_first
62+
adstock: AdstockTransformation = Field(
63+
..., description="The adstock transformation class."
64+
)
65+
saturation: SaturationTransformation = Field(
66+
..., description="The saturation transformation class."
67+
)
68+
num_days: int = Field(..., gt=0, description="The number of days.")
69+
parameters: dict[str, dict[str, dict[str, float]]] = Field(
70+
..., description="A dictionary of parameters for each channel."
71+
)
72+
adstock_first: bool = Field(
73+
True,
74+
description="Whether to apply adstock transformation first or saturation transformation first.",
75+
)
76+
model_config = ConfigDict(arbitrary_types_allowed=True)
7477

7578
def objective(self, budgets: list[float]) -> float:
7679
"""

pymc_marketing/mmm/components/adstock.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def function(self, x, alpha):
5454

5555
import numpy as np
5656
import xarray as xr
57+
from pydantic import Field, InstanceOf, validate_call
5758

5859
from pymc_marketing.mmm.components.base import Transformation
5960
from pymc_marketing.mmm.transformers import (
@@ -81,13 +82,20 @@ class AdstockTransformation(Transformation):
8182
prefix: str = "adstock"
8283
lookup_name: str
8384

85+
@validate_call
8486
def __init__(
8587
self,
86-
l_max: int,
87-
normalize: bool = True,
88-
mode: ConvMode = ConvMode.After,
89-
priors: dict | None = None,
90-
prefix: str | None = None,
88+
l_max: int = Field(
89+
..., gt=0, description="Maximum lag for the adstock transformation."
90+
),
91+
normalize: bool = Field(
92+
True, description="Whether to normalize the adstock values."
93+
),
94+
mode: ConvMode = Field(ConvMode.After, description="Convolution mode."),
95+
priors: dict[str, str | InstanceOf[Prior]] | None = Field(
96+
default=None, description="Priors for the parameters."
97+
),
98+
prefix: str | None = Field(None, description="Prefix for the parameters."),
9199
) -> None:
92100
self.l_max = l_max
93101
self.normalize = normalize
@@ -368,16 +376,22 @@ def _get_adstock_function(
368376
if isinstance(function, AdstockTransformation):
369377
return function
370378

371-
if function not in ADSTOCK_TRANSFORMATIONS:
379+
elif isinstance(function, str):
380+
if function not in ADSTOCK_TRANSFORMATIONS:
381+
raise ValueError(
382+
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
383+
)
384+
385+
if kwargs:
386+
warnings.warn(
387+
"The preferred method of initializing a lagging function is to use the class directly.",
388+
DeprecationWarning,
389+
stacklevel=1,
390+
)
391+
392+
return ADSTOCK_TRANSFORMATIONS[function](**kwargs)
393+
394+
else:
372395
raise ValueError(
373396
f"Unknown adstock function: {function}. Choose from {list(ADSTOCK_TRANSFORMATIONS.keys())}"
374397
)
375-
376-
if kwargs:
377-
warnings.warn(
378-
"The preferred method of initializing a lagging function is to use the class directly.",
379-
DeprecationWarning,
380-
stacklevel=1,
381-
)
382-
383-
return ADSTOCK_TRANSFORMATIONS[function](**kwargs)

pymc_marketing/mmm/components/saturation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def function(self, x, b):
7171

7272
import numpy as np
7373
import xarray as xr
74+
from pydantic import Field, InstanceOf, validate_call
7475

7576
from pymc_marketing.mmm.components.base import Transformation
7677
from pymc_marketing.mmm.transformers import (
@@ -130,10 +131,13 @@ class InfiniteReturns(SaturationTransformation):
130131

131132
prefix: str = "saturation"
132133

134+
@validate_call
133135
def sample_curve(
134136
self,
135-
parameters: xr.Dataset,
136-
max_value: float = 1.0,
137+
parameters: InstanceOf[xr.Dataset] = Field(
138+
..., description="Parameters of the saturation transformation."
139+
),
140+
max_value: float = Field(1.0, gt=0, description="Maximum range value."),
137141
) -> xr.DataArray:
138142
"""Sample the curve of the saturation transformation given parameters.
139143

0 commit comments

Comments
 (0)