23
23
from patsy import build_design_matrices , dmatrices
24
24
from sklearn .linear_model import LinearRegression as sk_lin_reg
25
25
26
- from causalpy .custom_exceptions import BadIndexException # NOQA
27
- from causalpy .custom_exceptions import DataException , FormulaException
26
+ from causalpy .custom_exceptions import (
27
+ BadIndexException , # NOQA
28
+ DataException ,
29
+ FormulaException ,
30
+ )
28
31
from causalpy .plot_utils import plot_xY
29
- from causalpy .utils import _is_variable_dummy_coded
32
+ from causalpy .utils import _is_variable_dummy_coded , round_num
30
33
31
34
LEGEND_FONT_SIZE = 12
32
35
az .style .use ("arviz-darkgrid" )
@@ -228,7 +231,7 @@ def _input_validation(self, data, treatment_time):
228
231
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
229
232
)
230
233
231
- def plot (self , counterfactual_label = "Counterfactual" , ** kwargs ):
234
+ def plot (self , counterfactual_label = "Counterfactual" , round_to = None , ** kwargs ):
232
235
"""
233
236
Plot the results
234
237
"""
@@ -275,8 +278,8 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
275
278
276
279
ax [0 ].set (
277
280
title = f"""
278
- Pre-intervention Bayesian $R^2$: { self .score .r2 :.3f }
279
- (std = { self .score .r2_std :.3f } )
281
+ Pre-intervention Bayesian $R^2$: { round_num ( self .score .r2 , round_to ) }
282
+ (std = { round_num ( self .score .r2_std , round_to ) } )
280
283
"""
281
284
)
282
285
@@ -580,7 +583,7 @@ def _input_validation(self):
580
583
coded. Consisting of 0's and 1's only."""
581
584
)
582
585
583
- def plot (self ):
586
+ def plot (self , round_to = None ):
584
587
"""Plot the results.
585
588
Creating the combined mean + HDI legend entries is a bit involved.
586
589
"""
@@ -658,7 +661,7 @@ def plot(self):
658
661
# formatting
659
662
ax .set (
660
663
xticks = self .x_pred_treatment [self .time_variable_name ].values ,
661
- title = self ._causal_impact_summary_stat (),
664
+ title = self ._causal_impact_summary_stat (round_to ),
662
665
)
663
666
ax .legend (
664
667
handles = (h_tuple for h_tuple in handles ),
@@ -711,11 +714,14 @@ def _plot_causal_impact_arrow(self, ax):
711
714
va = "center" ,
712
715
)
713
716
714
- def _causal_impact_summary_stat (self ) -> str :
717
+ def _causal_impact_summary_stat (self , round_to = None ) -> str :
715
718
"""Computes the mean and 94% credible interval bounds for the causal impact."""
716
719
percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
717
- ci = "$CI_{94\\ %}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
718
- causal_impact = f"{ self .causal_impact .mean ():.2f} , "
720
+ ci = (
721
+ "$CI_{94\\ %}$"
722
+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
723
+ )
724
+ causal_impact = f"{ round_num (self .causal_impact .mean (), round_to )} , "
719
725
return f"Causal impact = { causal_impact + ci } "
720
726
721
727
def summary (self ) -> None :
@@ -893,7 +899,7 @@ def _is_treated(self, x):
893
899
"""
894
900
return np .greater_equal (x , self .treatment_threshold )
895
901
896
- def plot (self ):
902
+ def plot (self , round_to = None ):
897
903
"""
898
904
Plot the results
899
905
"""
@@ -918,12 +924,15 @@ def plot(self):
918
924
labels = ["Posterior mean" ]
919
925
920
926
# create strings to compose title
921
- title_info = f"{ self .score .r2 :.3f } (std = { self .score .r2_std :.3f } )"
927
+ title_info = f"{ round_num ( self .score .r2 , round_to ) } (std = { round_num ( self .score .r2_std , round_to ) } )"
922
928
r2 = f"Bayesian $R^2$ on all data = { title_info } "
923
929
percentiles = self .discontinuity_at_threshold .quantile ([0.03 , 1 - 0.03 ]).values
924
- ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
930
+ ci = (
931
+ r"$CI_{94\%}$"
932
+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
933
+ )
925
934
discon = f"""
926
- Discontinuity at threshold = { self .discontinuity_at_threshold .mean ():.2f } ,
935
+ Discontinuity at threshold = { round_num ( self .discontinuity_at_threshold .mean (), round_to ) } ,
927
936
"""
928
937
ax .set (title = r2 + "\n " + discon + ci )
929
938
# Intervention line
@@ -1104,7 +1113,7 @@ def _is_treated(self, x):
1104
1113
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
1105
1114
return np .greater_equal (x , self .kink_point )
1106
1115
1107
- def plot (self ):
1116
+ def plot (self , round_to = None ):
1108
1117
"""
1109
1118
Plot the results
1110
1119
"""
@@ -1129,12 +1138,15 @@ def plot(self):
1129
1138
labels = ["Posterior mean" ]
1130
1139
1131
1140
# create strings to compose title
1132
- title_info = f"{ self .score .r2 :.3f } (std = { self .score .r2_std :.3f } )"
1141
+ title_info = f"{ round_num ( self .score .r2 , round_to ) } (std = { round_num ( self .score .r2_std , round_to ) } )"
1133
1142
r2 = f"Bayesian $R^2$ on all data = { title_info } "
1134
1143
percentiles = self .gradient_change .quantile ([0.03 , 1 - 0.03 ]).values
1135
- ci = r"$CI_{94\%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
1144
+ ci = (
1145
+ r"$CI_{94\%}$"
1146
+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
1147
+ )
1136
1148
grad_change = f"""
1137
- Change in gradient = { self .gradient_change .mean ():.2f } ,
1149
+ Change in gradient = { round_num ( self .gradient_change .mean (), round_to ) } ,
1138
1150
"""
1139
1151
ax .set (title = r2 + "\n " + grad_change + ci )
1140
1152
# Intervention line
@@ -1292,7 +1304,7 @@ def _input_validation(self) -> None:
1292
1304
"""
1293
1305
)
1294
1306
1295
- def plot (self ):
1307
+ def plot (self , round_to = None ):
1296
1308
"""Plot the results"""
1297
1309
fig , ax = plt .subplots (
1298
1310
2 , 1 , figsize = (7 , 9 ), gridspec_kw = {"height_ratios" : [3 , 1 ]}
@@ -1339,18 +1351,21 @@ def plot(self):
1339
1351
)
1340
1352
1341
1353
# Plot estimated caual impact / treatment effect
1342
- az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ])
1354
+ az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ], round_to = round_to )
1343
1355
ax [1 ].set (title = "Estimated treatment effect" )
1344
1356
return fig , ax
1345
1357
1346
- def _causal_impact_summary_stat (self ) -> str :
1358
+ def _causal_impact_summary_stat (self , round_to ) -> str :
1347
1359
"""Computes the mean and 94% credible interval bounds for the causal impact."""
1348
1360
percentiles = self .causal_impact .quantile ([0.03 , 1 - 0.03 ]).values
1349
- ci = r"$CI_{94%}$" + f"[{ percentiles [0 ]:.2f} , { percentiles [1 ]:.2f} ]"
1361
+ ci = (
1362
+ r"$CI_{94%}$"
1363
+ + f"[{ round_num (percentiles [0 ], round_to )} , { round_num (percentiles [1 ], round_to )} ]"
1364
+ )
1350
1365
causal_impact = f"{ self .causal_impact .mean ():.2f} , "
1351
1366
return f"Causal impact = { causal_impact + ci } "
1352
1367
1353
- def summary (self ) -> None :
1368
+ def summary (self , round_to = None ) -> None :
1354
1369
"""
1355
1370
Print text output summarising the results
1356
1371
"""
@@ -1359,7 +1374,7 @@ def summary(self) -> None:
1359
1374
print (f"Formula: { self .formula } " )
1360
1375
print ("\n Results:" )
1361
1376
# TODO: extra experiment specific outputs here
1362
- print (self ._causal_impact_summary_stat ())
1377
+ print (self ._causal_impact_summary_stat (round_to ))
1363
1378
self .print_coefficients ()
1364
1379
1365
1380
def _get_treatment_effect_coeff (self ) -> str :
0 commit comments