@@ -74,7 +74,7 @@ def __init__(
74
74
# cumulative impact post
75
75
self .post_impact_cumulative = np .cumsum (self .post_impact )
76
76
77
- def plot (self ):
77
+ def plot (self , counterfactual_label = "Counterfactual" , ** kwargs ):
78
78
fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
79
79
80
80
ax [0 ].plot (self .datapre .index , self .pre_y , "k." )
@@ -84,7 +84,7 @@ def plot(self):
84
84
ax [0 ].plot (
85
85
self .datapost .index ,
86
86
self .post_pred ,
87
- label = "counterfactual" ,
87
+ label = counterfactual_label ,
88
88
ls = ":" ,
89
89
c = "k" ,
90
90
)
@@ -95,7 +95,7 @@ def plot(self):
95
95
self .datapost .index ,
96
96
self .post_impact ,
97
97
"k." ,
98
- label = "counterfactual" ,
98
+ label = counterfactual_label ,
99
99
)
100
100
ax [1 ].axhline (y = 0 , c = "k" )
101
101
ax [1 ].set (title = "Causal Impact" )
@@ -151,12 +151,18 @@ def plot_coeffs(self):
151
151
)
152
152
153
153
154
+ class InterruptedTimeSeries (PrePostFit ):
155
+ """Interrupted time series analysis"""
156
+
157
+ expt_type = "Interrupted Time Series"
158
+
159
+
154
160
class SyntheticControl (PrePostFit ):
155
161
"""A wrapper around the PrePostFit class"""
156
162
157
- def plot (self , plot_predictors = False ):
163
+ def plot (self , plot_predictors = False , ** kwargs ):
158
164
"""Plot the results"""
159
- fig , ax = super ().plot ()
165
+ fig , ax = super ().plot (counterfactual_label = "Synthetic control" , ** kwargs )
160
166
if plot_predictors :
161
167
# plot control units as well
162
168
ax [0 ].plot (self .datapre .index , self .pre_X , "-" , c = [0.8 , 0.8 , 0.8 ], zorder = 1 )
0 commit comments