Skip to content

Commit 876c154

Browse files
committed
bug fixes
1 parent 182aac0 commit 876c154

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

causalpy/experiments/synthetic_control.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def __init__(
8282
**kwargs,
8383
) -> None:
8484
super().__init__(model=model)
85+
# rename the index to "obs_ind"
86+
data.index.name = "obs_ind"
8587
self.input_validation(data, treatment_time)
8688
self.treatment_time = treatment_time
8789
self.control_units = control_units
@@ -93,7 +95,9 @@ def __init__(
9395
self.datapost = data[data.index >= self.treatment_time]
9496

9597
# split data into the 4 quadrants (pre/post, control/treated) and store as
96-
# xarray DataArray objects
98+
# xarray DataArray objects.
99+
# NOTE: if we have renamed/ensured the index is named "obs_ind", then it will
100+
# make constructing the xarray DataArray objects easier.
97101
self.datapre_control = xr.DataArray(
98102
self.datapre[self.control_units],
99103
dims=["obs_ind", "control_units"],
@@ -130,7 +134,9 @@ def __init__(
130134
# fit the model to the observed (pre-intervention) data
131135
if isinstance(self.model, PyMCModel):
132136
COORDS = {
133-
"control_units": self.control_units,
137+
# key must stay as "coeffs" unless we can find a way to auto identify
138+
# the predictor dimension name
139+
"coeffs": self.control_units,
134140
"treated_units": self.treated_units,
135141
"obs_ind": np.arange(self.datapre.shape[0]),
136142
}
@@ -257,20 +263,22 @@ def _bayesian_plot(
257263
# MIDDLE PLOT -----------------------------------------------
258264
plot_xY(
259265
self.datapre.index,
260-
self.pre_impact.sel(treated_units="actual"),
266+
self.pre_impact.sel(treated_units=self.treated_units[0]),
261267
ax=ax[1],
262268
plot_hdi_kwargs={"color": "C0"},
263269
)
264270
plot_xY(
265271
self.datapost.index,
266-
self.post_impact.sel(treated_units="actual"),
272+
self.post_impact.sel(treated_units=self.treated_units[0]),
267273
ax=ax[1],
268274
plot_hdi_kwargs={"color": "C1"},
269275
)
270276
ax[1].axhline(y=0, c="k")
271277
ax[1].fill_between(
272278
self.datapost.index,
273-
y1=self.post_impact.mean(["chain", "draw"]).sel(treated_units="actual"),
279+
y1=self.post_impact.mean(["chain", "draw"]).sel(
280+
treated_units=self.treated_units[0]
281+
),
274282
color="C0",
275283
alpha=0.25,
276284
label="Causal impact",
@@ -281,7 +289,7 @@ def _bayesian_plot(
281289
ax[2].set(title="Cumulative Causal Impact")
282290
plot_xY(
283291
self.datapost.index,
284-
self.post_impact_cumulative.sel(treated_units="actual"),
292+
self.post_impact_cumulative.sel(treated_units=self.treated_units[0]),
285293
ax=ax[2],
286294
plot_hdi_kwargs={"color": "C1"},
287295
)

0 commit comments

Comments
 (0)