Skip to content

Commit 6dcb347

Browse files
committed
Merge branch 'reporting' of https://github.com/pymc-labs/CausalPy into reporting
2 parents e94496b + d54d28a commit 6dcb347

20 files changed

+697
-318
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,10 @@ repos:
4848
additional_dependencies:
4949
# Support pyproject.toml configuration
5050
- tomli
51+
- repo: https://github.com/pre-commit/mirrors-mypy
52+
rev: v1.18.2
53+
hooks:
54+
- id: mypy
55+
args: [--ignore-missing-imports]
56+
files: ^causalpy/
57+
additional_dependencies: [numpy>=1.20, pandas-stubs]

AGENTS.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,13 @@
3737
- **Formulas**: Use patsy for formula parsing (via `dmatrices()`)
3838
- **Custom exceptions**: Use project-specific exceptions from `causalpy.custom_exceptions`: `FormulaException`, `DataException`, `BadIndexException`
3939
- **File organization**: Experiments in `causalpy/experiments/`, PyMC models in `causalpy/pymc_models.py`, scikit-learn models in `causalpy/skl_models.py`
40+
41+
## Type Checking
42+
43+
- **Tool**: MyPy
44+
- **Configuration**: Integrated as a pre-commit hook.
45+
- **Scope**: Checks Python files within the `causalpy/` directory.
46+
- **Settings**:
47+
- `ignore-missing-imports`: Enabled to allow for gradual adoption of type hints without requiring all third-party libraries to have stubs.
48+
- `additional_dependencies`: Includes `numpy` and `pandas-stubs` to provide type information for these libraries.
49+
- **Execution**: Run automatically via `pre-commit run --all-files` or on commit.

causalpy/data/datasets.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,28 @@
4343
}
4444

4545

46-
def _get_data_home() -> pathlib.PosixPath:
46+
def _get_data_home() -> pathlib.Path:
4747
"""Return the path of the data directory"""
4848
return pathlib.Path(cp.__file__).parents[1] / "causalpy" / "data"
4949

5050

51-
def load_data(dataset: str = None) -> pd.DataFrame:
52-
"""Loads the requested dataset and returns a pandas DataFrame.
51+
def load_data(dataset: str | None = None) -> pd.DataFrame:
52+
"""Load the requested dataset and return a pandas DataFrame.
5353
54-
:param dataset: The desired dataset to load
54+
Parameters
55+
----------
56+
dataset : str, optional
57+
The desired dataset to load. If None, raises ValueError.
58+
59+
Returns
60+
-------
61+
pd.DataFrame
62+
The loaded dataset as a pandas DataFrame.
63+
64+
Raises
65+
------
66+
ValueError
67+
If the requested dataset is not found.
5568
"""
5669

5770
if dataset in DATASETS:

causalpy/data/simulate_data.py

Lines changed: 72 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@
2020
from scipy.stats import dirichlet, gamma, norm, uniform
2121
from statsmodels.nonparametric.smoothers_lowess import lowess
2222

23-
default_lowess_kwargs = {"frac": 0.2, "it": 0}
24-
RANDOM_SEED = 8927
25-
rng = np.random.default_rng(RANDOM_SEED)
23+
default_lowess_kwargs: dict[str, float | int] = {"frac": 0.2, "it": 0}
24+
RANDOM_SEED: int = 8927
25+
rng: np.random.Generator = np.random.default_rng(RANDOM_SEED)
2626

2727

2828
def _smoothed_gaussian_random_walk(
29-
gaussian_random_walk_mu, gaussian_random_walk_sigma, N, lowess_kwargs
30-
):
29+
gaussian_random_walk_mu: float,
30+
gaussian_random_walk_sigma: float,
31+
N: int,
32+
lowess_kwargs: dict,
33+
) -> tuple[np.ndarray, np.ndarray]:
3134
"""
32-
Generates Gaussian random walk data and applies LOWESS
35+
Generates Gaussian random walk data and applies LOWESS.
3336
3437
:param gaussian_random_walk_mu:
3538
Mean of the random walk
@@ -48,12 +51,12 @@ def _smoothed_gaussian_random_walk(
4851

4952

5053
def generate_synthetic_control_data(
51-
N=100,
52-
treatment_time=70,
53-
grw_mu=0.25,
54-
grw_sigma=1,
55-
lowess_kwargs=default_lowess_kwargs,
56-
):
54+
N: int = 100,
55+
treatment_time: int = 70,
56+
grw_mu: float = 0.25,
57+
grw_sigma: float = 1,
58+
lowess_kwargs: dict = default_lowess_kwargs,
59+
) -> tuple[pd.DataFrame, np.ndarray]:
5760
"""
5861
Generates data for synthetic control example.
5962
@@ -73,7 +76,6 @@ def generate_synthetic_control_data(
7376
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
7477
>>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
7578
"""
76-
7779
# 1. Generate non-treated variables
7880
df = pd.DataFrame(
7981
{
@@ -108,8 +110,12 @@ def generate_synthetic_control_data(
108110

109111

110112
def generate_time_series_data(
111-
N=100, treatment_time=70, beta_temp=-1, beta_linear=0.5, beta_intercept=3
112-
):
113+
N: int = 100,
114+
treatment_time: int = 70,
115+
beta_temp: float = -1,
116+
beta_linear: float = 0.5,
117+
beta_intercept: float = 3,
118+
) -> pd.DataFrame:
113119
"""
114120
Generates interrupted time series example data
115121
@@ -155,7 +161,9 @@ def generate_time_series_data(
155161
return df
156162

157163

158-
def generate_time_series_data_seasonal(treatment_time):
164+
def generate_time_series_data_seasonal(
165+
treatment_time: pd.Timestamp,
166+
) -> pd.DataFrame:
159167
"""
160168
Generates 10 years of monthly data with seasonality
161169
"""
@@ -169,11 +177,13 @@ def generate_time_series_data_seasonal(treatment_time):
169177
t=df.index,
170178
).set_index("date", drop=True)
171179
month_effect = np.array([11, 13, 12, 15, 19, 23, 21, 28, 20, 17, 15, 12])
172-
df["y"] = 0.2 * df["t"] + 2 * month_effect[df.month.values - 1]
180+
df["y"] = 0.2 * df["t"] + 2 * month_effect[np.asarray(df.month.values) - 1]
173181

174182
N = df.shape[0]
175183
idx = np.arange(N)[df.index > treatment_time]
176-
df["causal effect"] = 100 * gamma(10).pdf(np.arange(0, N, 1) - np.min(idx))
184+
df["causal effect"] = 100 * gamma(10).pdf(
185+
np.array(np.arange(0, N, 1)) - int(np.min(idx))
186+
)
177187

178188
df["y"] += df["causal effect"]
179189
df["y"] += norm(0, 2).rvs(N)
@@ -183,7 +193,9 @@ def generate_time_series_data_seasonal(treatment_time):
183193
return df
184194

185195

186-
def generate_time_series_data_simple(treatment_time, slope=0.0):
196+
def generate_time_series_data_simple(
197+
treatment_time: pd.Timestamp, slope: float = 0.0
198+
) -> pd.DataFrame:
187199
"""Generate simple interrupted time series data, with no seasonality or temporal
188200
structure.
189201
"""
@@ -205,7 +217,7 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
205217
return df
206218

207219

208-
def generate_did():
220+
def generate_did() -> pd.DataFrame:
209221
"""
210222
Generate Difference in Differences data
211223
@@ -223,8 +235,14 @@ def generate_did():
223235

224236
# local functions
225237
def outcome(
226-
t, control_intercept, treat_intercept_delta, trend, Δ, group, post_treatment
227-
):
238+
t: np.ndarray,
239+
control_intercept: float,
240+
treat_intercept_delta: float,
241+
trend: float,
242+
Δ: float,
243+
group: np.ndarray,
244+
post_treatment: np.ndarray,
245+
) -> np.ndarray:
228246
"""Compute the outcome of each unit"""
229247
return (
230248
control_intercept
@@ -244,21 +262,21 @@ def outcome(
244262
df["post_treatment"] = df["t"] > intervention_time
245263

246264
df["y"] = outcome(
247-
df["t"],
265+
np.asarray(df["t"]),
248266
control_intercept,
249267
treat_intercept_delta,
250268
trend,
251269
Δ,
252-
df["group"],
253-
df["post_treatment"],
270+
np.asarray(df["group"]),
271+
np.asarray(df["post_treatment"]),
254272
)
255273
df["y"] += rng.normal(0, 0.1, df.shape[0])
256274
return df
257275

258276

259277
def generate_regression_discontinuity_data(
260-
N=100, true_causal_impact=0.5, true_treatment_threshold=0.0
261-
):
278+
N: int = 100, true_causal_impact: float = 0.5, true_treatment_threshold: float = 0.0
279+
) -> pd.DataFrame:
262280
"""
263281
Generate regression discontinuity example data
264282
@@ -272,12 +290,12 @@ def generate_regression_discontinuity_data(
272290
... ) # doctest: +SKIP
273291
"""
274292

275-
def is_treated(x):
293+
def is_treated(x: np.ndarray) -> np.ndarray:
276294
"""Check if x was treated"""
277295
return np.greater_equal(x, true_treatment_threshold)
278296

279-
def impact(x):
280-
"""Assign true_causal_impact to all treaated entries"""
297+
def impact(x: np.ndarray) -> np.ndarray:
298+
"""Assign true_causal_impact to all treated entries"""
281299
y = np.zeros(len(x))
282300
y[is_treated(x)] = true_causal_impact
283301
return y
@@ -289,8 +307,11 @@ def impact(x):
289307

290308

291309
def generate_ancova_data(
292-
N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1
293-
):
310+
N: int = 200,
311+
pre_treatment_means: np.ndarray = np.array([10, 12]),
312+
treatment_effect: int = 2,
313+
sigma: int = 1,
314+
) -> pd.DataFrame:
294315
"""
295316
Generate ANCOVA example data
296317
@@ -310,7 +331,7 @@ def generate_ancova_data(
310331
return df
311332

312333

313-
def generate_geolift_data():
334+
def generate_geolift_data() -> pd.DataFrame:
314335
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
315336
countries. The treated unit `Denmark` is a weighted combination of the untreated
316337
units. We additionally specify a treatment effect which takes effect after the
@@ -360,7 +381,7 @@ def generate_geolift_data():
360381
return df
361382

362383

363-
def generate_multicell_geolift_data():
384+
def generate_multicell_geolift_data() -> pd.DataFrame:
364385
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
365386
countries. The treated unit `Denmark` is a weighted combination of the untreated
366387
units. We additionally specify a treatment effect which takes effect after the
@@ -422,7 +443,9 @@ def generate_multicell_geolift_data():
422443
# -----------------
423444

424445

425-
def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
446+
def generate_seasonality(
447+
n: int = 12, amplitude: int = 1, length_scale: float = 0.5
448+
) -> np.ndarray:
426449
"""Generate monthly seasonality by sampling from a Gaussian process with a
427450
Gaussian kernel, using numpy code"""
428451
# Generate the covariance matrix
@@ -436,14 +459,26 @@ def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
436459
return seasonality
437460

438461

439-
def periodic_kernel(x1, x2, period=1, length_scale=1, amplitude=1):
462+
def periodic_kernel(
463+
x1: np.ndarray,
464+
x2: np.ndarray,
465+
period: int = 1,
466+
length_scale: float = 1.0,
467+
amplitude: int = 1,
468+
) -> np.ndarray:
440469
"""Generate a periodic kernel for gaussian process"""
441470
return amplitude**2 * np.exp(
442471
-2 * np.sin(np.pi * np.abs(x1 - x2) / period) ** 2 / length_scale**2
443472
)
444473

445474

446-
def create_series(n=52, amplitude=1, length_scale=2, n_years=4, intercept=3):
475+
def create_series(
476+
n: int = 52,
477+
amplitude: int = 1,
478+
length_scale: int = 2,
479+
n_years: int = 4,
480+
intercept: int = 3,
481+
) -> np.ndarray:
447482
"""
448483
Returns numpy tile with generated seasonality data repeated over
449484
multiple years

0 commit comments

Comments
 (0)