Skip to content

Commit 4ebe1a7

Browse files
Rojan ShresthaRojan Shrestha
authored andcommitted
Added post_treatment_variable_name parameter and sklearn model summary for did
1 parent 09adfd7 commit 4ebe1a7

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

causalpy/experiments/diff_in_diff.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
from causalpy.custom_exceptions import (
2828
DataException,
29-
FormulaException,
3029
)
3130
from causalpy.plot_utils import plot_xY
3231
from causalpy.pymc_models import PyMCModel
@@ -84,6 +83,7 @@ def __init__(
8483
formula: str,
8584
time_variable_name: str,
8685
group_variable_name: str,
86+
post_treatment_variable_name: str = "post_treatment",
8787
model=None,
8888
**kwargs,
8989
) -> None:
@@ -95,6 +95,7 @@ def __init__(
9595
self.formula = formula
9696
self.time_variable_name = time_variable_name
9797
self.group_variable_name = group_variable_name
98+
self.post_treatment_variable_name = post_treatment_variable_name
9899
self.input_validation()
99100

100101
y, X = dmatrices(formula, self.data)
@@ -128,6 +129,12 @@ def __init__(
128129
}
129130
self.model.fit(X=self.X, y=self.y, coords=COORDS)
130131
elif isinstance(self.model, RegressorMixin):
132+
# For scikit-learn models, automatically set fit_intercept=False
133+
# This ensures the intercept is included in the coefficients array rather than being a separate intercept_ attribute
134+
# without this, the intercept is not included in the coefficients array hence would be displayed as 0 in the model summary
135+
# TODO: later, this should be handled in ScikitLearnAdaptor itself
136+
if hasattr(self.model, "fit_intercept"):
137+
self.model.fit_intercept = False
131138
self.model.fit(X=self.X, y=self.y)
132139
else:
133140
raise ValueError("Model type not recognized")
@@ -173,7 +180,7 @@ def __init__(
173180
# just the treated group
174181
.query(f"{self.group_variable_name} == 1")
175182
# just the treatment period(s)
176-
.query("post_treatment == True")
183+
.query(f"{self.post_treatment_variable_name} == True")
177184
# drop the outcome variable
178185
.drop(self.outcome_variable_name, axis=1)
179186
# We may have multiple units per time point, we only want one time point
@@ -189,7 +196,10 @@ def __init__(
189196
# INTERVENTION: set the interaction term between the group and the
190197
# post_treatment variable to zero. This is the counterfactual.
191198
for i, label in enumerate(self.labels):
192-
if "post_treatment" in label and self.group_variable_name in label:
199+
if (
200+
self.post_treatment_variable_name in label
201+
and self.group_variable_name in label
202+
):
193203
new_x.iloc[:, i] = 0
194204
self.y_pred_counterfactual = self.model.predict(np.asarray(new_x))
195205

@@ -198,32 +208,53 @@ def __init__(
198208
# This is the coefficient on the interaction term
199209
coeff_names = self.model.idata.posterior.coords["coeffs"].data
200210
for i, label in enumerate(coeff_names):
201-
if "post_treatment" in label and self.group_variable_name in label:
211+
if (
212+
self.post_treatment_variable_name in label
213+
and self.group_variable_name in label
214+
):
202215
self.causal_impact = self.model.idata.posterior["beta"].isel(
203216
{"coeffs": i}
204217
)
205218
elif isinstance(self.model, RegressorMixin):
206219
# This is the coefficient on the interaction term
207-
# TODO: CHECK FOR CORRECTNESS
208-
self.causal_impact = (
209-
self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
210-
).item()
220+
# Store the coefficient into dictionary {intercept:value}
221+
coef_map = dict(zip(self.labels, self.model.get_coeffs()))
222+
# Create and find the interaction term based on the values user provided
223+
interaction_term = (
224+
f"{self.group_variable_name}:{self.post_treatment_variable_name}"
225+
)
226+
matched_key = next((k for k in coef_map if interaction_term in k), None)
227+
att = coef_map.get(matched_key)
228+
self.causal_impact = att
211229
else:
212230
raise ValueError("Model type not recognized")
213231

214232
return
215233

216234
def input_validation(self):
217235
"""Validate the input data and model formula for correctness"""
218-
if "post_treatment" not in self.formula:
219-
raise FormulaException(
220-
"A predictor called `post_treatment` should be in the formula"
221-
)
222-
223-
if "post_treatment" not in self.data.columns:
224-
raise DataException(
225-
"Require a boolean column labelling observations which are `treated`"
226-
)
236+
if (
237+
self.post_treatment_variable_name not in self.formula
238+
or self.post_treatment_variable_name not in self.data.columns
239+
):
240+
if self.post_treatment_variable_name == "post_treatment":
241+
# Default case - user didn't specify custom name, so guide them to use "post_treatment"
242+
raise DataException(
243+
"Missing 'post_treatment' in formula or dataset.\n"
244+
"Note: post_treatment_variable_name might have been set to 'post_treatment' by default.\n"
245+
"1) Add 'post_treatment' to formula (e.g., 'y ~ 1 + group*post_treatment')\n"
246+
"2) and ensure dataset has boolean column 'post_treatment'.\n"
247+
"To use custom name, provide additional argument post_treatment_variable_name='your_post_treatment_variable_name'."
248+
)
249+
else:
250+
# Custom case - user specified custom name, so remind them what they specified
251+
raise DataException(
252+
f"Missing required variable '{self.post_treatment_variable_name}' in formula or dataset.\n\n"
253+
f"Since you specified post_treatment_variable_name='{self.post_treatment_variable_name}', "
254+
f"please ensure:\n"
255+
f"1) formula includes '{self.post_treatment_variable_name}'\n"
256+
f"2) dataset has boolean column named '{self.post_treatment_variable_name}'"
257+
)
227258

228259
if "unit" not in self.data.columns:
229260
raise DataException(

docs/source/_static/interrogate_badge.svg

Lines changed: 4 additions & 4 deletions
Loading

0 commit comments

Comments
 (0)