Skip to content

Commit 7bbff4f

Browse files
committed
initial efforts
1 parent a39e015 commit 7bbff4f

File tree

3 files changed

+143
-294
lines changed

3 files changed

+143
-294
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 107 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import arviz as az
2121
import numpy as np
2222
import pandas as pd
23+
import xarray as xr
2324
from matplotlib import pyplot as plt
24-
from patsy import build_design_matrices, dmatrices
2525
from sklearn.base import RegressorMixin
2626

2727
from causalpy.custom_exceptions import BadIndexException
@@ -41,8 +41,10 @@ class SyntheticControl(BaseExperiment):
4141
A pandas dataframe
4242
:param treatment_time:
4343
The time when treatment occurred, should be in reference to the data index
44-
:param formula:
45-
A statistical model formula
44+
:param control_units:
45+
A list of control units to be used in the experiment
46+
:param treated_units:
47+
A list of treated units to be used in the experiment
4648
:param model:
4749
A PyMC model
4850
@@ -55,7 +57,8 @@ class SyntheticControl(BaseExperiment):
5557
>>> result = cp.SyntheticControl(
5658
... df,
5759
... treatment_time,
58-
... formula="actual ~ 0 + a + b + c + d + e + f + g",
60+
... control_units=["a", "b", "c", "d", "e", "f", "g"],
61+
... treated_units=["actual"],
5962
... model=cp.pymc_models.WeightedSumFitter(
6063
... sample_kwargs={
6164
... "target_accept": 0.95,
@@ -66,63 +69,111 @@ class SyntheticControl(BaseExperiment):
6669
... )
6770
"""
6871

69-
expt_type = "SyntheticControl"
7072
supports_ols = True
7173
supports_bayes = True
7274

7375
def __init__(
7476
self,
7577
data: pd.DataFrame,
7678
treatment_time: Union[int, float, pd.Timestamp],
77-
formula: str,
79+
control_units: list[str],
80+
treated_units: list[str],
7881
model=None,
7982
**kwargs,
8083
) -> None:
8184
super().__init__(model=model)
8285
self.input_validation(data, treatment_time)
8386
self.treatment_time = treatment_time
84-
# set experiment type - usually done in subclasses
85-
self.expt_type = "Pre-Post Fit"
87+
self.control_units = control_units
88+
self.treated_units = treated_units
89+
self.expt_type = "SyntheticControl"
8690
# split data in to pre and post intervention
8791
self.datapre = data[data.index < self.treatment_time]
8892
self.datapost = data[data.index >= self.treatment_time]
8993

90-
self.formula = formula
91-
92-
# set things up with pre-intervention data
93-
y, X = dmatrices(formula, self.datapre)
94-
self.outcome_variable_name = y.design_info.column_names[0]
95-
self._y_design_info = y.design_info
96-
self._x_design_info = X.design_info
97-
self.labels = X.design_info.column_names
98-
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
99-
# process post-intervention data
100-
(new_y, new_x) = build_design_matrices(
101-
[self._y_design_info, self._x_design_info], self.datapost
94+
# split data into the 4 quadrants (pre/post, control/treated) and store as xarray dataarray
95+
# self.datapre_control = self.datapre[self.control_units]
96+
# self.datapre_treated = self.datapre[self.treated_units]
97+
# self.datapost_control = self.datapost[self.control_units]
98+
# self.datapost_treated = self.datapost[self.treated_units]
99+
self.datapre_control = xr.DataArray(
100+
self.datapre[self.control_units],
101+
dims=["obs_ind", "control_units"],
102+
coords={
103+
"obs_ind": self.datapre[self.control_units].index,
104+
"control_units": self.control_units,
105+
},
106+
)
107+
self.datapre_treated = xr.DataArray(
108+
self.datapre[self.treated_units],
109+
dims=["obs_ind", "treated_units"],
110+
coords={
111+
"obs_ind": self.datapre[self.treated_units].index,
112+
"treated_units": self.treated_units,
113+
},
114+
)
115+
self.datapost_control = xr.DataArray(
116+
self.datapost[self.control_units],
117+
dims=["obs_ind", "control_units"],
118+
coords={
119+
"obs_ind": self.datapost[self.control_units].index,
120+
"control_units": self.control_units,
121+
},
122+
)
123+
self.datapost_treated = xr.DataArray(
124+
self.datapost[self.treated_units],
125+
dims=["obs_ind", "treated_units"],
126+
coords={
127+
"obs_ind": self.datapost[self.treated_units].index,
128+
"treated_units": self.treated_units,
129+
},
102130
)
103-
self.post_X = np.asarray(new_x)
104-
self.post_y = np.asarray(new_y)
105131

106132
# fit the model to the observed (pre-intervention) data
107133
if isinstance(self.model, PyMCModel):
108-
COORDS = {"coeffs": self.labels, "obs_indx": np.arange(self.pre_X.shape[0])}
109-
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
134+
COORDS = {
135+
"control_units": self.control_units,
136+
"treated_units": self.treated_units,
137+
"obs_indx": np.arange(self.datapre.shape[0]),
138+
}
139+
self.model.fit(
140+
X=self.datapre_control.to_numpy(),
141+
y=self.datapre_treated.to_numpy(),
142+
coords=COORDS,
143+
)
110144
elif isinstance(self.model, RegressorMixin):
111-
self.model.fit(X=self.pre_X, y=self.pre_y)
145+
self.model.fit(
146+
X=self.datapre_control.to_numpy(), y=self.datapre_treated.to_numpy()
147+
)
112148
else:
113149
raise ValueError("Model type not recognized")
114150

115151
# score the goodness of fit to the pre-intervention data
116-
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
152+
self.score = self.model.score(
153+
X=self.datapre_control.to_numpy(), y=self.datapre_treated.to_numpy()
154+
)
117155

118156
# get the model predictions of the observed (pre-intervention) data
119-
self.pre_pred = self.model.predict(X=self.pre_X)
157+
self.pre_pred = self.model.predict(X=self.datapre_control.to_numpy())
120158

121159
# calculate the counterfactual
122-
self.post_pred = self.model.predict(X=self.post_X)
123-
self.pre_impact = self.model.calculate_impact(self.pre_y[:, 0], self.pre_pred)
160+
self.post_pred = self.model.predict(X=self.datapost_control.to_numpy())
161+
# TODO: Remove the need for this 'hack' by properly updating the coords when we
162+
# run model.predict
163+
# TEMPORARY HACK: --------------------------------------------------------------
164+
# : set the coords (obs_ind) for self.post_pred to be the same as the datapost
165+
# index. This is needed for xarray to properly do the comparison (-) between
166+
# datapre_treated and self.post_pred
167+
# self.post_pred["posterior_predictive"] = self.post_pred[
168+
# "posterior_predictive"
169+
# ].assign_coords(obs_ind=self.datapost.index)
170+
# ------------------------------------------------------------------------------
171+
self.pre_impact = self.model.calculate_impact(
172+
self.datapre_treated, self.pre_pred
173+
)
174+
124175
self.post_impact = self.model.calculate_impact(
125-
self.post_y[:, 0], self.post_pred
176+
self.datapost_treated, self.post_pred
126177
)
127178
self.post_impact_cumulative = self.model.calculate_cumulative_impact(
128179
self.post_impact
@@ -150,7 +201,11 @@ def summary(self, round_to=None) -> None:
150201
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers
151202
"""
152203
print(f"{self.expt_type:=^80}")
153-
print(f"Formula: {self.formula}")
204+
print(f"Control units: {self.control_units}")
205+
if len(self.treated_units) > 1:
206+
print(f"Treated units: {self.treated_units}")
207+
else:
208+
print(f"Treated unit: {self.treated_units[0]}")
154209
self.print_coefficients(round_to)
155210

156211
def _bayesian_plot(
@@ -176,7 +231,9 @@ def _bayesian_plot(
176231
handles = [(h_line, h_patch)]
177232
labels = ["Pre-intervention period"]
178233

179-
(h,) = ax[0].plot(self.datapre.index, self.pre_y, "k.", label="Observations")
234+
(h,) = ax[0].plot(
235+
self.datapre.index, self.datapre_treated, "k.", label="Observations"
236+
)
180237
handles.append(h)
181238
labels.append("Observations")
182239

@@ -190,14 +247,14 @@ def _bayesian_plot(
190247
handles.append((h_line, h_patch))
191248
labels.append(counterfactual_label)
192249

193-
ax[0].plot(self.datapost.index, self.post_y, "k.")
250+
ax[0].plot(self.datapost.index, self.datapost_treated, "k.")
194251
# Shaded causal effect
195252
h = ax[0].fill_between(
196253
self.datapost.index,
197254
y1=az.extract(
198255
self.post_pred, group="posterior_predictive", var_names="mu"
199256
).mean("sample"),
200-
y2=np.squeeze(self.post_y),
257+
y2=np.squeeze(self.datapost_treated),
201258
color="C0",
202259
alpha=0.25,
203260
)
@@ -214,20 +271,20 @@ def _bayesian_plot(
214271
# MIDDLE PLOT -----------------------------------------------
215272
plot_xY(
216273
self.datapre.index,
217-
self.pre_impact,
274+
self.pre_impact.sel(treated_units="actual"),
218275
ax=ax[1],
219276
plot_hdi_kwargs={"color": "C0"},
220277
)
221278
plot_xY(
222279
self.datapost.index,
223-
self.post_impact,
280+
self.post_impact.sel(treated_units="actual"),
224281
ax=ax[1],
225282
plot_hdi_kwargs={"color": "C1"},
226283
)
227284
ax[1].axhline(y=0, c="k")
228285
ax[1].fill_between(
229286
self.datapost.index,
230-
y1=self.post_impact.mean(["chain", "draw"]),
287+
y1=self.post_impact.mean(["chain", "draw"]).sel(treated_units="actual"),
231288
color="C0",
232289
alpha=0.25,
233290
label="Causal impact",
@@ -238,7 +295,7 @@ def _bayesian_plot(
238295
ax[2].set(title="Cumulative Causal Impact")
239296
plot_xY(
240297
self.datapost.index,
241-
self.post_impact_cumulative,
298+
self.post_impact_cumulative.sel(treated_units="actual"),
242299
ax=ax[2],
243300
plot_hdi_kwargs={"color": "C1"},
244301
)
@@ -259,15 +316,22 @@ def _bayesian_plot(
259316
fontsize=LEGEND_FONT_SIZE,
260317
)
261318

262-
# code above: same as `PrePostFit._bayesian_plot` -------------------------------
263-
# code below: additional for the synthetic control experiment ------------------
264-
265319
plot_predictors = kwargs.get("plot_predictors", False)
266320
if plot_predictors:
267321
# plot control units as well
268-
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)
269322
ax[0].plot(
270-
self.datapost.index, self.post_X, "-", c=[0.8, 0.8, 0.8], zorder=1
323+
self.datapre.index,
324+
self.datapre_control,
325+
"-",
326+
c=[0.8, 0.8, 0.8],
327+
zorder=1,
328+
)
329+
ax[0].plot(
330+
self.datapost.index,
331+
self.datapost_control,
332+
"-",
333+
c=[0.8, 0.8, 0.8],
334+
zorder=1,
271335
)
272336

273337
return fig, ax

causalpy/pymc_models.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def _data_setter(self, X) -> None:
8989
prediction.
9090
"""
9191
with self:
92+
# TODO: update coords
9293
pm.set_data({"X": X})
9394

9495
def fit(self, X, y, coords: Optional[Dict[str, Any]] = None) -> None:
@@ -150,10 +151,11 @@ def score(self, X, y) -> pd.Series:
150151
# Note: First argument must be a 1D array
151152
return r2_score(y.flatten(), mu)
152153

153-
def calculate_impact(self, y_true, y_pred):
154-
pre_data = xr.DataArray(y_true, dims=["obs_ind"])
155-
impact = pre_data - y_pred["posterior_predictive"]["y_hat"]
156-
return impact.transpose(..., "obs_ind")
154+
def calculate_impact(
155+
self, y_true: xr.DataArray, y_pred: az.InferenceData
156+
) -> xr.DataArray:
157+
impact = y_true - y_pred["posterior_predictive"]["y_hat"]
158+
return impact.transpose(..., "treated_units", "obs_ind")
157159

158160
def calculate_cumulative_impact(self, impact):
159161
return impact.cumsum(dim="obs_ind")
@@ -255,17 +257,13 @@ def build_model(self, X, y, coords):
255257
"""
256258
Defines the PyMC model
257259
"""
260+
print(coords)
258261
with self:
259262
self.add_coords(coords)
260263
n_predictors = X.shape[1]
261-
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
264+
X = pm.Data("X", X, dims=["obs_ind", "control_units"])
262265
y = pm.Data("y", y[:, 0], dims="obs_ind")
263-
# TODO: There we should allow user-specified priors here
264-
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
265-
# beta = pm.Dirichlet(
266-
# name="beta", a=(1 / n_predictors) * np.ones(n_predictors),
267-
# dims="coeffs"
268-
# )
266+
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="control_units")
269267
sigma = pm.HalfNormal("sigma", 1)
270268
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
271269
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")

0 commit comments

Comments
 (0)