@@ -82,6 +82,8 @@ def __init__(
82
82
** kwargs ,
83
83
) -> None :
84
84
super ().__init__ (model = model )
85
+ # rename the index to "obs_ind"
86
+ data .index .name = "obs_ind"
85
87
self .input_validation (data , treatment_time )
86
88
self .treatment_time = treatment_time
87
89
self .control_units = control_units
@@ -93,7 +95,9 @@ def __init__(
93
95
self .datapost = data [data .index >= self .treatment_time ]
94
96
95
97
# split data into the 4 quadrants (pre/post, control/treated) and store as
96
- # xarray DataArray objects
98
+ # xarray DataArray objects.
99
+ # NOTE: if we have renamed/ensured the index is named "obs_ind", then it will
100
+ # make constructing the xarray DataArray objects easier.
97
101
self .datapre_control = xr .DataArray (
98
102
self .datapre [self .control_units ],
99
103
dims = ["obs_ind" , "control_units" ],
@@ -130,7 +134,9 @@ def __init__(
130
134
# fit the model to the observed (pre-intervention) data
131
135
if isinstance (self .model , PyMCModel ):
132
136
COORDS = {
133
- "control_units" : self .control_units ,
137
+ # key must stay as "coeffs" unless we can find a way to auto identify
138
+ # the predictor dimension name
139
+ "coeffs" : self .control_units ,
134
140
"treated_units" : self .treated_units ,
135
141
"obs_ind" : np .arange (self .datapre .shape [0 ]),
136
142
}
@@ -257,20 +263,22 @@ def _bayesian_plot(
257
263
# MIDDLE PLOT -----------------------------------------------
258
264
plot_xY (
259
265
self .datapre .index ,
260
- self .pre_impact .sel (treated_units = "actual" ),
266
+ self .pre_impact .sel (treated_units = self . treated_units [ 0 ] ),
261
267
ax = ax [1 ],
262
268
plot_hdi_kwargs = {"color" : "C0" },
263
269
)
264
270
plot_xY (
265
271
self .datapost .index ,
266
- self .post_impact .sel (treated_units = "actual" ),
272
+ self .post_impact .sel (treated_units = self . treated_units [ 0 ] ),
267
273
ax = ax [1 ],
268
274
plot_hdi_kwargs = {"color" : "C1" },
269
275
)
270
276
ax [1 ].axhline (y = 0 , c = "k" )
271
277
ax [1 ].fill_between (
272
278
self .datapost .index ,
273
- y1 = self .post_impact .mean (["chain" , "draw" ]).sel (treated_units = "actual" ),
279
+ y1 = self .post_impact .mean (["chain" , "draw" ]).sel (
280
+ treated_units = self .treated_units [0 ]
281
+ ),
274
282
color = "C0" ,
275
283
alpha = 0.25 ,
276
284
label = "Causal impact" ,
@@ -281,7 +289,7 @@ def _bayesian_plot(
281
289
ax [2 ].set (title = "Cumulative Causal Impact" )
282
290
plot_xY (
283
291
self .datapost .index ,
284
- self .post_impact_cumulative .sel (treated_units = "actual" ),
292
+ self .post_impact_cumulative .sel (treated_units = self . treated_units [ 0 ] ),
285
293
ax = ax [2 ],
286
294
plot_hdi_kwargs = {"color" : "C1" },
287
295
)
0 commit comments