@@ -567,13 +567,17 @@ class InterventionTimeEstimator(PyMCModel):
567567 """
568568
569569 def __init__ (
570- self , time_variable_name : str , treatment_type_effect = None , sample_kwargs = None
570+ self ,
571+ time_variable_name : str ,
572+ treatment_effect_type : str | list [str ],
573+ treatment_effect_param = None ,
574+ sample_kwargs = None ,
571575 ):
572576 """
573577 Initializes the InterventionTimeEstimator model.
574578
575579 :param time_variable_name: name of the column representing time among the covariates
576- :param treatment_type_effect : Optional dictionary that specifies prior parameters for the
580+ :param treatment_effect_type : Optional dictionary that specifies prior parameters for the
577581 intervention effects. Expected keys are:
578582 - "level": [mu, sigma]
579583 - "trend": [mu, sigma]
@@ -584,13 +588,57 @@ def __init__(
584588 """
585589 self .time_variable_name = time_variable_name
586590
587- if treatment_type_effect is None :
588- treatment_type_effect = {}
589- if not isinstance (treatment_type_effect , dict ):
590- raise TypeError ("treatment_type_effect must be a dictionary." )
591-
592591 super ().__init__ (sample_kwargs )
593- self .treatment_type_effect = treatment_type_effect
592+
593+ # Hardcoded default priors
594+ self .DEFAULT_BETA_PRIOR = (0 , 5 )
595+ self .DEFAULT_LEVEL_PRIOR = (0 , 5 )
596+ self .DEFAULT_TREND_PRIOR = (0 , 0.5 )
597+ self .DEFAULT_IMPULSE_PRIOR = (0 , 5 , 5 )
598+
599+ # Make sure we get a list of all expected effects
600+ if isinstance (treatment_effect_type , str ):
601+ self .treatment_effect_type = [treatment_effect_type ]
602+ else :
603+ self .treatment_effect_type = treatment_effect_type
604+
605+ # Defining the priors here
606+ self .treatment_effect_param = (
607+ {} if treatment_effect_param is None else treatment_effect_param
608+ )
609+
610+ if "level" in self .treatment_effect_type :
611+ if (
612+ "level" not in self .treatment_effect_param
613+ or len (self .treatment_effect_param ["level" ]) != 2
614+ ):
615+ self .treatment_effect_param ["level" ] = self .DEFAULT_LEVEL_PRIOR
616+ else :
617+ self .treatment_effect_param ["level" ] = self .treatment_effect_param [
618+ "level"
619+ ]
620+
621+ if "trend" in self .treatment_effect_type :
622+ if (
623+ "trend" not in self .treatment_effect_param
624+ or len (self .treatment_effect_param ["trend" ]) != 2
625+ ):
626+ self .treatment_effect_param ["trend" ] = self .DEFAULT_TREND_PRIOR
627+ else :
628+ self .treatment_effect_param ["trend" ] = self .treatment_effect_param [
629+ "trend"
630+ ]
631+
632+ if "impulse" in self .treatment_effect_type :
633+ if (
634+ "impulse" not in self .treatment_effect_param
635+ or len (self .treatment_effect_param ["impulse" ]) != 3
636+ ):
637+ self .treatment_effect_param ["impulse" ] = self .DEFAULT_IMPULSE_PRIOR
638+ else :
639+ self .treatment_effect_param ["impulse" ] = self .treatment_effect_param [
640+ "impulse"
641+ ]
594642
595643 def build_model (self , X , y , coords ):
596644 """
@@ -603,12 +651,8 @@ def build_model(self, X, y, coords):
603651 Assumes the following attributes are already defined in self:
604652 - self.timeline: the index of the column in X representing time.
605653 - self.time_range: a tuple (lower_bound, upper_bound) for the intervention time.
606- - self.treatment_type_effect : a dictionary specifying which intervention effects to include and their priors.
654+ - self.treatment_effect_type : a dictionary specifying which intervention effects to include and their priors.
607655 """
608- DEFAULT_BETA_PRIOR = (0 , 5 )
609- DEFAULT_LEVEL_PRIOR = (0 , 5 )
610- DEFAULT_TREND_PRIOR = (0 , 0.5 )
611- DEFAULT_IMPULSE_PRIOR = (0 , 5 , 5 )
612656
613657 with self :
614658 self .add_coords (coords )
@@ -626,52 +670,37 @@ def build_model(self, X, y, coords):
626670 )
627671 beta = pm .Normal (
628672 name = "beta" ,
629- mu = DEFAULT_BETA_PRIOR [0 ],
630- sigma = DEFAULT_BETA_PRIOR [1 ],
673+ mu = self . DEFAULT_BETA_PRIOR [0 ],
674+ sigma = self . DEFAULT_BETA_PRIOR [1 ],
631675 dims = "coeffs" ,
632676 )
633677
634678 # --- Intervention effect ---
635679 mu_in_components = []
636680
637- if "level" in self .treatment_type_effect :
638- mu , sigma = (
639- DEFAULT_LEVEL_PRIOR
640- if len (self .treatment_type_effect ["level" ]) != 2
641- else (
642- self .treatment_type_effect ["level" ][0 ],
643- self .treatment_type_effect ["level" ][1 ],
644- )
645- )
681+ if "level" in self .treatment_effect_param :
646682 level = pm .Normal (
647683 "level" ,
648- mu = mu ,
649- sigma = sigma ,
684+ mu = self . treatment_effect_param [ "level" ][ 0 ] ,
685+ sigma = self . treatment_effect_param [ "level" ][ 1 ] ,
650686 )
651687 mu_in_components .append (level )
652- if "trend" in self .treatment_type_effect :
653- mu , sigma = (
654- DEFAULT_TREND_PRIOR
655- if len (self .treatment_type_effect ["trend" ]) != 2
656- else (
657- self .treatment_type_effect ["trend" ][0 ],
658- self .treatment_type_effect ["trend" ][1 ],
659- )
688+ if "trend" in self .treatment_effect_param :
689+ trend = pm .Normal (
690+ "trend" ,
691+ mu = self .treatment_effect_param ["trend" ][0 ],
692+ sigma = self .treatment_effect_param ["trend" ][1 ],
660693 )
661- trend = pm .Normal ("trend" , mu = mu , sigma = sigma )
662694 mu_in_components .append (trend * (t - treatment_time ))
663- if "impulse" in self .treatment_type_effect :
664- mu , sigma1 , sigma2 = (
665- DEFAULT_IMPULSE_PRIOR
666- if len (self .treatment_type_effect ["impulse" ]) != 3
667- else (
668- self .treatment_type_effect ["impulse" ][0 ],
669- self .treatment_type_effect ["impulse" ][1 ],
670- self .treatment_type_effect ["impulse" ][2 ],
671- )
695+ if "impulse" in self .treatment_effect_param :
696+ impulse_amplitude = pm .Normal (
697+ "impulse_amplitude" ,
698+ mu = self .treatment_effect_param ["impulse" ][0 ],
699+ sigma = self .treatment_effect_param ["impulse" ][1 ],
700+ )
701+ decay_rate = pm .HalfNormal (
702+ "decay_rate" , sigma = self .treatment_effect_param ["impulse" ][2 ]
672703 )
673- impulse_amplitude = pm .Normal ("impulse_amplitude" , mu = mu , sigma = sigma1 )
674- decay_rate = pm .HalfNormal ("decay_rate" , sigma = sigma2 )
675704 impulse = pm .Deterministic (
676705 "impulse" ,
677706 impulse_amplitude
@@ -687,7 +716,7 @@ def build_model(self, X, y, coords):
687716 mu_in = (
688717 pm .Deterministic (name = "mu_in" , var = sum (mu_in_components ))
689718 if len (mu_in_components ) > 0
690- else 0
719+ else pm . Data ( name = "mu_in" , vars = 0 )
691720 )
692721 # Compute and store the sum of the base time series and the intervention's effect
693722 mu_ts = pm .Deterministic ("mu_ts" , mu + weight * mu_in , dims = "obs_ind" )
0 commit comments