Skip to content

Commit f915f77

Browse files
committed
add return type hints for plot methods + fix up test assertions
1 parent c080fa9 commit f915f77

File tree

9 files changed

+53
-23
lines changed

9 files changed

+53
-23
lines changed

causalpy/experiments/diff_in_diff.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
import pandas as pd
20+
from matplotlib import pyplot as plt
2021
from patsy import build_design_matrices, dmatrices
2122

2223
from causalpy.custom_exceptions import (
@@ -204,7 +205,7 @@ def input_validation(self):
204205
coded. Consisting of 0's and 1's only."""
205206
)
206207

207-
def plot(self, round_to=None):
208+
def plot(self, round_to=None) -> tuple[plt.Figure, plt.Axes]:
208209
"""
209210
Plot the results
210211

causalpy/experiments/inverse_propensity_weighting.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
Inverse propensity weighting
1616
"""
1717

18+
from typing import List
19+
1820
import arviz as az
1921
import matplotlib.pyplot as plt
2022
import numpy as np
@@ -238,7 +240,9 @@ def get_ate(self, i, idata, method="doubly_robust"):
238240
ate = trt - ntrt
239241
return [ate, trt, ntrt]
240242

241-
def plot_ate(self, idata=None, method=None, prop_draws=100, ate_draws=300):
243+
def plot_ate(
244+
self, idata=None, method=None, prop_draws=100, ate_draws=300
245+
) -> tuple[plt.Figure, List[plt.Axes]]:
242246
if idata is None:
243247
idata = self.model.idata
244248
if method is None:
@@ -364,7 +368,7 @@ def make_hists(idata, i, axs, method=method):
364368
axs[2].legend()
365369
axs[2].set_title("Average Treatment Effect", fontsize=20)
366370

367-
return fig
371+
return fig, axs
368372

369373
def weighted_percentile(self, data, weights, perc):
370374
"""
@@ -380,7 +384,9 @@ def weighted_percentile(self, data, weights, perc):
380384
) # 'like' a CDF function
381385
return np.interp(perc, cdf, data)
382386

383-
def plot_balance_ecdf(self, covariate, idata=None, weighting_scheme=None):
387+
def plot_balance_ecdf(
388+
self, covariate, idata=None, weighting_scheme=None
389+
) -> tuple[plt.Figure, List[plt.Axes]]:
384390
"""
385391
Plotting function takes a single covariate and shows the
386392
differences in the ECDF between the treatment and control
@@ -451,4 +457,5 @@ def plot_balance_ecdf(self, covariate, idata=None, weighting_scheme=None):
451457
axs[0].set_xlabel("Quantiles")
452458
axs[1].legend()
453459
axs[0].legend()
454-
return fig
460+
# TODO: for some reason ax is type numpy.ndarray, so we need to convert this back to a list to conform to the expected return type.
461+
return fig, list(axs)

causalpy/experiments/prepostfit.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import numpy as np
2121
import pandas as pd
22+
from matplotlib import pyplot as plt
2223
from patsy import build_design_matrices, dmatrices
2324

2425
from causalpy.custom_exceptions import BadIndexException
@@ -107,7 +108,7 @@ def input_validation(self, data, treatment_time):
107108
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
108109
)
109110

110-
def plot(self):
111+
def plot(self) -> tuple[plt.Figure, plt.Axes]:
111112
"""
112113
Plot the results
113114
@@ -204,7 +205,9 @@ class SyntheticControl(PrePostFit):
204205

205206
expt_type = "SyntheticControl"
206207

207-
def plot(self, round_to=None, plot_predictors: bool = False):
208+
def plot(
209+
self, round_to=None, plot_predictors: bool = False
210+
) -> tuple[plt.Figure, plt.Axes]:
208211
"""
209212
Plot the results
210213

causalpy/experiments/prepostnegd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
import pandas as pd
20+
from matplotlib import pyplot as plt
2021
from patsy import build_design_matrices, dmatrices
2122

2223
from causalpy.custom_exceptions import (
@@ -176,7 +177,7 @@ def _causal_impact_summary_stat(self, round_to) -> str:
176177
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
177178
return f"Causal impact = {causal_impact + ci}"
178179

179-
def plot(self):
180+
def plot(self) -> tuple[plt.Figure, plt.Axes]:
180181
"""
181182
Plot the results
182183

causalpy/experiments/regression_discontinuity.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import warnings # noqa: I001
1919

20+
from matplotlib import pyplot as plt
2021
import numpy as np
2122
import pandas as pd
2223
from patsy import build_design_matrices, dmatrices
@@ -192,7 +193,7 @@ def _is_treated(self, x):
192193
"""
193194
return np.greater_equal(x, self.treatment_threshold)
194195

195-
def plot(self):
196+
def plot(self) -> tuple[plt.Figure, plt.Axes]:
196197
"""
197198
Plot the results
198199

causalpy/experiments/regression_kink.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import warnings # noqa: I001
2020

21+
from matplotlib import pyplot as plt
2122
import numpy as np
2223
import pandas as pd
2324
from patsy import build_design_matrices, dmatrices
@@ -160,7 +161,7 @@ def _is_treated(self, x):
160161
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501
161162
return np.greater_equal(x, self.kink_point)
162163

163-
def plot(self, round_to=None):
164+
def plot(self, round_to=None) -> tuple[plt.Figure, plt.Axes]:
164165
"""
165166
Plot the results
166167

causalpy/plotting.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import List
16+
1517
import arviz as az
1618
import matplotlib.pyplot as plt
1719
import numpy as np
@@ -28,7 +30,9 @@ class BayesianPlotComponent:
2830
"""Plotting component for Bayesian models."""
2931

3032
@staticmethod
31-
def plot_pre_post(results, round_to=None, counterfactual_label=None):
33+
def plot_pre_post(
34+
results, round_to=None, counterfactual_label=None
35+
) -> tuple[plt.Figure, plt.Axes]:
3236
"""Generate plot for pre-post experiment types, such as Interrupted Time Series
3337
and Synthetic Control."""
3438
datapre = results.datapre
@@ -143,7 +147,9 @@ def plot_pre_post(results, round_to=None, counterfactual_label=None):
143147
return fig, ax
144148

145149
@staticmethod
146-
def plot_difference_in_differences(results, round_to=None):
150+
def plot_difference_in_differences(
151+
results, round_to=None
152+
) -> tuple[plt.Figure, plt.Axes]:
147153
"""Generate plot for difference-in-differences"""
148154

149155
def _plot_causal_impact_arrow(results, ax):
@@ -282,7 +288,7 @@ def _plot_causal_impact_arrow(results, ax):
282288
# pass
283289

284290
@staticmethod
285-
def plot_pre_post_negd(results, round_to=None):
291+
def plot_pre_post_negd(results, round_to=None) -> tuple[plt.Figure, List[plt.Axes]]:
286292
"""Generate plot for ANOVA-like experiments with non-equivalent group designs."""
287293
fig, ax = plt.subplots(
288294
2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]}
@@ -334,7 +340,9 @@ def plot_pre_post_negd(results, round_to=None):
334340
return fig, ax
335341

336342
@staticmethod
337-
def plot_regression_discontinuity(results, round_to=None):
343+
def plot_regression_discontinuity(
344+
results, round_to=None
345+
) -> tuple[plt.Figure, plt.Axes]:
338346
"""Generate plot for regression discontinuity designs."""
339347
fig, ax = plt.subplots()
340348
# Plot raw data
@@ -386,7 +394,7 @@ def plot_regression_discontinuity(results, round_to=None):
386394
return (fig, ax)
387395

388396
@staticmethod
389-
def plot_regression_kink(results, round_to=None):
397+
def plot_regression_kink(results, round_to=None) -> tuple[plt.Figure, plt.Axes]:
390398
"""Generate plot for regression kink designs."""
391399
fig, ax = plt.subplots()
392400
# Plot raw data
@@ -440,7 +448,7 @@ class OLSPlotComponent:
440448
"""Plotting component for OLS models."""
441449

442450
@staticmethod
443-
def plot_pre_post(results, round_to=None):
451+
def plot_pre_post(results, round_to=None) -> tuple[plt.Figure, List[plt.Axes]]:
444452
"""Generate plot for pre-post experiment types, such as Interrupted Time Series
445453
and Synthetic Control."""
446454
counterfactual_label = "Counterfactual"
@@ -509,7 +517,9 @@ def plot_pre_post(results, round_to=None):
509517
return (fig, ax)
510518

511519
@staticmethod
512-
def plot_difference_in_differences(results, round_to=None):
520+
def plot_difference_in_differences(
521+
results, round_to=None
522+
) -> tuple[plt.Figure, plt.Axes]:
513523
"""Generate plot for difference-in-differences"""
514524
fig, ax = plt.subplots()
515525

@@ -585,7 +595,9 @@ def plot_difference_in_differences(results, round_to=None):
585595
return (fig, ax)
586596

587597
@staticmethod
588-
def plot_regression_discontinuity(results, round_to=None) -> tuple:
598+
def plot_regression_discontinuity(
599+
results, round_to=None
600+
) -> tuple[plt.Figure, plt.Axes]:
589601
"""Generate plot for regression discontinuity designs."""
590602
fig, ax = plt.subplots()
591603
# Plot raw data

causalpy/tests/test_misc.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
"""
1717

1818
import arviz as az
19-
import matplotlib as mpl
2019
import pandas as pd
20+
from matplotlib import pyplot as plt
2121

2222
import causalpy as cp
2323

@@ -82,7 +82,11 @@ def test_inverse_prop():
8282
assert isinstance(ate_list, list)
8383
ate_list = result.get_ate(0, result.idata, method="overlap")
8484
assert isinstance(ate_list, list)
85-
fig = result.plot_ate(prop_draws=1, ate_draws=10)
86-
assert isinstance(fig, mpl.figure.Figure)
87-
fig = result.plot_balance_ecdf("age")
88-
assert isinstance(fig, mpl.figure.Figure)
85+
fig, axs = result.plot_ate(prop_draws=1, ate_draws=10)
86+
assert isinstance(fig, plt.Figure)
87+
assert isinstance(axs, list)
88+
assert all(isinstance(ax, plt.Axes) for ax in axs)
89+
fig, axs = result.plot_balance_ecdf("age")
90+
assert isinstance(fig, plt.Figure)
91+
assert isinstance(axs, list)
92+
assert all(isinstance(ax, plt.Axes) for ax in axs)

docs/source/_static/classes.png

90.4 KB
Loading

0 commit comments

Comments
 (0)