@@ -24,6 +24,25 @@ def __init__(self, prediction_model=None, **kwargs):
24
24
if self .prediction_model is None :
25
25
raise ValueError ("fitting_model not set or passed." )
26
26
27
+ def print_coefficients (self ):
28
+ """Prints the model coefficients"""
29
+ print ("Model coefficients:" )
30
+ coeffs = az .extract (self .prediction_model .idata .posterior , var_names = "beta" )
31
+ # Note: f"{name: <30}" pads the name with spaces so that we have alignment of the stats despite variable names of different lengths
32
+ for name in self .labels :
33
+ coeff_samples = coeffs .sel (coeffs = name )
34
+ print (
35
+ f" { name : <30} { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
36
+ )
37
+ # add coeff for measurement std
38
+ coeff_samples = az .extract (
39
+ self .prediction_model .idata .posterior , var_names = "sigma"
40
+ )
41
+ name = "sigma"
42
+ print (
43
+ f" { name : <30} { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
44
+ )
45
+
27
46
28
47
class TimeSeriesExperiment (ExperimentalDesign ):
29
48
"""A class to analyse time series quasi-experiments"""
@@ -147,10 +166,20 @@ def plot(self):
147
166
148
167
return (fig , ax )
149
168
169
+ def summary (self ):
170
+ """Print text output summarising the results"""
171
+
172
+ print (f"{ self .expt_type :=^80} " )
173
+ print (f"Formula: { self .formula } " )
174
+ # TODO: extra experiment specific outputs here
175
+ self .print_coefficients ()
176
+
150
177
151
178
class SyntheticControl (TimeSeriesExperiment ):
152
179
"""A wrapper around the TimeSeriesExperiment class"""
153
180
181
+ expt_type = "Synthetic Control"
182
+
154
183
def plot (self ):
155
184
"""Plot the results"""
156
185
fig , ax = super ().plot ()
@@ -163,7 +192,7 @@ def plot(self):
163
192
class InterruptedTimeSeries (TimeSeriesExperiment ):
164
193
"""A wrapper around the TimeSeriesExperiment class"""
165
194
166
- pass
195
+ expt_type = "Interrupted Time Series"
167
196
168
197
169
198
class DifferenceInDifferences (ExperimentalDesign ):
@@ -185,6 +214,7 @@ def __init__(
185
214
):
186
215
super ().__init__ (prediction_model = prediction_model , ** kwargs )
187
216
self .data = data
217
+ self .expt_type = "Difference in Differences"
188
218
self .formula = formula
189
219
self .time_variable_name = time_variable_name
190
220
y , X = dmatrices (formula , self .data )
@@ -324,6 +354,14 @@ def plot(self):
324
354
ax .legend (fontsize = LEGEND_FONT_SIZE )
325
355
return (fig , ax )
326
356
357
+ def summary (self ):
358
+ """Print text output summarising the results"""
359
+
360
+ print (f"{ self .expt_type :=^80} " )
361
+ print (f"Formula: { self .formula } " )
362
+ # TODO: extra experiment specific outputs here
363
+ self .print_coefficients ()
364
+
327
365
328
366
class RegressionDiscontinuity (ExperimentalDesign ):
329
367
"""
@@ -456,19 +494,4 @@ def summary(self):
456
494
print (
457
495
f"Discontinuity at threshold = { self .discontinuity_at_threshold .mean ():.2f} "
458
496
)
459
- print ("Model coefficients:" )
460
- coeffs = az .extract (self .prediction_model .idata .posterior , var_names = "beta" )
461
- # Note: f"{name: <30}" pads the name with spaces so that we have alignment of the stats despite variable names of different lengths
462
- for name in self .labels :
463
- coeff_samples = coeffs .sel (coeffs = name )
464
- print (
465
- f" { name : <30} { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
466
- )
467
- # add coeff for measurement std
468
- coeff_samples = az .extract (
469
- self .prediction_model .idata .posterior , var_names = "sigma"
470
- )
471
- name = "sigma"
472
- print (
473
- f" { name : <30} { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
474
- )
497
+ self .print_coefficients ()
0 commit comments