2020import pandas as pd
2121import xarray as xr
2222from matplotlib import pyplot as plt
23+ from patsy import dmatrices
2324from sklearn .base import RegressorMixin
2425
25- from causalpy .custom_exceptions import DataException
26+ from causalpy .custom_exceptions import DataException , FormulaException
2627from causalpy .pymc_models import PyMCModel
2728from causalpy .utils import round_num
2829
@@ -58,12 +59,15 @@ class EventStudy(BaseExperiment):
5859 ----------
5960 data : pd.DataFrame
6061 Panel data with unit, time, outcome, and treatment time columns.
62+ formula : str
63+ A patsy-style formula specifying the model. Should include the outcome variable
64+ on the left-hand side and fixed effects on the right-hand side. Use ``C(column)``
65+ syntax for categorical fixed effects. Example: ``"y ~ C(unit) + C(time)"``.
66+ Event-time dummies are added automatically by the class.
6167 unit_col : str
62- Name of the column identifying units.
68+ Name of the column identifying units (must match a term in the formula) .
6369 time_col : str
64- Name of the column identifying time periods.
65- outcome_col : str
66- Name of the outcome variable column.
70+ Name of the column identifying time periods (must match a term in the formula).
6771 treat_time_col : str
6872 Name of the column containing treatment time for each unit.
6973 Use NaN or np.inf for never-treated (control) units.
@@ -85,9 +89,9 @@ class EventStudy(BaseExperiment):
8589 ... )
8690 >>> result = cp.EventStudy(
8791 ... df,
92+ ... formula="y ~ C(unit) + C(time)",
8893 ... unit_col="unit",
8994 ... time_col="time",
90- ... outcome_col="y",
9195 ... treat_time_col="treat_time",
9296 ... event_window=(-5, 5),
9397 ... reference_event_time=-1,
@@ -109,9 +113,9 @@ class EventStudy(BaseExperiment):
109113 def __init__ (
110114 self ,
111115 data : pd .DataFrame ,
116+ formula : str ,
112117 unit_col : str ,
113118 time_col : str ,
114- outcome_col : str ,
115119 treat_time_col : str ,
116120 event_window : tuple [int , int ] = (- 5 , 5 ),
117121 reference_event_time : int = - 1 ,
@@ -121,9 +125,9 @@ def __init__(
121125 super ().__init__ (model = model )
122126 self .data = data .copy ()
123127 self .expt_type = "Event Study"
128+ self .formula = formula
124129 self .unit_col = unit_col
125130 self .time_col = time_col
126- self .outcome_col = outcome_col
127131 self .treat_time_col = treat_time_col
128132 self .event_window = event_window
129133 self .reference_event_time = reference_event_time
@@ -156,14 +160,15 @@ def __init__(
156160 self ._extract_event_time_coefficients ()
157161
158162 def input_validation (self ) -> None :
159- """Validate input data and parameters."""
160- # Check required columns exist
161- required_cols = [
162- self .unit_col ,
163- self .time_col ,
164- self .outcome_col ,
165- self .treat_time_col ,
166- ]
163+ """Validate input data, formula, and parameters."""
164+ # Check formula is provided
165+ if not self .formula or "~" not in self .formula :
166+ raise FormulaException (
167+ "Formula must be provided in the form 'outcome ~ predictors'"
168+ )
169+
170+ # Check required columns exist for event time computation
171+ required_cols = [self .unit_col , self .time_col , self .treat_time_col ]
167172 for col in required_cols :
168173 if col not in self .data .columns :
169174 raise DataException (f"Required column '{ col } ' not found in data" )
@@ -209,28 +214,19 @@ def _compute_event_time(self) -> None:
209214 ) & (self .data ["_event_time" ] <= self .event_window [1 ])
210215
211216 def _build_design_matrix (self ) -> None :
212- """Build design matrix with unit FE, time FE, and event-time dummies."""
213- # Get unique units and times
214- units = sorted (self .data [self .unit_col ].unique ())
215- times = sorted (self .data [self .time_col ].unique ())
216-
217- # Reference categories (first unit and first time)
218- ref_unit = units [0 ]
219- ref_time = times [0 ]
220-
221- # Build unit fixed effect dummies (excluding reference)
222- unit_dummies = pd .get_dummies (
223- self .data [self .unit_col ], prefix = "unit" , dtype = float
224- )
225- unit_cols_to_keep = [c for c in unit_dummies .columns if c != f"unit_{ ref_unit } " ]
226- unit_dummies = unit_dummies [unit_cols_to_keep ]
227-
228- # Build time fixed effect dummies (excluding reference)
229- time_dummies = pd .get_dummies (
230- self .data [self .time_col ], prefix = "time" , dtype = float
217+ """Build design matrix using patsy formula plus event-time dummies."""
218+ # Parse formula with patsy to get y and X (including FEs and covariates)
219+ y , X = dmatrices (self .formula , self .data )
220+ self ._y_design_info = y .design_info
221+ self ._x_design_info = X .design_info
222+
223+ # Extract outcome variable name from formula
224+ self .outcome_variable_name = y .design_info .column_names [0 ]
225+
226+ # Convert patsy output to DataFrames for manipulation
227+ X_df = pd .DataFrame (
228+ X , columns = X .design_info .column_names , index = self .data .index
231229 )
232- time_cols_to_keep = [c for c in time_dummies .columns if c != f"time_{ ref_time } " ]
233- time_dummies = time_dummies [time_cols_to_keep ]
234230
235231 # Build event-time dummies (excluding reference event time)
236232 event_times = list (range (self .event_window [0 ], self .event_window [1 ] + 1 ))
@@ -245,9 +241,8 @@ def _build_design_matrix(self) -> None:
245241 (self .data ["_event_time" ] == k ) & self .data ["_in_event_window" ]
246242 ).astype (float )
247243
248- # Combine all features: intercept + unit FE + time FE + event-time dummies
249- X_df = pd .DataFrame ({"intercept" : 1.0 }, index = self .data .index )
250- X_df = pd .concat ([X_df , unit_dummies , time_dummies , event_time_dummies ], axis = 1 )
244+ # Combine patsy design matrix with event-time dummies
245+ X_df = pd .concat ([X_df , event_time_dummies ], axis = 1 )
251246
252247 self .labels = list (X_df .columns )
253248 self .event_time_labels = [f"event_time_{ k } " for k in event_times_non_ref ]
@@ -262,7 +257,7 @@ def _build_design_matrix(self) -> None:
262257 },
263258 )
264259
265- y_values = np .asarray (self . data [ self . outcome_col ]. values ).reshape (- 1 , 1 )
260+ y_values = np .asarray (y ).reshape (- 1 , 1 )
266261 self .y = xr .DataArray (
267262 y_values ,
268263 dims = ["obs_ind" , "treated_units" ],
@@ -311,6 +306,7 @@ def summary(self, round_to: int | None = 2) -> None:
311306 Number of decimals for rounding. Defaults to 2.
312307 """
313308 print (f"{ self .expt_type :=^80} " )
309+ print (f"Formula: { self .formula } " )
314310 print (f"Event window: { self .event_window } " )
315311 print (f"Reference event time: { self .reference_event_time } " )
316312 print ("\n Event-time coefficients (beta_k):" )
0 commit comments