Skip to content

Commit a148ec3

Browse files
committed
fix bugs
1 parent ed62f0a commit a148ec3

File tree

3 files changed

+31
-17
lines changed

3 files changed

+31
-17
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,13 +147,17 @@ def __init__(
147147
coords=COORDS,
148148
)
149149
elif isinstance(self.model, RegressorMixin):
150-
self.model.fit(X=self.datapre_control, y=self.datapre_treated)
150+
self.model.fit(
151+
X=self.datapre_control.data,
152+
y=self.datapre_treated.isel(treated_units=0).data,
153+
)
151154
else:
152155
raise ValueError("Model type not recognized")
153156

154157
# score the goodness of fit to the pre-intervention data
155158
self.score = self.model.score(
156-
X=self.datapre_control.to_numpy(), y=self.datapre_treated.to_numpy()
159+
X=self.datapre_control.to_numpy(),
160+
y=self.datapre_treated.isel(treated_units=0).to_numpy(),
157161
)
158162

159163
# get the model predictions of the observed (pre-intervention) data
@@ -168,6 +172,7 @@ def __init__(
168172
self.post_impact = self.model.calculate_impact(
169173
self.datapost_treated, self.post_pred
170174
)
175+
171176
self.post_impact_cumulative = self.model.calculate_cumulative_impact(
172177
self.post_impact
173178
)
@@ -342,8 +347,16 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
342347

343348
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
344349

345-
ax[0].plot(self.datapre.index, self.pre_y, "k.")
346-
ax[0].plot(self.datapost.index, self.post_y, "k.")
350+
ax[0].plot(
351+
self.datapre_treated["obs_ind"],
352+
self.datapre_treated.isel(treated_units=0),
353+
"k.",
354+
)
355+
ax[0].plot(
356+
self.datapost_treated["obs_ind"],
357+
self.datapost_treated.isel(treated_units=0),
358+
"k.",
359+
)
347360

348361
ax[0].plot(self.datapre.index, self.pre_pred, c="k", label="model fit")
349362
ax[0].plot(
@@ -356,8 +369,17 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
356369
ax[0].set(
357370
title=f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"
358371
)
372+
# Shaded causal effect
373+
ax[0].fill_between(
374+
self.datapost.index,
375+
y1=np.squeeze(self.post_pred),
376+
y2=np.squeeze(self.datapost_treated.isel(treated_units=0).data),
377+
color="C0",
378+
alpha=0.25,
379+
label="Causal impact",
380+
)
359381

360-
ax[1].plot(self.datapre.index, self.pre_impact, "k.")
382+
ax[1].plot(self.datapre.index, self.pre_impact, "r.")
361383
ax[1].plot(
362384
self.datapost.index,
363385
self.post_impact,
@@ -372,14 +394,6 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
372394
ax[2].set(title="Cumulative Causal Impact")
373395

374396
# Shaded causal effect
375-
ax[0].fill_between(
376-
self.datapost.index,
377-
y1=np.squeeze(self.post_pred),
378-
y2=np.squeeze(self.post_y),
379-
color="C0",
380-
alpha=0.25,
381-
label="Causal impact",
382-
)
383397
ax[1].fill_between(
384398
self.datapost.index,
385399
y1=np.squeeze(self.post_impact),

causalpy/skl_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class ScikitLearnAdaptor:
2828

2929
def calculate_impact(self, y_true, y_pred):
3030
"""Calculate the causal impact of the intervention."""
31-
return y_true - np.squeeze(y_pred)
31+
return y_true - y_pred
3232

3333
def calculate_cumulative_impact(self, impact):
3434
"""Calculate the cumulative impact intervention."""

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)