Skip to content

Commit 6bee429

Browse files
Jason Preszlerjpreszler
authored andcommitted
Issue 129: increase docstring coverage, now at 86%
1 parent d7a12cb commit 6bee429

13 files changed

+274
-15
lines changed

causalpy/custom_exceptions.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
1+
"""
2+
Custom Exceptions for CausalPy.
3+
"""
4+
5+
16
class BadIndexException(Exception):
27
"""Custom exception used when we have a mismatch in types between the dataframe
38
index and an event, typically a treatment or intervention."""
49

5-
def __init__(self, message):
10+
def __init__(self, message: str):
611
self.message = message
712

813

914
class FormulaException(Exception):
1015
"""Exception raised given when there is some error in a user-provided model
1116
formula"""
1217

13-
def __init__(self, message):
18+
def __init__(self, message: str):
1419
self.message = message
1520

1621

1722
class DataException(Exception):
1823
"""Exception raised given when there is some error in user-provided dataframe"""
1924

20-
def __init__(self, message):
25+
def __init__(self, message: str):
2126
self.message = message

causalpy/data/simulate_data.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
def _smoothed_gaussian_random_walk(
1212
gaussian_random_walk_mu, gaussian_random_walk_sigma, N, lowess_kwargs
1313
):
14+
"""
15+
Generates Gaussian random walk data and applies LOWESS
16+
"""
1417
x = np.arange(N)
1518
y = norm(gaussian_random_walk_mu, gaussian_random_walk_sigma).rvs(N).cumsum()
1619
filtered = lowess(y, x, **lowess_kwargs)

causalpy/plot_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
Plotting utility functions.
3+
"""
4+
15
from typing import Any, Dict, Optional, Tuple, Union
26

37
import arviz as az

causalpy/pymc_experiments.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
"""
2+
Experiment routines for PyMC models.
3+
4+
Includes:
5+
1. ExperimentalDesign base class
6+
2. Pre-Post Fit
7+
3. Synthetic Control
8+
4. Difference in differences
9+
5. Regression Discontinuity
10+
"""
111
import warnings
212
from typing import Optional, Union
313

@@ -36,7 +46,7 @@ def idata(self):
3646
"""Access to the InferenceData object"""
3747
return self.model.idata
3848

39-
def print_coefficients(self):
49+
def print_coefficients(self) -> None:
4050
"""Prints the model coefficients"""
4151
print("Model coefficients:")
4252
coeffs = az.extract(self.idata.posterior, var_names="beta")
@@ -236,7 +246,7 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
236246

237247
return (fig, ax)
238248

239-
def summary(self):
249+
def summary(self) -> None:
240250
"""Print text output summarising the results"""
241251

242252
print(f"{self.expt_type:=^80}")
@@ -524,13 +534,14 @@ def _plot_causal_impact_arrow(self, ax):
524534
va="center",
525535
)
526536

527-
def _causal_impact_summary_stat(self):
537+
def _causal_impact_summary_stat(self) -> str:
538+
"""Computes the mean and 94% credible interval bounds for the causal impact."""
528539
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
529540
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
530541
causal_impact = f"{self.causal_impact.mean():.2f}, "
531542
return f"Causal impact = {causal_impact + ci}"
532543

533-
def summary(self):
544+
def summary(self) -> None:
534545
"""Print text output summarising the results"""
535546

536547
print(f"{self.expt_type:=^80}")
@@ -716,7 +727,7 @@ def plot(self):
716727
)
717728
return (fig, ax)
718729

719-
def summary(self):
730+
def summary(self) -> None:
720731
"""Print text output summarising the results"""
721732

722733
print(f"{self.expt_type:=^80}")
@@ -795,7 +806,7 @@ def __init__(
795806

796807
# ================================================================
797808

798-
def _input_validation(self):
809+
def _input_validation(self) -> None:
799810
"""Validate the input data and model formula for correctness"""
800811
if not _series_has_2_levels(self.data[self.group_variable_name]):
801812
raise DataException(
@@ -856,13 +867,14 @@ def plot(self):
856867
ax[1].set(title="Estimated treatment effect")
857868
return fig, ax
858869

859-
def _causal_impact_summary_stat(self):
870+
def _causal_impact_summary_stat(self) -> str:
871+
"""Computes the mean and 94% credible interval bounds for the causal impact."""
860872
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
861873
ci = r"$CI_{94\%}$" + f"[{percentiles[0]:.2f}, {percentiles[1]:.2f}]"
862874
causal_impact = f"{self.causal_impact.mean():.2f}, "
863875
return f"Causal impact = {causal_impact + ci}"
864876

865-
def summary(self):
877+
def summary(self) -> None:
866878
"""Print text output summarising the results"""
867879

868880
print(f"{self.expt_type:=^80}")

causalpy/pymc_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def build_model(self, X, y, coords) -> None:
4040
raise NotImplementedError("This method must be implemented by a subclass")
4141

4242
def _data_setter(self, X) -> None:
43+
"""Set data for the model."""
4344
with self.model:
4445
pm.set_data({"X": X})
4546

causalpy/skl_experiments.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Experiments for Scikit-Learn models
3+
"""
14
import warnings
25
from typing import Optional
36

@@ -78,6 +81,7 @@ def __init__(
7881
self.post_impact_cumulative = np.cumsum(self.post_impact)
7982

8083
def plot(self, counterfactual_label="Counterfactual", **kwargs):
84+
"""Plot experiment results"""
8185
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
8286

8387
ax[0].plot(self.datapre.index, self.pre_y, "k.")
@@ -140,9 +144,11 @@ def plot(self, counterfactual_label="Counterfactual", **kwargs):
140144
return (fig, ax)
141145

142146
def get_coeffs(self):
147+
"""Returns model coefficients"""
143148
return np.squeeze(self.model.coef_)
144149

145150
def plot_coeffs(self):
151+
"""Plots coefficient bar plot"""
146152
df = pd.DataFrame(
147153
{"predictor variable": self.labels, "ols_coef": self.get_coeffs()}
148154
)
@@ -463,6 +469,7 @@ def _is_treated(self, x):
463469
return np.greater_equal(x, self.treatment_threshold)
464470

465471
def plot(self):
472+
"""Plot results"""
466473
fig, ax = plt.subplots()
467474
# Plot raw data
468475
sns.scatterplot(

causalpy/skl_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""
2+
Scikit-Learn Models
3+
4+
Includes:
5+
1. Weighted Proportion
6+
"""
17
from functools import partial
28

39
import numpy as np
@@ -18,9 +24,11 @@ class WeightedProportion(LinearModel, RegressorMixin):
1824
"""
1925

2026
def loss(self, W, X, y):
27+
"""Compute root mean squared loss with data X, weights W, and predictor y"""
2128
return np.sqrt(np.mean((y - np.dot(X, W.T)) ** 2))
2229

2330
def fit(self, X, y):
31+
"""Fit model on data X with predictor y"""
2432
w_start = [1 / X.shape[1]] * X.shape[1]
2533
coef_ = fmin_slsqp(
2634
partial(self.loss, X=X, y=y),
@@ -34,4 +42,5 @@ def fit(self, X, y):
3442
return self
3543

3644
def predict(self, X):
45+
"""Predict results for data X"""
3746
return np.dot(X, self.coef_.T)

causalpy/tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44

55
@pytest.fixture(scope="session")
66
def rng() -> np.random.Generator:
7+
"""Random number generator that can persist through a pytest session"""
78
seed: int = sum(map(ord, "causalpy"))
89
return np.random.default_rng(seed=seed)

causalpy/tests/test_data_loading.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
@pytest.mark.parametrize("dataset_name", tests)
2121
def test_data_loading(dataset_name):
22+
"""
23+
Checks that test data can be loaded into data frames and that there are no
24+
missing values in any column.
25+
"""
2226
df = cp.load_data(dataset_name)
2327
assert isinstance(df, pd.DataFrame)
2428
# Check that there are no missing values in any column

0 commit comments

Comments
 (0)