Skip to content

Commit 8600ef3

Browse files
Allow VariableFactory in FourierBase (#1304)
* Allow VariableFactory in FourierBase * change docstrings and loosen serialization * tests for changes --------- Co-authored-by: Will Dean <[email protected]> Co-authored-by: Will Dean <[email protected]>
1 parent b3f740d commit 8600ef3

File tree

2 files changed

+60
-10
lines changed

2 files changed

+60
-10
lines changed

pymc_marketing/mmm/fourier.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@
223223

224224
from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR
225225
from pymc_marketing.plot import SelToString, plot_curve, plot_hdi, plot_samples
226-
from pymc_marketing.prior import Prior, create_dim_handler
226+
from pymc_marketing.prior import Prior, VariableFactory, create_dim_handler
227227

228228
X_NAME: str = "day"
229229
NON_GRID_NAMES: frozenset[str] = frozenset({X_NAME})
@@ -274,8 +274,8 @@ class FourierBase(BaseModel):
274274
prefix : str, optional
275275
Alternative prefix for the fourier seasonality, by default None or
276276
"fourier"
277-
prior : Prior, optional
278-
Prior distribution for the fourier seasonality beta parameters, by
277+
prior : Prior | VariableFactory, optional
278+
Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
279279
default `Prior("Laplace", mu=0, b=1)`
280280
variable_name : str, optional
281281
Name of the variable that multiplies the fourier modes. By default None,
@@ -286,17 +286,21 @@ class FourierBase(BaseModel):
286286
n_order: int = Field(..., gt=0)
287287
days_in_period: float = Field(..., gt=0)
288288
prefix: str = Field("fourier")
289-
prior: InstanceOf[Prior] = Field(Prior("Laplace", mu=0, b=1))
289+
prior: InstanceOf[Prior] | InstanceOf[VariableFactory] = Field(
290+
Prior("Laplace", mu=0, b=1)
291+
)
290292
variable_name: str | None = Field(None)
291293

292294
def model_post_init(self, __context: Any) -> None:
293295
"""Model post initialization for a Pydantic model."""
294296
if self.variable_name is None:
295297
self.variable_name = f"{self.prefix}_beta"
296298

297-
if not self.prior.dims:
299+
if not self.prior.dims and isinstance(self.prior, Prior):
298300
self.prior = self.prior.deepcopy()
299301
self.prior.dims = self.prefix
302+
elif not self.prior.dims:
303+
self.prior.dims = self.prefix
300304

301305
@model_validator(mode="after")
302306
def _check_variable_name(self) -> Self:
@@ -311,7 +315,7 @@ def _check_prior_has_right_dimensions(self) -> Self:
311315
return self
312316

313317
@field_serializer("prior", when_used="json")
314-
def serialize_prior(prior: Prior) -> dict[str, Any]:
318+
def serialize_prior(prior: Any) -> dict[str, Any]:
315319
"""Serialize the prior distribution.
316320
317321
Parameters
@@ -325,6 +329,9 @@ def serialize_prior(prior: Prior) -> dict[str, Any]:
325329
The serialized prior distribution.
326330
327331
"""
332+
if hasattr(prior, "to_dict"):
333+
return prior.to_dict()
334+
328335
return prior.to_json()
329336

330337
@property
@@ -718,8 +725,8 @@ class YearlyFourier(FourierBase):
718725
prefix : str, optional
719726
Alternative prefix for the fourier seasonality, by default None or
720727
"fourier"
721-
prior : Prior, optional
722-
Prior distribution for the fourier seasonality beta parameters, by
728+
prior : Prior | VariableFactory, optional
729+
Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
723730
default `Prior("Laplace", mu=0, b=1)`
724731
name : str, optional
725732
Name of the variable that multiplies the fourier modes, by default None
@@ -774,8 +781,8 @@ class MonthlyFourier(FourierBase):
774781
prefix : str, optional
775782
Alternative prefix for the fourier seasonality, by default None or
776783
"fourier"
777-
prior : Prior, optional
778-
Prior distribution for the fourier seasonality beta parameters, by
784+
prior : Prior | VariableFactory, optional
785+
Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
779786
default `Prior("Laplace", mu=0, b=1)`
780787
name : str, optional
781788
Name of the variable that multiplies the fourier modes, by default None

tests/mmm/test_fourier.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,46 @@ def test_fourier_base_instantiation():
358358
prior=Prior("Laplace", mu=0, b=1, dims="fourier"),
359359
)
360360
assert "Can't instantiate abstract class FourierBase" in str(exc_info.value)
361+
362+
363+
class ArbitraryCode:
364+
def __init__(self, dims: tuple[str, ...]) -> None:
365+
self.dims = dims
366+
367+
def create_variable(self, name: str):
368+
return pm.Normal(name, dims=self.dims)
369+
370+
371+
def test_fourier_arbitrary_prior() -> None:
372+
prior = ArbitraryCode(dims=("fourier",))
373+
fourier = YearlyFourier(n_order=4, prior=prior)
374+
375+
x = np.arange(10)
376+
with pm.Model():
377+
y = fourier.apply(x)
378+
379+
assert y.eval().shape == (10,)
380+
381+
382+
def test_fourier_dims_modified() -> None:
383+
prior = ArbitraryCode(dims=())
384+
YearlyFourier(n_order=4, prior=prior)
385+
assert prior.dims == "fourier"
386+
387+
388+
class SerializableArbitraryCode(ArbitraryCode):
389+
def to_dict(self):
390+
return {"dims": self.dims, "msg": "Hello, World!"}
391+
392+
393+
def test_fourier_serializable_arbitrary_prior() -> None:
394+
prior = SerializableArbitraryCode(dims=("fourier",))
395+
fourier = YearlyFourier(n_order=4, prior=prior)
396+
397+
assert fourier.model_dump(mode="json") == {
398+
"n_order": 4,
399+
"days_in_period": 365.25,
400+
"prefix": "fourier",
401+
"prior": {"dims": ["fourier"], "msg": "Hello, World!"},
402+
"variable_name": "fourier_beta",
403+
}

0 commit comments

Comments
 (0)