@@ -103,11 +103,15 @@ def __init__(
103
103
104
104
# causal impact pre (ie the residuals of the model fit to observed)
105
105
pre_data = xr .DataArray (self .pre_y [:, 0 ], dims = ["obs_ind" ])
106
- self .pre_impact = pre_data - self .pre_pred ["posterior_predictive" ].y_hat
106
+ self .pre_impact = (
107
+ pre_data - self .pre_pred ["posterior_predictive" ].y_hat
108
+ ).transpose (..., "obs_ind" )
107
109
108
110
# causal impact post (ie the residuals of the model fit to observed)
109
111
post_data = xr .DataArray (self .post_y [:, 0 ], dims = ["obs_ind" ])
110
- self .post_impact = post_data - self .post_pred ["posterior_predictive" ].y_hat
112
+ self .post_impact = (
113
+ post_data - self .post_pred ["posterior_predictive" ].y_hat
114
+ ).transpose (..., "obs_ind" )
111
115
112
116
# cumulative impact post
113
117
self .post_impact_cumulative = self .post_impact .cumsum (dim = "obs_ind" )
@@ -117,9 +121,12 @@ def plot(self):
117
121
"""Plot the results"""
118
122
fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
119
123
124
+ # TOP PLOT --------------------------------------------------
120
125
# pre-intervention period
121
126
plot_xY (
122
- self .datapre .index , self .pre_pred ["posterior_predictive" ].y_hat , ax = ax [0 ]
127
+ self .datapre .index ,
128
+ self .pre_pred ["posterior_predictive" ].y_hat ,
129
+ ax = ax [0 ],
123
130
)
124
131
ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
125
132
# post intervention period
@@ -130,23 +137,6 @@ def plot(self):
130
137
include_label = False ,
131
138
)
132
139
ax [0 ].plot (self .datapost .index , self .post_y , "k." )
133
-
134
- ax [0 ].set (
135
- title = f"""
136
- Pre-intervention Bayesian $R^2$: { self .score .r2 :.3f}
137
- (std = { self .score .r2_std :.3f} )
138
- """
139
- )
140
-
141
- plot_xY (self .datapre .index , self .pre_impact , ax = ax [1 ])
142
- plot_xY (self .datapost .index , self .post_impact , ax = ax [1 ], include_label = False )
143
- ax [1 ].axhline (y = 0 , c = "k" )
144
- ax [1 ].set (title = "Causal Impact" )
145
-
146
- ax [2 ].set (title = "Cumulative Causal Impact" )
147
- plot_xY (self .datapost .index , self .post_impact_cumulative , ax = ax [2 ])
148
- ax [2 ].axhline (y = 0 , c = "k" )
149
-
150
140
# Shaded causal effect
151
141
ax [0 ].fill_between (
152
142
self .datapost .index ,
@@ -158,13 +148,44 @@ def plot(self):
158
148
alpha = 0.25 ,
159
149
label = "Causal impact" ,
160
150
)
151
+ ax [0 ].set (
152
+ title = f"""
153
+ Pre-intervention Bayesian $R^2$: { self .score .r2 :.3f}
154
+ (std = { self .score .r2_std :.3f} )
155
+ """
156
+ )
157
+
158
+ # MIDDLE PLOT -----------------------------------------------
159
+ plot_xY (
160
+ self .datapre .index ,
161
+ self .pre_impact ,
162
+ ax = ax [1 ],
163
+ )
164
+ plot_xY (
165
+ self .datapost .index ,
166
+ self .post_impact ,
167
+ ax = ax [1 ],
168
+ include_label = False ,
169
+ )
170
+ ax [1 ].axhline (y = 0 , c = "k" )
161
171
ax [1 ].fill_between (
162
172
self .datapost .index ,
163
173
y1 = self .post_impact .mean (["chain" , "draw" ]),
164
174
color = "C0" ,
165
175
alpha = 0.25 ,
166
176
label = "Causal impact" ,
167
177
)
178
+ ax [1 ].set (title = "Causal Impact" )
179
+
180
+ # BOTTOM PLOT -----------------------------------------------
181
+
182
+ ax [2 ].set (title = "Cumulative Causal Impact" )
183
+ plot_xY (
184
+ self .datapost .index ,
185
+ self .post_impact_cumulative ,
186
+ ax = ax [2 ],
187
+ )
188
+ ax [2 ].axhline (y = 0 , c = "k" )
168
189
169
190
# Intervention line
170
191
for i in [0 , 1 , 2 ]:
0 commit comments