@@ -147,13 +147,17 @@ def __init__(
147147 coords = COORDS ,
148148 )
149149 elif isinstance (self .model , RegressorMixin ):
150- self .model .fit (X = self .datapre_control , y = self .datapre_treated )
150+ self .model .fit (
151+ X = self .datapre_control .data ,
152+ y = self .datapre_treated .isel (treated_units = 0 ).data ,
153+ )
151154 else :
152155 raise ValueError ("Model type not recognized" )
153156
154157 # score the goodness of fit to the pre-intervention data
155158 self .score = self .model .score (
156- X = self .datapre_control .to_numpy (), y = self .datapre_treated .to_numpy ()
159+ X = self .datapre_control .to_numpy (),
160+ y = self .datapre_treated .isel (treated_units = 0 ).to_numpy (),
157161 )
158162
159163 # get the model predictions of the observed (pre-intervention) data
@@ -168,6 +172,7 @@ def __init__(
168172 self .post_impact = self .model .calculate_impact (
169173 self .datapost_treated , self .post_pred
170174 )
175+
171176 self .post_impact_cumulative = self .model .calculate_cumulative_impact (
172177 self .post_impact
173178 )
@@ -342,8 +347,16 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
342347
343348 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
344349
345- ax [0 ].plot (self .datapre .index , self .pre_y , "k." )
346- ax [0 ].plot (self .datapost .index , self .post_y , "k." )
350+ ax [0 ].plot (
351+ self .datapre_treated ["obs_ind" ],
352+ self .datapre_treated .isel (treated_units = 0 ),
353+ "k." ,
354+ )
355+ ax [0 ].plot (
356+ self .datapost_treated ["obs_ind" ],
357+ self .datapost_treated .isel (treated_units = 0 ),
358+ "k." ,
359+ )
347360
348361 ax [0 ].plot (self .datapre .index , self .pre_pred , c = "k" , label = "model fit" )
349362 ax [0 ].plot (
@@ -356,8 +369,17 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
356369 ax [0 ].set (
357370 title = f"$R^2$ on pre-intervention data = { round_num (self .score , round_to )} "
358371 )
372+ # Shaded causal effect
373+ ax [0 ].fill_between (
374+ self .datapost .index ,
375+ y1 = np .squeeze (self .post_pred ),
376+ y2 = np .squeeze (self .datapost_treated .isel (treated_units = 0 ).data ),
377+ color = "C0" ,
378+ alpha = 0.25 ,
379+ label = "Causal impact" ,
380+ )
359381
360- ax [1 ].plot (self .datapre .index , self .pre_impact , "k ." )
382+ ax [1 ].plot (self .datapre .index , self .pre_impact , "r ." )
361383 ax [1 ].plot (
362384 self .datapost .index ,
363385 self .post_impact ,
@@ -372,14 +394,6 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, List[plt.Axes]
372394 ax [2 ].set (title = "Cumulative Causal Impact" )
373395
374396 # Shaded causal effect
375- ax [0 ].fill_between (
376- self .datapost .index ,
377- y1 = np .squeeze (self .post_pred ),
378- y2 = np .squeeze (self .post_y ),
379- color = "C0" ,
380- alpha = 0.25 ,
381- label = "Causal impact" ,
382- )
383397 ax [1 ].fill_between (
384398 self .datapost .index ,
385399 y1 = np .squeeze (self .post_impact ),
0 commit comments