Skip to content

Commit 411aac7

Browse files
committed
Updating treatment type effect input
1 parent d0f4a58 commit 411aac7

File tree

1 file changed

+76
-47
lines changed

1 file changed

+76
-47
lines changed

causalpy/pymc_models.py

Lines changed: 76 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -567,13 +567,17 @@ class InterventionTimeEstimator(PyMCModel):
567567
"""
568568

569569
def __init__(
570-
self, time_variable_name: str, treatment_type_effect=None, sample_kwargs=None
570+
self,
571+
time_variable_name: str,
572+
treatment_effect_type: str | list[str],
573+
treatment_effect_param=None,
574+
sample_kwargs=None,
571575
):
572576
"""
573577
Initializes the InterventionTimeEstimator model.
574578
575579
:param time_variable_name: name of the column representing time among the covariates
576-
:param treatment_type_effect: Optional dictionary that specifies prior parameters for the
580+
:param treatment_effect_type: Optional dictionary that specifies prior parameters for the
577581
intervention effects. Expected keys are:
578582
- "level": [mu, sigma]
579583
- "trend": [mu, sigma]
@@ -584,13 +588,57 @@ def __init__(
584588
"""
585589
self.time_variable_name = time_variable_name
586590

587-
if treatment_type_effect is None:
588-
treatment_type_effect = {}
589-
if not isinstance(treatment_type_effect, dict):
590-
raise TypeError("treatment_type_effect must be a dictionary.")
591-
592591
super().__init__(sample_kwargs)
593-
self.treatment_type_effect = treatment_type_effect
592+
593+
# Hardcoded default priors
594+
self.DEFAULT_BETA_PRIOR = (0, 5)
595+
self.DEFAULT_LEVEL_PRIOR = (0, 5)
596+
self.DEFAULT_TREND_PRIOR = (0, 0.5)
597+
self.DEFAULT_IMPULSE_PRIOR = (0, 5, 5)
598+
599+
# Make sure we get a list of all expected effects
600+
if isinstance(treatment_effect_type, str):
601+
self.treatment_effect_type = [treatment_effect_type]
602+
else:
603+
self.treatment_effect_type = treatment_effect_type
604+
605+
# Defining the priors here
606+
self.treatment_effect_param = (
607+
{} if treatment_effect_param is None else treatment_effect_param
608+
)
609+
610+
if "level" in self.treatment_effect_type:
611+
if (
612+
"level" not in self.treatment_effect_param
613+
or len(self.treatment_effect_param["level"]) != 2
614+
):
615+
self.treatment_effect_param["level"] = self.DEFAULT_LEVEL_PRIOR
616+
else:
617+
self.treatment_effect_param["level"] = self.treatment_effect_param[
618+
"level"
619+
]
620+
621+
if "trend" in self.treatment_effect_type:
622+
if (
623+
"trend" not in self.treatment_effect_param
624+
or len(self.treatment_effect_param["trend"]) != 2
625+
):
626+
self.treatment_effect_param["trend"] = self.DEFAULT_TREND_PRIOR
627+
else:
628+
self.treatment_effect_param["trend"] = self.treatment_effect_param[
629+
"trend"
630+
]
631+
632+
if "impulse" in self.treatment_effect_type:
633+
if (
634+
"impulse" not in self.treatment_effect_param
635+
or len(self.treatment_effect_param["impulse"]) != 3
636+
):
637+
self.treatment_effect_param["impulse"] = self.DEFAULT_IMPULSE_PRIOR
638+
else:
639+
self.treatment_effect_param["impulse"] = self.treatment_effect_param[
640+
"impulse"
641+
]
594642

595643
def build_model(self, X, y, coords):
596644
"""
@@ -603,12 +651,8 @@ def build_model(self, X, y, coords):
603651
Assumes the following attributes are already defined in self:
604652
- self.timeline: the index of the column in X representing time.
605653
- self.time_range: a tuple (lower_bound, upper_bound) for the intervention time.
606-
- self.treatment_type_effect: a dictionary specifying which intervention effects to include and their priors.
654+
- self.treatment_effect_type: a dictionary specifying which intervention effects to include and their priors.
607655
"""
608-
DEFAULT_BETA_PRIOR = (0, 5)
609-
DEFAULT_LEVEL_PRIOR = (0, 5)
610-
DEFAULT_TREND_PRIOR = (0, 0.5)
611-
DEFAULT_IMPULSE_PRIOR = (0, 5, 5)
612656

613657
with self:
614658
self.add_coords(coords)
@@ -626,52 +670,37 @@ def build_model(self, X, y, coords):
626670
)
627671
beta = pm.Normal(
628672
name="beta",
629-
mu=DEFAULT_BETA_PRIOR[0],
630-
sigma=DEFAULT_BETA_PRIOR[1],
673+
mu=self.DEFAULT_BETA_PRIOR[0],
674+
sigma=self.DEFAULT_BETA_PRIOR[1],
631675
dims="coeffs",
632676
)
633677

634678
# --- Intervention effect ---
635679
mu_in_components = []
636680

637-
if "level" in self.treatment_type_effect:
638-
mu, sigma = (
639-
DEFAULT_LEVEL_PRIOR
640-
if len(self.treatment_type_effect["level"]) != 2
641-
else (
642-
self.treatment_type_effect["level"][0],
643-
self.treatment_type_effect["level"][1],
644-
)
645-
)
681+
if "level" in self.treatment_effect_param:
646682
level = pm.Normal(
647683
"level",
648-
mu=mu,
649-
sigma=sigma,
684+
mu=self.treatment_effect_param["level"][0],
685+
sigma=self.treatment_effect_param["level"][1],
650686
)
651687
mu_in_components.append(level)
652-
if "trend" in self.treatment_type_effect:
653-
mu, sigma = (
654-
DEFAULT_TREND_PRIOR
655-
if len(self.treatment_type_effect["trend"]) != 2
656-
else (
657-
self.treatment_type_effect["trend"][0],
658-
self.treatment_type_effect["trend"][1],
659-
)
688+
if "trend" in self.treatment_effect_param:
689+
trend = pm.Normal(
690+
"trend",
691+
mu=self.treatment_effect_param["trend"][0],
692+
sigma=self.treatment_effect_param["trend"][1],
660693
)
661-
trend = pm.Normal("trend", mu=mu, sigma=sigma)
662694
mu_in_components.append(trend * (t - treatment_time))
663-
if "impulse" in self.treatment_type_effect:
664-
mu, sigma1, sigma2 = (
665-
DEFAULT_IMPULSE_PRIOR
666-
if len(self.treatment_type_effect["impulse"]) != 3
667-
else (
668-
self.treatment_type_effect["impulse"][0],
669-
self.treatment_type_effect["impulse"][1],
670-
self.treatment_type_effect["impulse"][2],
671-
)
695+
if "impulse" in self.treatment_effect_param:
696+
impulse_amplitude = pm.Normal(
697+
"impulse_amplitude",
698+
mu=self.treatment_effect_param["impulse"][0],
699+
sigma=self.treatment_effect_param["impulse"][1],
700+
)
701+
decay_rate = pm.HalfNormal(
702+
"decay_rate", sigma=self.treatment_effect_param["impulse"][2]
672703
)
673-
impulse_amplitude = pm.Normal("impulse_amplitude", mu=mu, sigma=sigma1)
674-
decay_rate = pm.HalfNormal("decay_rate", sigma=sigma2)
675704
impulse = pm.Deterministic(
676705
"impulse",
677706
impulse_amplitude
@@ -687,7 +716,7 @@ def build_model(self, X, y, coords):
687716
mu_in = (
688717
pm.Deterministic(name="mu_in", var=sum(mu_in_components))
689718
if len(mu_in_components) > 0
690-
else 0
719+
else pm.Data(name="mu_in", vars=0)
691720
)
692721
# Compute and store the sum of the base time series and the intervention's effect
693722
mu_ts = pm.Deterministic("mu_ts", mu + weight * mu_in, dims="obs_ind")

0 commit comments

Comments
 (0)