Skip to content

Commit e7e0562

Browse files
committed
Refactor EventStudy to use patsy formula for FEs
The EventStudy class now requires a patsy-style formula to specify the outcome and fixed effects, removing the separate outcome_col argument. Design matrix construction uses patsy, and event-time dummies are appended. Input validation checks for formula presence, and tests and documentation are updated to reflect the new API and output format.
1 parent e2c59f4 commit e7e0562

File tree

3 files changed

+132
-109
lines changed

3 files changed

+132
-109
lines changed

causalpy/experiments/event_study.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
import pandas as pd
2121
import xarray as xr
2222
from matplotlib import pyplot as plt
23+
from patsy import dmatrices
2324
from sklearn.base import RegressorMixin
2425

25-
from causalpy.custom_exceptions import DataException
26+
from causalpy.custom_exceptions import DataException, FormulaException
2627
from causalpy.pymc_models import PyMCModel
2728
from causalpy.utils import round_num
2829

@@ -58,12 +59,15 @@ class EventStudy(BaseExperiment):
5859
----------
5960
data : pd.DataFrame
6061
Panel data with unit, time, outcome, and treatment time columns.
62+
formula : str
63+
A patsy-style formula specifying the model. Should include the outcome variable
64+
on the left-hand side and fixed effects on the right-hand side. Use ``C(column)``
65+
syntax for categorical fixed effects. Example: ``"y ~ C(unit) + C(time)"``.
66+
Event-time dummies are added automatically by the class.
6167
unit_col : str
62-
Name of the column identifying units.
68+
Name of the column identifying units (must match a term in the formula).
6369
time_col : str
64-
Name of the column identifying time periods.
65-
outcome_col : str
66-
Name of the outcome variable column.
70+
Name of the column identifying time periods (must match a term in the formula).
6771
treat_time_col : str
6872
Name of the column containing treatment time for each unit.
6973
Use NaN or np.inf for never-treated (control) units.
@@ -85,9 +89,9 @@ class EventStudy(BaseExperiment):
8589
... )
8690
>>> result = cp.EventStudy(
8791
... df,
92+
... formula="y ~ C(unit) + C(time)",
8893
... unit_col="unit",
8994
... time_col="time",
90-
... outcome_col="y",
9195
... treat_time_col="treat_time",
9296
... event_window=(-5, 5),
9397
... reference_event_time=-1,
@@ -109,9 +113,9 @@ class EventStudy(BaseExperiment):
109113
def __init__(
110114
self,
111115
data: pd.DataFrame,
116+
formula: str,
112117
unit_col: str,
113118
time_col: str,
114-
outcome_col: str,
115119
treat_time_col: str,
116120
event_window: tuple[int, int] = (-5, 5),
117121
reference_event_time: int = -1,
@@ -121,9 +125,9 @@ def __init__(
121125
super().__init__(model=model)
122126
self.data = data.copy()
123127
self.expt_type = "Event Study"
128+
self.formula = formula
124129
self.unit_col = unit_col
125130
self.time_col = time_col
126-
self.outcome_col = outcome_col
127131
self.treat_time_col = treat_time_col
128132
self.event_window = event_window
129133
self.reference_event_time = reference_event_time
@@ -156,14 +160,15 @@ def __init__(
156160
self._extract_event_time_coefficients()
157161

158162
def input_validation(self) -> None:
159-
"""Validate input data and parameters."""
160-
# Check required columns exist
161-
required_cols = [
162-
self.unit_col,
163-
self.time_col,
164-
self.outcome_col,
165-
self.treat_time_col,
166-
]
163+
"""Validate input data, formula, and parameters."""
164+
# Check formula is provided
165+
if not self.formula or "~" not in self.formula:
166+
raise FormulaException(
167+
"Formula must be provided in the form 'outcome ~ predictors'"
168+
)
169+
170+
# Check required columns exist for event time computation
171+
required_cols = [self.unit_col, self.time_col, self.treat_time_col]
167172
for col in required_cols:
168173
if col not in self.data.columns:
169174
raise DataException(f"Required column '{col}' not found in data")
@@ -209,28 +214,19 @@ def _compute_event_time(self) -> None:
209214
) & (self.data["_event_time"] <= self.event_window[1])
210215

211216
def _build_design_matrix(self) -> None:
212-
"""Build design matrix with unit FE, time FE, and event-time dummies."""
213-
# Get unique units and times
214-
units = sorted(self.data[self.unit_col].unique())
215-
times = sorted(self.data[self.time_col].unique())
216-
217-
# Reference categories (first unit and first time)
218-
ref_unit = units[0]
219-
ref_time = times[0]
220-
221-
# Build unit fixed effect dummies (excluding reference)
222-
unit_dummies = pd.get_dummies(
223-
self.data[self.unit_col], prefix="unit", dtype=float
224-
)
225-
unit_cols_to_keep = [c for c in unit_dummies.columns if c != f"unit_{ref_unit}"]
226-
unit_dummies = unit_dummies[unit_cols_to_keep]
227-
228-
# Build time fixed effect dummies (excluding reference)
229-
time_dummies = pd.get_dummies(
230-
self.data[self.time_col], prefix="time", dtype=float
217+
"""Build design matrix using patsy formula plus event-time dummies."""
218+
# Parse formula with patsy to get y and X (including FEs and covariates)
219+
y, X = dmatrices(self.formula, self.data)
220+
self._y_design_info = y.design_info
221+
self._x_design_info = X.design_info
222+
223+
# Extract outcome variable name from formula
224+
self.outcome_variable_name = y.design_info.column_names[0]
225+
226+
# Convert patsy output to DataFrames for manipulation
227+
X_df = pd.DataFrame(
228+
X, columns=X.design_info.column_names, index=self.data.index
231229
)
232-
time_cols_to_keep = [c for c in time_dummies.columns if c != f"time_{ref_time}"]
233-
time_dummies = time_dummies[time_cols_to_keep]
234230

235231
# Build event-time dummies (excluding reference event time)
236232
event_times = list(range(self.event_window[0], self.event_window[1] + 1))
@@ -245,9 +241,8 @@ def _build_design_matrix(self) -> None:
245241
(self.data["_event_time"] == k) & self.data["_in_event_window"]
246242
).astype(float)
247243

248-
# Combine all features: intercept + unit FE + time FE + event-time dummies
249-
X_df = pd.DataFrame({"intercept": 1.0}, index=self.data.index)
250-
X_df = pd.concat([X_df, unit_dummies, time_dummies, event_time_dummies], axis=1)
244+
# Combine patsy design matrix with event-time dummies
245+
X_df = pd.concat([X_df, event_time_dummies], axis=1)
251246

252247
self.labels = list(X_df.columns)
253248
self.event_time_labels = [f"event_time_{k}" for k in event_times_non_ref]
@@ -262,7 +257,7 @@ def _build_design_matrix(self) -> None:
262257
},
263258
)
264259

265-
y_values = np.asarray(self.data[self.outcome_col].values).reshape(-1, 1)
260+
y_values = np.asarray(y).reshape(-1, 1)
266261
self.y = xr.DataArray(
267262
y_values,
268263
dims=["obs_ind", "treated_units"],
@@ -311,6 +306,7 @@ def summary(self, round_to: int | None = 2) -> None:
311306
Number of decimals for rounding. Defaults to 2.
312307
"""
313308
print(f"{self.expt_type:=^80}")
309+
print(f"Formula: {self.formula}")
314310
print(f"Event window: {self.event_window}")
315311
print(f"Reference event time: {self.reference_event_time}")
316312
print("\nEvent-time coefficients (beta_k):")

causalpy/tests/test_event_study.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sklearn.linear_model import LinearRegression
2222

2323
import causalpy as cp
24-
from causalpy.custom_exceptions import DataException
24+
from causalpy.custom_exceptions import DataException, FormulaException
2525
from causalpy.data.simulate_data import generate_event_study_data
2626

2727
sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2, "progressbar": False}
@@ -96,9 +96,26 @@ def test_event_study_missing_column():
9696
with pytest.raises(DataException, match="Required column 'treat_time' not found"):
9797
cp.EventStudy(
9898
df,
99+
formula="y ~ C(unit) + C(time)",
100+
unit_col="unit",
101+
time_col="time",
102+
treat_time_col="treat_time",
103+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
104+
)
105+
106+
107+
def test_event_study_missing_formula():
108+
"""Test that missing formula raises FormulaException."""
109+
df = pd.DataFrame(
110+
{"unit": [0, 1], "time": [0, 0], "y": [1.0, 2.0], "treat_time": [5.0, np.nan]}
111+
)
112+
113+
with pytest.raises(FormulaException, match="Formula must be provided"):
114+
cp.EventStudy(
115+
df,
116+
formula="", # Empty formula
99117
unit_col="unit",
100118
time_col="time",
101-
outcome_col="y",
102119
treat_time_col="treat_time",
103120
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
104121
)
@@ -111,9 +128,9 @@ def test_event_study_invalid_event_window():
111128
with pytest.raises(DataException, match="event_window\\[0\\].*must be less than"):
112129
cp.EventStudy(
113130
df,
131+
formula="y ~ C(unit) + C(time)",
114132
unit_col="unit",
115133
time_col="time",
116-
outcome_col="y",
117134
treat_time_col="treat_time",
118135
event_window=(5, -5), # Invalid: min > max
119136
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
@@ -127,9 +144,9 @@ def test_event_study_reference_outside_window():
127144
with pytest.raises(DataException, match="reference_event_time.*must be within"):
128145
cp.EventStudy(
129146
df,
147+
formula="y ~ C(unit) + C(time)",
130148
unit_col="unit",
131149
time_col="time",
132-
outcome_col="y",
133150
treat_time_col="treat_time",
134151
event_window=(-3, 3),
135152
reference_event_time=-5, # Outside window
@@ -151,9 +168,9 @@ def test_event_study_duplicate_observations():
151168
with pytest.raises(DataException, match="duplicate unit-time observations"):
152169
cp.EventStudy(
153170
df,
171+
formula="y ~ C(unit) + C(time)",
154172
unit_col="unit",
155173
time_col="time",
156-
outcome_col="y",
157174
treat_time_col="treat_time",
158175
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
159176
)
@@ -171,9 +188,9 @@ def test_event_study_pymc(mock_pymc_sample):
171188

172189
result = cp.EventStudy(
173190
df,
191+
formula="y ~ C(unit) + C(time)",
174192
unit_col="unit",
175193
time_col="time",
176-
outcome_col="y",
177194
treat_time_col="treat_time",
178195
event_window=(-5, 5),
179196
reference_event_time=-1,
@@ -199,9 +216,9 @@ def test_event_study_pymc_summary(mock_pymc_sample):
199216

200217
result = cp.EventStudy(
201218
df,
219+
formula="y ~ C(unit) + C(time)",
202220
unit_col="unit",
203221
time_col="time",
204-
outcome_col="y",
205222
treat_time_col="treat_time",
206223
event_window=(-3, 3),
207224
reference_event_time=-1,
@@ -226,9 +243,9 @@ def test_event_study_pymc_plot(mock_pymc_sample):
226243

227244
result = cp.EventStudy(
228245
df,
246+
formula="y ~ C(unit) + C(time)",
229247
unit_col="unit",
230248
time_col="time",
231-
outcome_col="y",
232249
treat_time_col="treat_time",
233250
event_window=(-3, 3),
234251
reference_event_time=-1,
@@ -248,9 +265,9 @@ def test_event_study_pymc_get_plot_data(mock_pymc_sample):
248265

249266
result = cp.EventStudy(
250267
df,
268+
formula="y ~ C(unit) + C(time)",
251269
unit_col="unit",
252270
time_col="time",
253-
outcome_col="y",
254271
treat_time_col="treat_time",
255272
event_window=(-3, 3),
256273
reference_event_time=-1,
@@ -272,9 +289,9 @@ def test_event_study_sklearn():
272289

273290
result = cp.EventStudy(
274291
df,
292+
formula="y ~ C(unit) + C(time)",
275293
unit_col="unit",
276294
time_col="time",
277-
outcome_col="y",
278295
treat_time_col="treat_time",
279296
event_window=(-5, 5),
280297
reference_event_time=-1,
@@ -298,9 +315,9 @@ def test_event_study_sklearn_summary():
298315

299316
result = cp.EventStudy(
300317
df,
318+
formula="y ~ C(unit) + C(time)",
301319
unit_col="unit",
302320
time_col="time",
303-
outcome_col="y",
304321
treat_time_col="treat_time",
305322
event_window=(-3, 3),
306323
reference_event_time=-1,
@@ -324,9 +341,9 @@ def test_event_study_sklearn_plot():
324341

325342
result = cp.EventStudy(
326343
df,
344+
formula="y ~ C(unit) + C(time)",
327345
unit_col="unit",
328346
time_col="time",
329-
outcome_col="y",
330347
treat_time_col="treat_time",
331348
event_window=(-3, 3),
332349
reference_event_time=-1,
@@ -365,9 +382,9 @@ def test_event_study_sklearn_recovers_effects():
365382

366383
result = cp.EventStudy(
367384
df,
385+
formula="y ~ C(unit) + C(time)",
368386
unit_col="unit",
369387
time_col="time",
370-
outcome_col="y",
371388
treat_time_col="treat_time",
372389
event_window=(-3, 3),
373390
reference_event_time=-1,
@@ -396,9 +413,9 @@ def test_event_study_narrow_event_window():
396413

397414
result = cp.EventStudy(
398415
df,
416+
formula="y ~ C(unit) + C(time)",
399417
unit_col="unit",
400418
time_col="time",
401-
outcome_col="y",
402419
treat_time_col="treat_time",
403420
event_window=(-1, 1),
404421
reference_event_time=-1,
@@ -419,9 +436,9 @@ def test_event_study_all_control_units():
419436
# The model should still run but event-time dummies will all be 0
420437
result = cp.EventStudy(
421438
df,
439+
formula="y ~ C(unit) + C(time)",
422440
unit_col="unit",
423441
time_col="time",
424-
outcome_col="y",
425442
treat_time_col="treat_time",
426443
event_window=(-3, 3),
427444
reference_event_time=-1,

0 commit comments

Comments
 (0)