Skip to content

Commit 1f4d17e

Browse files
committed
tidy up + fixes
1 parent 3bbabee commit 1f4d17e

File tree

2 files changed

+24
-28
lines changed

2 files changed

+24
-28
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,15 @@ def __init__(
8585
self.input_validation(data, treatment_time)
8686
self.treatment_time = treatment_time
8787
self.control_units = control_units
88+
self.labels = control_units
8889
self.treated_units = treated_units
8990
self.expt_type = "SyntheticControl"
9091
# split data in to pre and post intervention
9192
self.datapre = data[data.index < self.treatment_time]
9293
self.datapost = data[data.index >= self.treatment_time]
9394

94-
# split data into the 4 quadrants (pre/post, control/treated) and store as xarray dataarray
95-
# self.datapre_control = self.datapre[self.control_units]
96-
# self.datapre_treated = self.datapre[self.treated_units]
97-
# self.datapost_control = self.datapost[self.control_units]
98-
# self.datapost_treated = self.datapost[self.treated_units]
95+
# split data into the 4 quadrants (pre/post, control/treated) and store as
96+
# xarray DataArray objects
9997
self.datapre_control = xr.DataArray(
10098
self.datapre[self.control_units],
10199
dims=["obs_ind", "control_units"],
@@ -137,14 +135,12 @@ def __init__(
137135
"obs_ind": np.arange(self.datapre.shape[0]),
138136
}
139137
self.model.fit(
140-
X=self.datapre_control.to_numpy(),
141-
y=self.datapre_treated.to_numpy(),
138+
X=self.datapre_control,
139+
y=self.datapre_treated,
142140
coords=COORDS,
143141
)
144142
elif isinstance(self.model, RegressorMixin):
145-
self.model.fit(
146-
X=self.datapre_control.to_numpy(), y=self.datapre_treated.to_numpy()
147-
)
143+
self.model.fit(X=self.datapre_control, y=self.datapre_treated)
148144
else:
149145
raise ValueError("Model type not recognized")
150146

@@ -154,20 +150,10 @@ def __init__(
154150
)
155151

156152
# get the model predictions of the observed (pre-intervention) data
157-
self.pre_pred = self.model.predict(X=self.datapre_control.to_numpy())
153+
self.pre_pred = self.model.predict(X=self.datapre_control)
158154

159155
# calculate the counterfactual
160-
self.post_pred = self.model.predict(X=self.datapost_control.to_numpy())
161-
# TODO: Remove the need for this 'hack' by properly updating the coords when we
162-
# run model.predict
163-
# TEMPORARY HACK: --------------------------------------------------------------
164-
# : set the coords (obs_ind) for self.post_pred to be the same as the datapost
165-
# index. This is needed for xarray to properly do the comparison (-) between
166-
# datapre_treated and self.post_pred
167-
# self.post_pred["posterior_predictive"] = self.post_pred[
168-
# "posterior_predictive"
169-
# ].assign_coords(obs_ind=self.datapost.index)
170-
# ------------------------------------------------------------------------------
156+
self.post_pred = self.model.predict(X=self.datapost_control)
171157
self.pre_impact = self.model.calculate_impact(
172158
self.datapre_treated, self.pre_pred
173159
)

causalpy/pymc_models.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,20 @@ def predict(self, X):
135135
random_seed = self.sample_kwargs.get("random_seed", None)
136136
self._data_setter(X)
137137
with self: # sample with new input data
138-
post_pred = pm.sample_posterior_predictive(
138+
pp = pm.sample_posterior_predictive(
139139
self.idata,
140140
var_names=["y_hat", "mu"],
141141
progressbar=False,
142142
random_seed=random_seed,
143143
)
144-
return post_pred
144+
145+
# TODO: This is a bit of a hack. Maybe it could be done properly in _data_setter?
146+
if isinstance(X, xr.DataArray):
147+
pp["posterior_predictive"] = pp["posterior_predictive"].assign_coords(
148+
obs_ind=X.obs_ind
149+
)
150+
151+
return pp
145152

146153
def score(self, X, y) -> pd.Series:
147154
"""Score the Bayesian :math:`R^2` given inputs ``X`` and outputs ``y``.
@@ -161,10 +168,13 @@ def score(self, X, y) -> pd.Series:
161168
return r2_score(y.flatten(), mu)
162169

163170
def calculate_impact(
164-
self, y_true: xr.DataArray, y_pred: az.InferenceData
171+
self, y_true: xr.DataArray | np.ndarray, y_pred: az.InferenceData
165172
) -> xr.DataArray:
173+
if isinstance(y_true, np.ndarray):
174+
y_true = xr.DataArray(y_true, dims=["obs_ind"])
175+
166176
impact = y_true - y_pred["posterior_predictive"]["y_hat"]
167-
return impact.transpose(..., "treated_units", "obs_ind")
177+
return impact.transpose(..., "obs_ind")
168178

169179
def calculate_cumulative_impact(self, impact):
170180
return impact.cumsum(dim="obs_ind")
@@ -269,9 +279,9 @@ def build_model(self, X, y, coords):
269279
with self:
270280
self.add_coords(coords)
271281
n_predictors = X.shape[1]
272-
X = pm.Data("X", X, dims=["obs_ind", "control_units"])
282+
X = pm.Data("X", X, dims=["obs_ind", "coeffs"])
273283
y = pm.Data("y", y[:, 0], dims="obs_ind")
274-
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="control_units")
284+
beta = pm.Dirichlet("beta", a=np.ones(n_predictors), dims="coeffs")
275285
sigma = pm.HalfNormal("sigma", 1)
276286
mu = pm.Deterministic("mu", pm.math.dot(X, beta), dims="obs_ind")
277287
pm.Normal("y_hat", mu, sigma, observed=y, dims="obs_ind")

0 commit comments

Comments
 (0)