Skip to content

Commit dcb913f

Browse files
committed
add round_to kwarg to plot methods
1 parent c3c542d commit dcb913f

File tree

4 files changed

+93
-37
lines changed

4 files changed

+93
-37
lines changed

causalpy/pymc_experiments.py

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,13 @@
2323
from patsy import build_design_matrices, dmatrices
2424
from sklearn.linear_model import LinearRegression as sk_lin_reg
2525

26-
from causalpy.custom_exceptions import BadIndexException # NOQA
27-
from causalpy.custom_exceptions import DataException, FormulaException
26+
from causalpy.custom_exceptions import (
27+
BadIndexException, # NOQA
28+
DataException,
29+
FormulaException,
30+
)
2831
from causalpy.plot_utils import plot_xY
29-
from causalpy.utils import _is_variable_dummy_coded
32+
from causalpy.utils import _is_variable_dummy_coded, round_num
3033

3134
LEGEND_FONT_SIZE = 12
3235
az.style.use("arviz-darkgrid")
@@ -228,7 +231,7 @@ def _input_validation(self, data, treatment_time):
228231
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
229232
)
230233

231-
def plot(self, counterfactual_label="Counterfactual", **kwargs):
234+
def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
232235
"""
233236
Plot the results
234237
"""
@@ -275,8 +278,8 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
275278

276279
ax[0].set(
277280
title=f"""
278-
Pre-intervention Bayesian $R^2$: {self.score.r2:.3f}
279-
(std = {self.score.r2_std:.3f})
281+
Pre-intervention Bayesian $R^2$: {round_num(self.score.r2, round_to)}
282+
(std = {round_num(self.score.r2_std, round_to)})
280283
"""
281284
)
282285

@@ -580,7 +583,7 @@ def _input_validation(self):
580583
coded. Consisting of 0's and 1's only."""
581584
)
582585

583-
def plot(self):
586+
def plot(self, round_to=None):
584587
"""Plot the results.
585588
Creating the combined mean + HDI legend entries is a bit involved.
586589
"""
@@ -658,7 +661,7 @@ def plot(self):
658661
# formatting
659662
ax.set(
660663
xticks=self.x_pred_treatment[self.time_variable_name].values,
661-
title=self._causal_impact_summary_stat(),
664+
title=self._causal_impact_summary_stat(round_to),
662665
)
663666
ax.legend(
664667
handles=(h_tuple for h_tuple in handles),
@@ -711,11 +714,14 @@ def _plot_causal_impact_arrow(self, ax):
711714
va="center",
712715
)
713716

714-
def _causal_impact_summary_stat(self) -> str:
717+
def _causal_impact_summary_stat(self, round_to=None) -> str:
715718
"""Computes the mean and 94% credible interval bounds for the causal impact."""
716719
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
717-
ci = "$CI_{94\\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
718-
causal_impact = f"{self.causal_impact.mean():.2f}, "
720+
ci = (
721+
"$CI_{94\\%}$"
722+
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
723+
)
724+
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
719725
return f"Causal impact = {causal_impact + ci}"
720726

721727
def summary(self) -> None:
@@ -893,7 +899,7 @@ def _is_treated(self, x):
893899
"""
894900
return np.greater_equal(x, self.treatment_threshold)
895901

896-
def plot(self):
902+
def plot(self, round_to=None):
897903
"""
898904
Plot the results
899905
"""
@@ -918,12 +924,15 @@ def plot(self):
918924
labels = ["Posterior mean"]
919925

920926
# create strings to compose title
921-
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
927+
title_info = f"{round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})"
922928
r2 = f"Bayesian $R^2$ on all data = {title_info}"
923929
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
924-
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
930+
ci = (
931+
r"$CI_{94\%}$"
932+
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
933+
)
925934
discon = f"""
926-
Discontinuity at threshold = {self.discontinuity_at_threshold.mean():.2f},
935+
Discontinuity at threshold = {round_num(self.discontinuity_at_threshold.mean(), round_to)},
927936
"""
928937
ax.set(title=r2 + "\n" + discon + ci)
929938
# Intervention line
@@ -1104,7 +1113,7 @@ def _is_treated(self, x):
11041113
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
11051114
return np.greater_equal(x, self.kink_point)
11061115

1107-
def plot(self):
1116+
def plot(self, round_to=None):
11081117
"""
11091118
Plot the results
11101119
"""
@@ -1129,12 +1138,15 @@ def plot(self):
11291138
labels = ["Posterior mean"]
11301139

11311140
# create strings to compose title
1132-
title_info = f"{self.score.r2:.3f} (std = {self.score.r2_std:.3f})"
1141+
title_info = f"{round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})"
11331142
r2 = f"Bayesian $R^2$ on all data = {title_info}"
11341143
percentiles = self.gradient_change.quantile([0.03, 1 - 0.03]).values
1135-
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
1144+
ci = (
1145+
r"$CI_{94\%}$"
1146+
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
1147+
)
11361148
grad_change = f"""
1137-
Change in gradient = {self.gradient_change.mean():.2f},
1149+
Change in gradient = {round_num(self.gradient_change.mean(), round_to)},
11381150
"""
11391151
ax.set(title=r2 + "\n" + grad_change + ci)
11401152
# Intervention line
@@ -1292,7 +1304,7 @@ def _input_validation(self) -> None:
12921304
"""
12931305
)
12941306

1295-
def plot(self):
1307+
def plot(self, round_to=None):
12961308
"""Plot the results"""
12971309
fig, ax = plt.subplots(
12981310
2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]}
@@ -1339,18 +1351,21 @@ def plot(self):
13391351
)
13401352

13411353
# Plot estimated caual impact / treatment effect
1342-
az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1])
1354+
az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1], round_to=round_to)
13431355
ax[1].set(title="Estimated treatment effect")
13441356
return fig, ax
13451357

1346-
def _causal_impact_summary_stat(self) -> str:
1358+
def _causal_impact_summary_stat(self, round_to) -> str:
13471359
"""Computes the mean and 94% credible interval bounds for the causal impact."""
13481360
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
1349-
ci = r"$CI_{94%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
1361+
ci = (
1362+
r"$CI_{94%}$"
1363+
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
1364+
)
13501365
causal_impact = f"{self.causal_impact.mean():.2f}, "
13511366
return f"Causal impact = {causal_impact + ci}"
13521367

1353-
def summary(self) -> None:
1368+
def summary(self, round_to=None) -> None:
13541369
"""
13551370
Print text output summarising the results
13561371
"""
@@ -1359,7 +1374,7 @@ def summary(self) -> None:
13591374
print(f"Formula: {self.formula}")
13601375
print("\nResults:")
13611376
# TODO: extra experiment specific outputs here
1362-
print(self._causal_impact_summary_stat())
1377+
print(self._causal_impact_summary_stat(round_to))
13631378
self.print_coefficients()
13641379

13651380
def _get_treatment_effect_coeff(self) -> str:

causalpy/skl_experiments.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import seaborn as sns
1818
from patsy import build_design_matrices, dmatrices
1919

20+
from causalpy.utils import round_num
21+
2022
LEGEND_FONT_SIZE = 12
2123

2224

@@ -113,7 +115,7 @@ def __init__(
113115
# cumulative impact post
114116
self.post_impact_cumulative = np.cumsum(self.post_impact)
115117

116-
def plot(self, counterfactual_label="Counterfactual", **kwargs):
118+
def plot(self, counterfactual_label="Counterfactual", round_to=None, **kwargs):
117119
"""Plot experiment results"""
118120
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
119121

@@ -128,7 +130,9 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
128130
ls=":",
129131
c="k",
130132
)
131-
ax[0].set(title=f"$R^2$ on pre-intervention data = {self.score:.3f}")
133+
ax[0].set(
134+
title=f"$R^2$ on pre-intervention data = {round_num(self.score, round_to)}"
135+
)
132136

133137
ax[1].plot(self.datapre.index, self.pre_impact, "k.")
134138
ax[1].plot(
@@ -258,9 +262,11 @@ class SyntheticControl(PrePostFit):
258262
... )
259263
"""
260264

261-
def plot(self, plot_predictors=False, **kwargs):
265+
def plot(self, plot_predictors=False, round_to=None, **kwargs):
262266
"""Plot the results"""
263-
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
267+
fig, ax = super().plot(
268+
counterfactual_label="Synthetic control", round_to=round_to, **kwargs
269+
)
264270
if plot_predictors:
265271
# plot control units as well
266272
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)
@@ -397,7 +403,7 @@ def __init__(
397403
# TODO: THIS IS NOT YET CORRECT
398404
self.causal_impact = self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
399405

400-
def plot(self):
406+
def plot(self, round_to=None):
401407
"""Plot results"""
402408
fig, ax = plt.subplots()
403409

@@ -462,7 +468,7 @@ def plot(self):
462468
xlim=[-0.05, 1.1],
463469
xticks=[0, 1],
464470
xticklabels=["pre", "post"],
465-
title=f"Causal impact = {self.causal_impact[0]:.2f}",
471+
title=f"Causal impact = {round_num(self.causal_impact[0], round_to)}",
466472
)
467473
ax.legend(fontsize=LEGEND_FONT_SIZE)
468474
return (fig, ax)
@@ -607,7 +613,7 @@ def _is_treated(self, x):
607613
"""
608614
return np.greater_equal(x, self.treatment_threshold)
609615

610-
def plot(self):
616+
def plot(self, round_to=None):
611617
"""Plot results"""
612618
fig, ax = plt.subplots()
613619
# Plot raw data
@@ -627,8 +633,8 @@ def plot(self):
627633
label="model fit",
628634
)
629635
# create strings to compose title
630-
r2 = f"$R^2$ on all data = {self.score:.3f}"
631-
discon = f"Discontinuity at threshold = {self.discontinuity_at_threshold:.2f}"
636+
r2 = f"$R^2$ on all data = {round_num(self.score, round_to)}"
637+
discon = f"Discontinuity at threshold = {round_num(self.discontinuity_at_threshold, round_to)}"
632638
ax.set(title=r2 + "\n" + discon)
633639
# Intervention line
634640
ax.axvline(

causalpy/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Utility functions
33
"""
4+
import numpy as np
45
import pandas as pd
56

67

@@ -13,3 +14,37 @@ def _is_variable_dummy_coded(series: pd.Series) -> bool:
1314
def _series_has_2_levels(series: pd.Series) -> bool:
1415
"""Check that the variable in the provided Series has 2 levels"""
1516
return len(pd.Categorical(series).categories) == 2
17+
18+
19+
def round_num(n, round_to):
20+
"""
21+
Return a string representing a number with `round_to` significant figures.
22+
23+
Parameters
24+
----------
25+
n : float
26+
number to round
27+
round_to : int
28+
number of significant figures
29+
"""
30+
sig_figs = _format_sig_figs(n, round_to)
31+
return f"{n:.{sig_figs}g}"
32+
33+
34+
def _format_sig_figs(value, default=None):
35+
"""Get a default number of significant figures.
36+
37+
Gives the integer part or `default`, whichever is bigger.
38+
39+
Examples
40+
--------
41+
0.1234 --> 0.12
42+
1.234 --> 1.2
43+
12.34 --> 12
44+
123.4 --> 123
45+
"""
46+
if default is None:
47+
default = 2
48+
if value == 0:
49+
return 1
50+
return max(int(np.log10(np.abs(value))) + 1, default)

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

0 commit comments

Comments
 (0)