223
223
224
224
from pymc_marketing .constants import DAYS_IN_MONTH , DAYS_IN_YEAR
225
225
from pymc_marketing .plot import SelToString , plot_curve , plot_hdi , plot_samples
226
- from pymc_marketing .prior import Prior , create_dim_handler
226
+ from pymc_marketing .prior import Prior , VariableFactory , create_dim_handler
227
227
228
228
X_NAME : str = "day"
229
229
NON_GRID_NAMES : frozenset [str ] = frozenset ({X_NAME })
@@ -274,8 +274,8 @@ class FourierBase(BaseModel):
274
274
prefix : str, optional
275
275
Alternative prefix for the fourier seasonality, by default None or
276
276
"fourier"
277
- prior : Prior, optional
278
- Prior distribution for the fourier seasonality beta parameters, by
277
+ prior : Prior | VariableFactory , optional
278
+ Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
279
279
default `Prior("Laplace", mu=0, b=1)`
280
280
variable_name : str, optional
281
281
Name of the variable that multiplies the fourier modes. By default None,
@@ -286,17 +286,21 @@ class FourierBase(BaseModel):
286
286
n_order : int = Field (..., gt = 0 )
287
287
days_in_period : float = Field (..., gt = 0 )
288
288
prefix : str = Field ("fourier" )
289
- prior : InstanceOf [Prior ] = Field (Prior ("Laplace" , mu = 0 , b = 1 ))
289
+ prior : InstanceOf [Prior ] | InstanceOf [VariableFactory ] = Field (
290
+ Prior ("Laplace" , mu = 0 , b = 1 )
291
+ )
290
292
variable_name : str | None = Field (None )
291
293
292
294
def model_post_init (self , __context : Any ) -> None :
293
295
"""Model post initialization for a Pydantic model."""
294
296
if self .variable_name is None :
295
297
self .variable_name = f"{ self .prefix } _beta"
296
298
297
- if not self .prior .dims :
299
+ if not self .prior .dims and isinstance ( self . prior , Prior ) :
298
300
self .prior = self .prior .deepcopy ()
299
301
self .prior .dims = self .prefix
302
+ elif not self .prior .dims :
303
+ self .prior .dims = self .prefix
300
304
301
305
@model_validator (mode = "after" )
302
306
def _check_variable_name (self ) -> Self :
@@ -311,7 +315,7 @@ def _check_prior_has_right_dimensions(self) -> Self:
311
315
return self
312
316
313
317
@field_serializer ("prior" , when_used = "json" )
314
- def serialize_prior (prior : Prior ) -> dict [str , Any ]:
318
+ def serialize_prior (prior : Any ) -> dict [str , Any ]:
315
319
"""Serialize the prior distribution.
316
320
317
321
Parameters
@@ -325,6 +329,9 @@ def serialize_prior(prior: Prior) -> dict[str, Any]:
325
329
The serialized prior distribution.
326
330
327
331
"""
332
+ if hasattr (prior , "to_dict" ):
333
+ return prior .to_dict ()
334
+
328
335
return prior .to_json ()
329
336
330
337
@property
@@ -718,8 +725,8 @@ class YearlyFourier(FourierBase):
718
725
prefix : str, optional
719
726
Alternative prefix for the fourier seasonality, by default None or
720
727
"fourier"
721
- prior : Prior, optional
722
- Prior distribution for the fourier seasonality beta parameters, by
728
+ prior : Prior | VariableFactory , optional
729
+ Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
723
730
default `Prior("Laplace", mu=0, b=1)`
724
731
name : str, optional
725
732
Name of the variable that multiplies the fourier modes, by default None
@@ -774,8 +781,8 @@ class MonthlyFourier(FourierBase):
774
781
prefix : str, optional
775
782
Alternative prefix for the fourier seasonality, by default None or
776
783
"fourier"
777
- prior : Prior, optional
778
- Prior distribution for the fourier seasonality beta parameters, by
784
+ prior : Prior | VariableFactory , optional
785
+ Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
779
786
default `Prior("Laplace", mu=0, b=1)`
780
787
name : str, optional
781
788
Name of the variable that multiplies the fourier modes, by default None
0 commit comments