@@ -147,13 +147,17 @@ def __init__(
147
147
coords = COORDS ,
148
148
)
149
149
elif isinstance (self .model , RegressorMixin ):
150
- self .model .fit (X = self .datapre_control , y = self .datapre_treated )
150
+ self .model .fit (
151
+ X = self .datapre_control .data ,
152
+ y = self .datapre_treated .isel (treated_units = 0 ).data ,
153
+ )
151
154
else :
152
155
raise ValueError ("Model type not recognized" )
153
156
154
157
# score the goodness of fit to the pre-intervention data
155
158
self .score = self .model .score (
156
- X = self .datapre_control .to_numpy (), y = self .datapre_treated .to_numpy ()
159
+ X = self .datapre_control .to_numpy (),
160
+ y = self .datapre_treated .isel (treated_units = 0 ).to_numpy (),
157
161
)
158
162
159
163
# get the model predictions of the observed (pre-intervention) data
@@ -168,6 +172,7 @@ def __init__(
168
172
self .post_impact = self .model .calculate_impact (
169
173
self .datapost_treated , self .post_pred
170
174
)
175
+
171
176
self .post_impact_cumulative = self .model .calculate_cumulative_impact (
172
177
self .post_impact
173
178
)
@@ -342,8 +347,16 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
342
347
343
348
fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
344
349
345
- ax [0 ].plot (self .datapre .index , self .pre_y , "k." )
346
- ax [0 ].plot (self .datapost .index , self .post_y , "k." )
350
+ ax [0 ].plot (
351
+ self .datapre_treated ["obs_ind" ],
352
+ self .datapre_treated .isel (treated_units = 0 ),
353
+ "k." ,
354
+ )
355
+ ax [0 ].plot (
356
+ self .datapost_treated ["obs_ind" ],
357
+ self .datapost_treated .isel (treated_units = 0 ),
358
+ "k." ,
359
+ )
347
360
348
361
ax [0 ].plot (self .datapre .index , self .pre_pred , c = "k" , label = "model fit" )
349
362
ax [0 ].plot (
@@ -356,8 +369,17 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
356
369
ax [0 ].set (
357
370
title = f"$R^2$ on pre-intervention data = { round_num (self .score , round_to )} "
358
371
)
372
+ # Shaded causal effect
373
+ ax [0 ].fill_between (
374
+ self .datapost .index ,
375
+ y1 = np .squeeze (self .post_pred ),
376
+ y2 = np .squeeze (self .datapost_treated .isel (treated_units = 0 ).data ),
377
+ color = "C0" ,
378
+ alpha = 0.25 ,
379
+ label = "Causal impact" ,
380
+ )
359
381
360
- ax [1 ].plot (self .datapre .index , self .pre_impact , "k ." )
382
+ ax [1 ].plot (self .datapre .index , self .pre_impact , "r ." )
361
383
ax [1 ].plot (
362
384
self .datapost .index ,
363
385
self .post_impact ,
@@ -372,14 +394,6 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
372
394
ax [2 ].set (title = "Cumulative Causal Impact" )
373
395
374
396
# Shaded causal effect
375
- ax [0 ].fill_between (
376
- self .datapost .index ,
377
- y1 = np .squeeze (self .post_pred ),
378
- y2 = np .squeeze (self .post_y ),
379
- color = "C0" ,
380
- alpha = 0.25 ,
381
- label = "Causal impact" ,
382
- )
383
397
ax [1 ].fill_between (
384
398
self .datapost .index ,
385
399
y1 = np .squeeze (self .post_impact ),
0 commit comments