9
9
from causalpy .plot_utils import plot_xY
10
10
11
11
LEGEND_FONT_SIZE = 12
12
+ az .style .use ("arviz-darkgrid" )
12
13
13
14
14
15
class ExperimentalDesign :
15
16
"""Base class"""
16
17
17
18
prediction_model = None
19
+ expt_type = None
18
20
19
21
def __init__ (self , prediction_model = None , ** kwargs ):
20
22
if prediction_model is not None :
@@ -345,20 +347,20 @@ def __init__(
345
347
treatment_threshold : float ,
346
348
prediction_model = None ,
347
349
running_variable_name : str = "x" ,
348
- outcome_variable_name = "y" ,
349
350
** kwargs ,
350
351
):
351
352
super ().__init__ (prediction_model = prediction_model , ** kwargs )
353
+ self .expt_type = "Regression Discontinuity"
352
354
self .data = data
353
355
self .formula = formula
354
356
self .running_variable_name = running_variable_name
355
- self .outcome_variable_name = outcome_variable_name
356
357
self .treatment_threshold = treatment_threshold
357
358
y , X = dmatrices (formula , self .data )
358
359
self ._y_design_info = y .design_info
359
360
self ._x_design_info = X .design_info
360
361
self .labels = X .design_info .column_names
361
362
self .y , self .X = np .asarray (y ), np .asarray (X )
363
+ self .outcome_variable_name = y .design_info .column_names [0 ]
362
364
363
365
# TODO: `treated` is a deterministic function of x and treatment_threshold, so this could be a function rather than supplied data
364
366
@@ -445,7 +447,8 @@ def plot(self):
445
447
446
448
def summary (self ):
447
449
"""Print text output summarising the results"""
448
- print ("Difference in Differences experiment" )
450
+
451
+ print (f"{ self .expt_type :=^80} " )
449
452
print (f"Formula: { self .formula } " )
450
453
print (f"Running variable: { self .running_variable_name } " )
451
454
print (f"Threshold on running variable: { self .treatment_threshold } " )
@@ -455,15 +458,17 @@ def summary(self):
455
458
)
456
459
print ("Model coefficients:" )
457
460
coeffs = az .extract (self .prediction_model .idata .posterior , var_names = "beta" )
461
+ # Note: f"{name: <30}" pads the name with spaces so that we have alignment of the stats despite variable names of different lengths
458
462
for name in self .labels :
459
463
coeff_samples = coeffs .sel (coeffs = name )
460
464
print (
461
- f"\t { name } \t \t { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
465
+ f" { name : <30 } { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
462
466
)
463
467
# add coeff for measurement std
464
468
coeff_samples = az .extract (
465
469
self .prediction_model .idata .posterior , var_names = "sigma"
466
470
)
471
+ name = "sigma"
467
472
print (
468
- f"\t sigma \t \t { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
473
+ f" { name : <30 } { coeff_samples .mean ().data :.2f} , 94% HDI [{ coeff_samples .quantile (0.03 ).data :.2f} , { coeff_samples .quantile (1 - 0.03 ).data :.2f} ]"
469
474
)
0 commit comments