2020import arviz as az
2121import numpy as np
2222import pandas as pd
23+ import xarray as xr
2324from matplotlib import pyplot as plt
24- from patsy import build_design_matrices , dmatrices
2525from sklearn .base import RegressorMixin
2626
2727from causalpy .custom_exceptions import BadIndexException
@@ -41,8 +41,10 @@ class SyntheticControl(BaseExperiment):
4141 A pandas dataframe
4242 :param treatment_time:
4343 The time when treatment occurred, should be in reference to the data index
44- :param formula:
45- A statistical model formula
44+ :param control_units:
45+ A list of control units to be used in the experiment
46+ :param treated_units:
47+ A list of treated units to be used in the experiment
4648 :param model:
4749 A PyMC model
4850
@@ -55,7 +57,8 @@ class SyntheticControl(BaseExperiment):
5557 >>> result = cp.SyntheticControl(
5658 ... df,
5759 ... treatment_time,
58- ... formula="actual ~ 0 + a + b + c + d + e + f + g",
60+ ... control_units=["a", "b", "c", "d", "e", "f", "g"],
61+ ... treated_units=["actual"],
5962 ... model=cp.pymc_models.WeightedSumFitter(
6063 ... sample_kwargs={
6164 ... "target_accept": 0.95,
@@ -66,63 +69,111 @@ class SyntheticControl(BaseExperiment):
6669 ... )
6770 """
6871
69- expt_type = "SyntheticControl"
7072 supports_ols = True
7173 supports_bayes = True
7274
7375 def __init__ (
7476 self ,
7577 data : pd .DataFrame ,
7678 treatment_time : Union [int , float , pd .Timestamp ],
77- formula : str ,
79+ control_units : list [str ],
80+ treated_units : list [str ],
7881 model = None ,
7982 ** kwargs ,
8083 ) -> None :
8184 super ().__init__ (model = model )
8285 self .input_validation (data , treatment_time )
8386 self .treatment_time = treatment_time
84- # set experiment type - usually done in subclasses
85- self .expt_type = "Pre-Post Fit"
87+ self .control_units = control_units
88+ self .treated_units = treated_units
89+ self .expt_type = "SyntheticControl"
8690 # split data in to pre and post intervention
8791 self .datapre = data [data .index < self .treatment_time ]
8892 self .datapost = data [data .index >= self .treatment_time ]
8993
90- self .formula = formula
91-
92- # set things up with pre-intervention data
93- y , X = dmatrices (formula , self .datapre )
94- self .outcome_variable_name = y .design_info .column_names [0 ]
95- self ._y_design_info = y .design_info
96- self ._x_design_info = X .design_info
97- self .labels = X .design_info .column_names
98- self .pre_y , self .pre_X = np .asarray (y ), np .asarray (X )
99- # process post-intervention data
100- (new_y , new_x ) = build_design_matrices (
101- [self ._y_design_info , self ._x_design_info ], self .datapost
94+ # split data into the 4 quadrants (pre/post, control/treated) and store as xarray dataarray
95+ # self.datapre_control = self.datapre[self.control_units]
96+ # self.datapre_treated = self.datapre[self.treated_units]
97+ # self.datapost_control = self.datapost[self.control_units]
98+ # self.datapost_treated = self.datapost[self.treated_units]
99+ self .datapre_control = xr .DataArray (
100+ self .datapre [self .control_units ],
101+ dims = ["obs_ind" , "control_units" ],
102+ coords = {
103+ "obs_ind" : self .datapre [self .control_units ].index ,
104+ "control_units" : self .control_units ,
105+ },
106+ )
107+ self .datapre_treated = xr .DataArray (
108+ self .datapre [self .treated_units ],
109+ dims = ["obs_ind" , "treated_units" ],
110+ coords = {
111+ "obs_ind" : self .datapre [self .treated_units ].index ,
112+ "treated_units" : self .treated_units ,
113+ },
114+ )
115+ self .datapost_control = xr .DataArray (
116+ self .datapost [self .control_units ],
117+ dims = ["obs_ind" , "control_units" ],
118+ coords = {
119+ "obs_ind" : self .datapost [self .control_units ].index ,
120+ "control_units" : self .control_units ,
121+ },
122+ )
123+ self .datapost_treated = xr .DataArray (
124+ self .datapost [self .treated_units ],
125+ dims = ["obs_ind" , "treated_units" ],
126+ coords = {
127+ "obs_ind" : self .datapost [self .treated_units ].index ,
128+ "treated_units" : self .treated_units ,
129+ },
102130 )
103- self .post_X = np .asarray (new_x )
104- self .post_y = np .asarray (new_y )
105131
106132 # fit the model to the observed (pre-intervention) data
107133 if isinstance (self .model , PyMCModel ):
108- COORDS = {"coeffs" : self .labels , "obs_indx" : np .arange (self .pre_X .shape [0 ])}
109- self .model .fit (X = self .pre_X , y = self .pre_y , coords = COORDS )
134+ COORDS = {
135+ "control_units" : self .control_units ,
136+ "treated_units" : self .treated_units ,
137+ "obs_indx" : np .arange (self .datapre .shape [0 ]),
138+ }
139+ self .model .fit (
140+ X = self .datapre_control .to_numpy (),
141+ y = self .datapre_treated .to_numpy (),
142+ coords = COORDS ,
143+ )
110144 elif isinstance (self .model , RegressorMixin ):
111- self .model .fit (X = self .pre_X , y = self .pre_y )
145+ self .model .fit (
146+ X = self .datapre_control .to_numpy (), y = self .datapre_treated .to_numpy ()
147+ )
112148 else :
113149 raise ValueError ("Model type not recognized" )
114150
115151 # score the goodness of fit to the pre-intervention data
116- self .score = self .model .score (X = self .pre_X , y = self .pre_y )
152+ self .score = self .model .score (
153+ X = self .datapre_control .to_numpy (), y = self .datapre_treated .to_numpy ()
154+ )
117155
118156 # get the model predictions of the observed (pre-intervention) data
119- self .pre_pred = self .model .predict (X = self .pre_X )
157+ self .pre_pred = self .model .predict (X = self .datapre_control . to_numpy () )
120158
121159 # calculate the counterfactual
122- self .post_pred = self .model .predict (X = self .post_X )
123- self .pre_impact = self .model .calculate_impact (self .pre_y [:, 0 ], self .pre_pred )
160+ self .post_pred = self .model .predict (X = self .datapost_control .to_numpy ())
161+ # TODO: Remove the need for this 'hack' by properly updating the coords when we
162+ # run model.predict
163+ # TEMPORARY HACK: --------------------------------------------------------------
164+ # : set the coords (obs_ind) for self.post_pred to be the same as the datapost
165+ # index. This is needed for xarray to properly do the comparison (-) between
166+ # datapre_treated and self.post_pred
167+ # self.post_pred["posterior_predictive"] = self.post_pred[
168+ # "posterior_predictive"
169+ # ].assign_coords(obs_ind=self.datapost.index)
170+ # ------------------------------------------------------------------------------
171+ self .pre_impact = self .model .calculate_impact (
172+ self .datapre_treated , self .pre_pred
173+ )
174+
124175 self .post_impact = self .model .calculate_impact (
125- self .post_y [:, 0 ] , self .post_pred
176+ self .datapost_treated , self .post_pred
126177 )
127178 self .post_impact_cumulative = self .model .calculate_cumulative_impact (
128179 self .post_impact
@@ -150,7 +201,11 @@ def summary(self, round_to=None) -> None:
150201 Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers
151202 """
152203 print (f"{ self .expt_type :=^80} " )
153- print (f"Formula: { self .formula } " )
204+ print (f"Control units: { self .control_units } " )
205+ if len (self .treated_units ) > 1 :
206+ print (f"Treated units: { self .treated_units } " )
207+ else :
208+ print (f"Treated unit: { self .treated_units [0 ]} " )
154209 self .print_coefficients (round_to )
155210
156211 def _bayesian_plot (
@@ -176,7 +231,9 @@ def _bayesian_plot(
176231 handles = [(h_line , h_patch )]
177232 labels = ["Pre-intervention period" ]
178233
179- (h ,) = ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
234+ (h ,) = ax [0 ].plot (
235+ self .datapre .index , self .datapre_treated , "k." , label = "Observations"
236+ )
180237 handles .append (h )
181238 labels .append ("Observations" )
182239
@@ -190,14 +247,14 @@ def _bayesian_plot(
190247 handles .append ((h_line , h_patch ))
191248 labels .append (counterfactual_label )
192249
193- ax [0 ].plot (self .datapost .index , self .post_y , "k." )
250+ ax [0 ].plot (self .datapost .index , self .datapost_treated , "k." )
194251 # Shaded causal effect
195252 h = ax [0 ].fill_between (
196253 self .datapost .index ,
197254 y1 = az .extract (
198255 self .post_pred , group = "posterior_predictive" , var_names = "mu"
199256 ).mean ("sample" ),
200- y2 = np .squeeze (self .post_y ),
257+ y2 = np .squeeze (self .datapost_treated ),
201258 color = "C0" ,
202259 alpha = 0.25 ,
203260 )
@@ -214,20 +271,20 @@ def _bayesian_plot(
214271 # MIDDLE PLOT -----------------------------------------------
215272 plot_xY (
216273 self .datapre .index ,
217- self .pre_impact ,
274+ self .pre_impact . sel ( treated_units = "actual" ) ,
218275 ax = ax [1 ],
219276 plot_hdi_kwargs = {"color" : "C0" },
220277 )
221278 plot_xY (
222279 self .datapost .index ,
223- self .post_impact ,
280+ self .post_impact . sel ( treated_units = "actual" ) ,
224281 ax = ax [1 ],
225282 plot_hdi_kwargs = {"color" : "C1" },
226283 )
227284 ax [1 ].axhline (y = 0 , c = "k" )
228285 ax [1 ].fill_between (
229286 self .datapost .index ,
230- y1 = self .post_impact .mean (["chain" , "draw" ]),
287+ y1 = self .post_impact .mean (["chain" , "draw" ]). sel ( treated_units = "actual" ) ,
231288 color = "C0" ,
232289 alpha = 0.25 ,
233290 label = "Causal impact" ,
@@ -238,7 +295,7 @@ def _bayesian_plot(
238295 ax [2 ].set (title = "Cumulative Causal Impact" )
239296 plot_xY (
240297 self .datapost .index ,
241- self .post_impact_cumulative ,
298+ self .post_impact_cumulative . sel ( treated_units = "actual" ) ,
242299 ax = ax [2 ],
243300 plot_hdi_kwargs = {"color" : "C1" },
244301 )
@@ -259,15 +316,22 @@ def _bayesian_plot(
259316 fontsize = LEGEND_FONT_SIZE ,
260317 )
261318
262- # code above: same as `PrePostFit._bayesian_plot` -------------------------------
263- # code below: additional for the synthetic control experiment ------------------
264-
265319 plot_predictors = kwargs .get ("plot_predictors" , False )
266320 if plot_predictors :
267321 # plot control units as well
268- ax [0 ].plot (self .datapre .index , self .pre_X , "-" , c = [0.8 , 0.8 , 0.8 ], zorder = 1 )
269322 ax [0 ].plot (
270- self .datapost .index , self .post_X , "-" , c = [0.8 , 0.8 , 0.8 ], zorder = 1
323+ self .datapre .index ,
324+ self .datapre_control ,
325+ "-" ,
326+ c = [0.8 , 0.8 , 0.8 ],
327+ zorder = 1 ,
328+ )
329+ ax [0 ].plot (
330+ self .datapost .index ,
331+ self .datapost_control ,
332+ "-" ,
333+ c = [0.8 , 0.8 , 0.8 ],
334+ zorder = 1 ,
271335 )
272336
273337 return fig , ax
0 commit comments