Skip to content

Commit 225a498

Browse files
juanitorduzCopilot
andauthored
Add type hints to all code base (#557)
* init * rm file * update badge * Update causalpy/experiments/regression_kink.py Co-authored-by: Copilot <[email protected]> * update docstrings * fix: apply ruff formatting after rebase --------- Co-authored-by: Copilot <[email protected]>
1 parent a9ddbfa commit 225a498

16 files changed

+533
-285
lines changed

causalpy/data/datasets.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,22 @@ def _get_data_home() -> pathlib.Path:
4949

5050

5151
def load_data(dataset: str | None = None) -> pd.DataFrame:
52-
"""Loads the requested dataset and returns a pandas DataFrame.
52+
"""Load the requested dataset and return a pandas DataFrame.
5353
54-
:param dataset: The desired dataset to load
54+
Parameters
55+
----------
56+
dataset : str, optional
57+
The desired dataset to load. If None, raises ValueError.
58+
59+
Returns
60+
-------
61+
pd.DataFrame
62+
The loaded dataset as a pandas DataFrame.
63+
64+
Raises
65+
------
66+
ValueError
67+
If the requested dataset is not found.
5568
"""
5669

5770
if dataset in DATASETS:

causalpy/data/simulate_data.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
Functions that generate data sets used in examples
1616
"""
1717

18-
from typing import Any
19-
2018
import numpy as np
2119
import pandas as pd
2220
from scipy.stats import dirichlet, gamma, norm, uniform
@@ -31,7 +29,7 @@ def _smoothed_gaussian_random_walk(
3129
gaussian_random_walk_mu: float,
3230
gaussian_random_walk_sigma: float,
3331
N: int,
34-
lowess_kwargs: dict[str, Any],
32+
lowess_kwargs: dict,
3533
) -> tuple[np.ndarray, np.ndarray]:
3634
"""
3735
Generates Gaussian random walk data and applies LOWESS.
@@ -57,7 +55,7 @@ def generate_synthetic_control_data(
5755
treatment_time: int = 70,
5856
grw_mu: float = 0.25,
5957
grw_sigma: float = 1,
60-
lowess_kwargs: dict[str, Any] | None = None,
58+
lowess_kwargs: dict = default_lowess_kwargs,
6159
) -> tuple[pd.DataFrame, np.ndarray]:
6260
"""
6361
Generates data for synthetic control example.
@@ -78,9 +76,6 @@ def generate_synthetic_control_data(
7876
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
7977
>>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
8078
"""
81-
if lowess_kwargs is None:
82-
lowess_kwargs = default_lowess_kwargs
83-
8479
# 1. Generate non-treated variables
8580
df = pd.DataFrame(
8681
{
@@ -166,7 +161,9 @@ def generate_time_series_data(
166161
return df
167162

168163

169-
def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataFrame:
164+
def generate_time_series_data_seasonal(
165+
treatment_time: pd.Timestamp,
166+
) -> pd.DataFrame:
170167
"""
171168
Generates 10 years of monthly data with seasonality
172169
"""
@@ -184,7 +181,9 @@ def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataF
184181

185182
N = df.shape[0]
186183
idx = np.arange(N)[df.index > treatment_time]
187-
df["causal effect"] = 100 * gamma(10).pdf(np.arange(0, N, 1) - np.min(idx))
184+
df["causal effect"] = 100 * gamma(10).pdf(
185+
np.array(np.arange(0, N, 1)) - int(np.min(idx))
186+
)
188187

189188
df["y"] += df["causal effect"]
190189
df["y"] += norm(0, 2).rvs(N)
@@ -310,8 +309,8 @@ def impact(x: np.ndarray) -> np.ndarray:
310309
def generate_ancova_data(
311310
N: int = 200,
312311
pre_treatment_means: np.ndarray = np.array([10, 12]),
313-
treatment_effect: float = 2,
314-
sigma: float = 1,
312+
treatment_effect: int = 2,
313+
sigma: int = 1,
315314
) -> pd.DataFrame:
316315
"""
317316
Generate ANCOVA example data
@@ -445,7 +444,7 @@ def generate_multicell_geolift_data() -> pd.DataFrame:
445444

446445

447446
def generate_seasonality(
448-
n: int = 12, amplitude: float = 1, length_scale: float = 0.5
447+
n: int = 12, amplitude: int = 1, length_scale: float = 0.5
449448
) -> np.ndarray:
450449
"""Generate monthly seasonality by sampling from a Gaussian process with a
451450
Gaussian kernel, using numpy code"""
@@ -463,9 +462,9 @@ def generate_seasonality(
463462
def periodic_kernel(
464463
x1: np.ndarray,
465464
x2: np.ndarray,
466-
period: float = 1,
467-
length_scale: float = 1,
468-
amplitude: float = 1,
465+
period: int = 1,
466+
length_scale: float = 1.0,
467+
amplitude: int = 1,
469468
) -> np.ndarray:
470469
"""Generate a periodic kernel for gaussian process"""
471470
return amplitude**2 * np.exp(
@@ -475,10 +474,10 @@ def periodic_kernel(
475474

476475
def create_series(
477476
n: int = 52,
478-
amplitude: float = 1,
479-
length_scale: float = 2,
477+
amplitude: int = 1,
478+
length_scale: int = 2,
480479
n_years: int = 4,
481-
intercept: float = 3,
480+
intercept: int = 3,
482481
) -> np.ndarray:
483482
"""
484483
Returns numpy tile with generated seasonality data repeated over

causalpy/experiments/base.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"""
1717

1818
from abc import abstractmethod
19+
from typing import Any, Union
1920

2021
import arviz as az
2122
import matplotlib.pyplot as plt
@@ -29,10 +30,12 @@
2930
class BaseExperiment:
3031
"""Base class for quasi experimental designs."""
3132

33+
labels: list[str]
34+
3235
supports_bayes: bool
3336
supports_ols: bool
3437

35-
def __init__(self, model=None):
38+
def __init__(self, model: Union[PyMCModel, RegressorMixin] | None = None) -> None:
3639
# Ensure we've made any provided Scikit Learn model (as identified as being type
3740
# RegressorMixin) compatible with CausalPy by appending our custom methods.
3841
if isinstance(model, RegressorMixin):
@@ -50,16 +53,26 @@ def __init__(self, model=None):
5053
if self.model is None:
5154
raise ValueError("model not set or passed.")
5255

56+
def fit(self, *args: Any, **kwargs: Any) -> None:
57+
raise NotImplementedError("fit method not implemented")
58+
5359
@property
54-
def idata(self):
60+
def idata(self) -> az.InferenceData:
5561
"""Return the InferenceData object of the model. Only relevant for PyMC models."""
5662
return self.model.idata
5763

58-
def print_coefficients(self, round_to=None):
59-
"""Ask the model to print its coefficients."""
64+
def print_coefficients(self, round_to: int | None = None) -> None:
65+
"""Ask the model to print its coefficients.
66+
67+
Parameters
68+
----------
69+
round_to : int, optional
70+
Number of significant figures to round to. Defaults to None,
71+
in which case 2 significant figures are used.
72+
"""
6073
self.model.print_coefficients(self.labels, round_to)
6174

62-
def plot(self, *args, **kwargs) -> tuple:
75+
def plot(self, *args: Any, **kwargs: Any) -> tuple:
6376
"""Plot the model.
6477
6578
Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
@@ -75,16 +88,16 @@ def plot(self, *args, **kwargs) -> tuple:
7588
raise ValueError("Unsupported model type")
7689

7790
@abstractmethod
78-
def _bayesian_plot(self, *args, **kwargs):
91+
def _bayesian_plot(self, *args: Any, **kwargs: Any) -> tuple:
7992
"""Abstract method for plotting the model."""
8093
raise NotImplementedError("_bayesian_plot method not yet implemented")
8194

8295
@abstractmethod
83-
def _ols_plot(self, *args, **kwargs):
96+
def _ols_plot(self, *args: Any, **kwargs: Any) -> tuple:
8497
"""Abstract method for plotting the model."""
8598
raise NotImplementedError("_ols_plot method not yet implemented")
8699

87-
def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
100+
def get_plot_data(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
88101
"""Recover the data of an experiment along with the prediction and causal impact information.
89102
90103
Internally, this function dispatches to either :func:`get_plot_data_bayesian` or :func:`get_plot_data_ols`
@@ -98,11 +111,11 @@ def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
98111
raise ValueError("Unsupported model type")
99112

100113
@abstractmethod
101-
def get_plot_data_bayesian(self, *args, **kwargs):
114+
def get_plot_data_bayesian(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
102115
"""Abstract method for recovering plot data."""
103116
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")
104117

105118
@abstractmethod
106-
def get_plot_data_ols(self, *args, **kwargs):
119+
def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
107120
"""Abstract method for recovering plot data."""
108121
raise NotImplementedError("get_plot_data_ols method not yet implemented")

causalpy/experiments/diff_in_diff.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
Difference in differences
1616
"""
1717

18+
from typing import Union
19+
1820
import arviz as az
1921
import numpy as np
2022
import pandas as pd
@@ -47,20 +49,24 @@ class DifferenceInDifferences(BaseExperiment):
4749
4850
.. note::
4951
50-
There is no pre/post intervention data distinction for DiD, we fit all the
51-
data available.
52-
:param data:
53-
A pandas dataframe
54-
:param formula:
55-
A statistical model formula
56-
:param time_variable_name:
57-
Name of the data column for the time variable
58-
:param group_variable_name:
59-
Name of the data column for the group variable
60-
:param post_treatment_variable_name:
61-
Name of the data column indicating post-treatment period (default: "post_treatment")
62-
:param model:
63-
A PyMC model for difference in differences
52+
There is no pre/post intervention data distinction for DiD, we fit
53+
all the data available.
54+
55+
Parameters
56+
----------
57+
data : pd.DataFrame
58+
A pandas dataframe.
59+
formula : str
60+
A statistical model formula.
61+
time_variable_name : str
62+
Name of the data column for the time variable.
63+
group_variable_name : str
64+
Name of the data column for the group variable.
65+
post_treatment_variable_name : str, optional
66+
Name of the data column indicating post-treatment period.
67+
Defaults to "post_treatment".
68+
model : PyMCModel or RegressorMixin, optional
69+
A PyMC model for difference in differences. Defaults to None.
6470
6571
Example
6672
--------
@@ -92,8 +98,8 @@ def __init__(
9298
time_variable_name: str,
9399
group_variable_name: str,
94100
post_treatment_variable_name: str = "post_treatment",
95-
model=None,
96-
**kwargs,
101+
model: Union[PyMCModel, RegressorMixin] | None = None,
102+
**kwargs: dict,
97103
) -> None:
98104
super().__init__(model=model)
99105
self.causal_impact: xr.DataArray | float | None
@@ -234,14 +240,14 @@ def __init__(
234240
f"{self.group_variable_name}:{self.post_treatment_variable_name}"
235241
)
236242
matched_key = next((k for k in coef_map if interaction_term in k), None)
237-
att = coef_map.get(matched_key)
243+
att = coef_map.get(matched_key) if matched_key is not None else None
238244
self.causal_impact = att
239245
else:
240246
raise ValueError("Model type not recognized")
241247

242248
return
243249

244-
def input_validation(self):
250+
def input_validation(self) -> None:
245251
# Validate formula structure and interaction interaction terms
246252
self._validate_formula_interaction_terms()
247253

@@ -269,7 +275,7 @@ def input_validation(self):
269275
coded. Consisting of 0's and 1's only."""
270276
)
271277

272-
def _validate_formula_interaction_terms(self):
278+
def _validate_formula_interaction_terms(self) -> None:
273279
"""
274280
Validate that the formula contains at most one interaction term and no three-way or higher-order interactions.
275281
Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables.
@@ -299,7 +305,7 @@ def _validate_formula_interaction_terms(self):
299305
"Multiple interaction terms are not currently supported as they complicate interpretation of the causal effect."
300306
)
301307

302-
def summary(self, round_to=None) -> None:
308+
def summary(self, round_to: int | None = 2) -> None:
303309
"""Print summary of main results and model coefficients.
304310
305311
:param round_to:
@@ -311,11 +317,13 @@ def summary(self, round_to=None) -> None:
311317
print(self._causal_impact_summary_stat(round_to))
312318
self.print_coefficients(round_to)
313319

314-
def _causal_impact_summary_stat(self, round_to=None) -> str:
320+
def _causal_impact_summary_stat(self, round_to: int | None = None) -> str:
315321
"""Computes the mean and 94% credible interval bounds for the causal impact."""
316322
return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}"
317323

318-
def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
324+
def _bayesian_plot(
325+
self, round_to: int | None = None, **kwargs: dict
326+
) -> tuple[plt.Figure, plt.Axes]:
319327
"""
320328
Plot the results
321329
@@ -463,9 +471,10 @@ def _plot_causal_impact_arrow(results, ax):
463471
)
464472
return fig, ax
465473

466-
def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
474+
def _ols_plot(
475+
self, round_to: int | None = 2, **kwargs: dict
476+
) -> tuple[plt.Figure, plt.Axes]:
467477
"""Generate plot for difference-in-differences"""
468-
round_to = kwargs.get("round_to")
469478
fig, ax = plt.subplots()
470479

471480
# Plot raw data
@@ -528,11 +537,15 @@ def _ols_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]:
528537
va="center",
529538
)
530539
# formatting
540+
# In OLS context, causal_impact should be a float, but mypy doesn't know this
541+
causal_impact_value = (
542+
float(self.causal_impact) if self.causal_impact is not None else 0.0
543+
)
531544
ax.set(
532545
xlim=[-0.05, 1.1],
533546
xticks=[0, 1],
534547
xticklabels=["pre", "post"],
535-
title=f"Causal impact = {round_num(self.causal_impact, round_to)}",
548+
title=f"Causal impact = {round_num(causal_impact_value, round_to)}",
536549
)
537550
ax.legend(fontsize=LEGEND_FONT_SIZE)
538551
return fig, ax

0 commit comments

Comments
 (0)