@@ -134,7 +134,7 @@ def _input_validation(self, data, treatment_time):
134
134
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
135
135
)
136
136
137
- def plot (self ):
137
+ def plot (self , counterfactual_label = "Counterfactual" , ** kwargs ):
138
138
"""Plot the results"""
139
139
fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
140
140
@@ -161,7 +161,7 @@ def plot(self):
161
161
plot_hdi_kwargs = {"color" : "C1" },
162
162
)
163
163
handles .append ((h_line , h_patch ))
164
- labels .append ("Synthetic control" )
164
+ labels .append (counterfactual_label )
165
165
166
166
ax [0 ].plot (self .datapost .index , self .post_y , "k." )
167
167
# Shaded causal effect
@@ -243,14 +243,20 @@ def summary(self):
243
243
self .print_coefficients ()
244
244
245
245
246
+ class InterruptedTimeSeries (PrePostFit ):
247
+ """Interrupted time series analysis"""
248
+
249
+ expt_type = "Interrupted Time Series"
250
+
251
+
246
252
class SyntheticControl (PrePostFit ):
247
253
"""A wrapper around the PrePostFit class"""
248
254
249
255
expt_type = "Synthetic Control"
250
256
251
- def plot (self , plot_predictors = False ):
257
+ def plot (self , plot_predictors = False , ** kwargs ):
252
258
"""Plot the results"""
253
- fig , ax = super ().plot ()
259
+ fig , ax = super ().plot (counterfactual_label = "Synthetic control" , ** kwargs )
254
260
if plot_predictors :
255
261
# plot control units as well
256
262
ax [0 ].plot (self .datapre .index , self .pre_X , "-" , c = [0.8 , 0.8 , 0.8 ], zorder = 1 )
0 commit comments