1313# limitations under the License.
1414"""Additive effects for the multidimensional Marketing Mix Model."""
1515
16- from typing import Protocol
16+ from typing import Any , Protocol
1717
1818import pandas as pd
1919import pymc as pm
2020import xarray as xr
21+ from pydantic import BaseModel , InstanceOf
2122from pytensor import tensor as pt
2223
2324from pymc_marketing .mmm .events import EventEffect , days_from_reference
@@ -318,14 +319,8 @@ def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
318319 pm .set_data ({f"{ self .prefix } _t" : t }, model = model )
319320
320321
321- def create_event_mu_effect (
322- df_events : pd .DataFrame ,
323- prefix : str ,
324- effect : EventEffect ,
325- ) -> MuEffect :
326- """Create an event effect for the MMM.
327-
328- This class has the ability to create data and mean effects for the MMM model.
322+ class EventAdditiveEffect (BaseModel ):
323+ """Event effect class for the MMM.
329324
330325 Parameters
331326 ----------
@@ -338,105 +333,112 @@ def create_event_mu_effect(
338333 The prefix to use for the event effect and associated variables.
339334 effect : EventEffect
340335 The event effect to apply.
341-
342- Returns
343- -------
344- MuEffect
345- The event effect which is used in the MMM.
336+ reference_date : str
337+ The arbitrary reference date to calculate distance from events in days. Default
338+ is "2025-01-01".
346339
347340 """
348- if missing_columns := set (["start_date" , "end_date" , "name" ]).difference (
349- df_events .columns ,
350- ):
351- raise ValueError (f"Columns { missing_columns } are missing in df_events." )
352341
353- effect .basis .prefix = prefix
342+ df_events : InstanceOf [pd .DataFrame ]
343+ prefix : str
344+ effect : EventEffect
345+ reference_date : str = "2025-01-01"
354346
355- reference_date = "2025-01-01"
356- start_dates = pd .to_datetime (df_events ["start_date" ])
357- end_dates = pd .to_datetime (df_events ["end_date" ])
347+ def model_post_init (self , context : Any , / ) -> None :
348+ """Post initialization of the model."""
349+ if missing_columns := set (["start_date" , "end_date" , "name" ]).difference (
350+ self .df_events .columns
351+ ):
352+ raise ValueError (f"Columns { missing_columns } are missing in df_events." )
358353
359- class Effect :
360- """Event effect class for the MMM."""
354+ self .effect .basis .prefix = self .prefix
361355
362- def create_data (self , mmm : MMM ) -> None :
363- """Create the required data in the model.
356+ @property
357+ def start_dates (self ) -> pd .Series :
358+ """The start dates of the events."""
359+ return pd .to_datetime (self .df_events ["start_date" ])
364360
365- Parameters
366- ----------
367- mmm : MMM
368- The MMM model instance.
361+ @ property
362+ def end_dates ( self ) -> pd . Series :
363+ """The end dates of the events."""
364+ return pd . to_datetime ( self . df_events [ "end_date" ])
369365
370- """
371- model : pm . Model = mmm . model
366+ def create_data ( self , mmm : MMM ) -> None :
367+ """Create the required data in the model.
372368
373- model_dates = pd .to_datetime (model .coords ["date" ])
369+ Parameters
370+ ----------
371+ mmm : MMM
372+ The MMM model instance.
374373
375- model .add_coord (prefix , df_events ["name" ].to_numpy ())
374+ """
375+ model : pm .Model = mmm .model
376376
377- if "days" not in model :
378- pm .Data (
379- "days" ,
380- days_from_reference (model_dates , reference_date ),
381- dims = "date" ,
382- )
377+ model_dates = pd .to_datetime (model .coords ["date" ])
383378
379+ model .add_coord (self .prefix , self .df_events ["name" ].to_numpy ())
380+
381+ if "days" not in model :
384382 pm .Data (
385- f"{ prefix } _start_diff" ,
386- days_from_reference (start_dates , reference_date ),
387- dims = prefix ,
388- )
389- pm .Data (
390- f"{ prefix } _end_diff" ,
391- days_from_reference (end_dates , reference_date ),
392- dims = prefix ,
383+ "days" ,
384+ days_from_reference (model_dates , self .reference_date ),
385+ dims = "date" ,
393386 )
394387
395- def create_effect (self , mmm : MMM ) -> pt .TensorVariable :
396- """Create the event effect in the model.
397-
398- Parameters
399- ----------
400- mmm : MMM
401- The MMM model instance.
388+ pm .Data (
389+ f"{ self .prefix } _start_diff" ,
390+ days_from_reference (self .start_dates , self .reference_date ),
391+ dims = self .prefix ,
392+ )
393+ pm .Data (
394+ f"{ self .prefix } _end_diff" ,
395+ days_from_reference (self .end_dates , self .reference_date ),
396+ dims = self .prefix ,
397+ )
402398
403- Returns
404- -------
405- pt.TensorVariable
406- The average event effect in the model.
399+ def create_effect (self , mmm : MMM ) -> pt .TensorVariable :
400+ """Create the event effect in the model.
407401
408- """
409- model : pm .Model = mmm .model
402+ Parameters
403+ ----------
404+ mmm : MMM
405+ The MMM model instance.
410406
411- s_ref = model ["days" ][:, None ] - model [f"{ prefix } _start_diff" ]
412- e_ref = model ["days" ][:, None ] - model [f"{ prefix } _end_diff" ]
407+ Returns
408+ -------
409+ pt.TensorVariable
410+ The average event effect in the model.
413411
414- def create_basis_matrix (s_ref , e_ref ):
415- return pt .where (
416- (s_ref >= 0 ) & (e_ref <= 0 ),
417- 0 ,
418- pt .where (pt .abs (s_ref ) < pt .abs (e_ref ), s_ref , e_ref ),
419- )
412+ """
413+ model : pm .Model = mmm .model
420414
421- X = create_basis_matrix ( s_ref , e_ref )
422- event_effect = effect . apply ( X , name = prefix )
415+ start_ref = model [ "days" ][:, None ] - model [ f" { self . prefix } _start_diff" ]
416+ end_ref = model [ "days" ][:, None ] - model [ f" { self . prefix } _end_diff" ]
423417
424- total_effect = pm .Deterministic (
425- f"{ prefix } _total_effect" ,
426- event_effect .sum (axis = 1 ),
427- dims = "date" ,
418+ def create_basis_matrix (start_ref , end_ref ):
419+ return pt .where (
420+ (start_ref >= 0 ) & (end_ref <= 0 ),
421+ 0 ,
422+ pt .where (pt .abs (start_ref ) < pt .abs (end_ref ), start_ref , end_ref ),
428423 )
429424
430- dim_handler = create_dim_handler (( "date" , * mmm . dims ) )
431- return dim_handler ( total_effect , "date" )
425+ X = create_basis_matrix ( start_ref , end_ref )
426+ event_effect = self . effect . apply ( X , name = self . prefix )
432427
433- def set_data (self , mmm : MMM , model : pm .Model , X : xr .Dataset ) -> None :
434- """Set the data for new predictions."""
435- new_dates = pd .to_datetime (model .coords ["date" ])
428+ total_effect = pm .Deterministic (
429+ f"{ self .prefix } _total_effect" ,
430+ event_effect .sum (axis = 1 ),
431+ dims = "date" ,
432+ )
436433
437- new_data = {
438- "days" : days_from_reference (new_dates , reference_date ),
439- }
440- pm .set_data (new_data = new_data , model = model )
434+ dim_handler = create_dim_handler (("date" , * mmm .dims ))
435+ return dim_handler (total_effect , "date" )
436+
437+ def set_data (self , mmm : MMM , model : pm .Model , X : xr .Dataset ) -> None :
438+ """Set the data for new predictions."""
439+ new_dates = pd .to_datetime (model .coords ["date" ])
441440
442- return Effect ()
441+ new_data = {
442+ "days" : days_from_reference (new_dates , self .reference_date ),
443+ }
444+ pm .set_data (new_data = new_data , model = model )
0 commit comments