Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,10 @@ repos:
additional_dependencies:
# Support pyproject.toml configuration
- tomli
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.18.2
hooks:
- id: mypy
args: [--ignore-missing-imports]
files: ^causalpy/
additional_dependencies: [numpy>=1.20, pandas-stubs]
10 changes: 10 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,13 @@
- **Formulas**: Use patsy for formula parsing (via `dmatrices()`)
- **Custom exceptions**: Use project-specific exceptions from `causalpy.custom_exceptions`: `FormulaException`, `DataException`, `BadIndexException`
- **File organization**: Experiments in `causalpy/experiments/`, PyMC models in `causalpy/pymc_models.py`, scikit-learn models in `causalpy/skl_models.py`

## Type Checking

- **Tool**: MyPy
- **Configuration**: Integrated as a pre-commit hook.
- **Scope**: Checks Python files within the `causalpy/` directory.
- **Settings**:
- `ignore-missing-imports`: Enabled to allow for gradual adoption of type hints without requiring all third-party libraries to have stubs.
- `additional_dependencies`: Includes `numpy` and `pandas-stubs` to provide type information for these libraries.
- **Execution**: Run automatically via `pre-commit run --all-files` or on commit.
21 changes: 17 additions & 4 deletions causalpy/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,28 @@
}


def _get_data_home() -> pathlib.PosixPath:
def _get_data_home() -> pathlib.Path:
"""Return the path of the data directory"""
return pathlib.Path(cp.__file__).parents[1] / "causalpy" / "data"


def load_data(dataset: str = None) -> pd.DataFrame:
"""Loads the requested dataset and returns a pandas DataFrame.
def load_data(dataset: str | None = None) -> pd.DataFrame:
"""Load the requested dataset and return a pandas DataFrame.

:param dataset: The desired dataset to load
Parameters
----------
dataset : str, optional
The desired dataset to load. If None, raises ValueError.

Returns
-------
pd.DataFrame
The loaded dataset as a pandas DataFrame.

Raises
------
ValueError
If the requested dataset is not found.
"""

if dataset in DATASETS:
Expand Down
43 changes: 21 additions & 22 deletions causalpy/data/simulate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
Functions that generate data sets used in examples
"""

from typing import Any

import numpy as np
import pandas as pd
from scipy.stats import dirichlet, gamma, norm, uniform
Expand All @@ -31,7 +29,7 @@ def _smoothed_gaussian_random_walk(
gaussian_random_walk_mu: float,
gaussian_random_walk_sigma: float,
N: int,
lowess_kwargs: dict[str, Any],
lowess_kwargs: dict,
) -> tuple[np.ndarray, np.ndarray]:
"""
Generates Gaussian random walk data and applies LOWESS.
Expand All @@ -57,7 +55,7 @@ def generate_synthetic_control_data(
treatment_time: int = 70,
grw_mu: float = 0.25,
grw_sigma: float = 1,
lowess_kwargs: dict[str, Any] | None = None,
lowess_kwargs: dict = default_lowess_kwargs,
) -> tuple[pd.DataFrame, np.ndarray]:
"""
Generates data for synthetic control example.
Expand All @@ -78,9 +76,6 @@ def generate_synthetic_control_data(
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
>>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
"""
if lowess_kwargs is None:
lowess_kwargs = default_lowess_kwargs

# 1. Generate non-treated variables
df = pd.DataFrame(
{
Expand Down Expand Up @@ -166,7 +161,9 @@ def generate_time_series_data(
return df


def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataFrame:
def generate_time_series_data_seasonal(
treatment_time: pd.Timestamp,
) -> pd.DataFrame:
"""
Generates 10 years of monthly data with seasonality
"""
Expand All @@ -180,11 +177,13 @@ def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataF
t=df.index,
).set_index("date", drop=True)
month_effect = np.array([11, 13, 12, 15, 19, 23, 21, 28, 20, 17, 15, 12])
df["y"] = 0.2 * df["t"] + 2 * month_effect[df.month.values - 1]
df["y"] = 0.2 * df["t"] + 2 * month_effect[np.asarray(df.month.values) - 1]

N = df.shape[0]
idx = np.arange(N)[df.index > treatment_time]
df["causal effect"] = 100 * gamma(10).pdf(np.arange(0, N, 1) - np.min(idx))
df["causal effect"] = 100 * gamma(10).pdf(
np.array(np.arange(0, N, 1)) - int(np.min(idx))
)

df["y"] += df["causal effect"]
df["y"] += norm(0, 2).rvs(N)
Expand Down Expand Up @@ -263,13 +262,13 @@ def outcome(
df["post_treatment"] = df["t"] > intervention_time

df["y"] = outcome(
df["t"],
np.asarray(df["t"]),
control_intercept,
treat_intercept_delta,
trend,
Δ,
df["group"],
df["post_treatment"],
np.asarray(df["group"]),
np.asarray(df["post_treatment"]),
)
df["y"] += rng.normal(0, 0.1, df.shape[0])
return df
Expand Down Expand Up @@ -310,8 +309,8 @@ def impact(x: np.ndarray) -> np.ndarray:
def generate_ancova_data(
N: int = 200,
pre_treatment_means: np.ndarray = np.array([10, 12]),
treatment_effect: float = 2,
sigma: float = 1,
treatment_effect: int = 2,
sigma: int = 1,
) -> pd.DataFrame:
"""
Generate ANCOVA example data
Expand Down Expand Up @@ -445,7 +444,7 @@ def generate_multicell_geolift_data() -> pd.DataFrame:


def generate_seasonality(
n: int = 12, amplitude: float = 1, length_scale: float = 0.5
n: int = 12, amplitude: int = 1, length_scale: float = 0.5
) -> np.ndarray:
"""Generate monthly seasonality by sampling from a Gaussian process with a
Gaussian kernel, using numpy code"""
Expand All @@ -463,9 +462,9 @@ def generate_seasonality(
def periodic_kernel(
x1: np.ndarray,
x2: np.ndarray,
period: float = 1,
length_scale: float = 1,
amplitude: float = 1,
period: int = 1,
length_scale: float = 1.0,
amplitude: int = 1,
) -> np.ndarray:
"""Generate a periodic kernel for gaussian process"""
return amplitude**2 * np.exp(
Expand All @@ -475,10 +474,10 @@ def periodic_kernel(

def create_series(
n: int = 52,
amplitude: float = 1,
length_scale: float = 2,
amplitude: int = 1,
length_scale: int = 2,
n_years: int = 4,
intercept: float = 3,
intercept: int = 3,
) -> np.ndarray:
"""
Returns numpy tile with generated seasonality data repeated over
Expand Down
33 changes: 23 additions & 10 deletions causalpy/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""

from abc import abstractmethod
from typing import Any, Union

import arviz as az
import matplotlib.pyplot as plt
Expand All @@ -29,10 +30,12 @@
class BaseExperiment:
"""Base class for quasi experimental designs."""

labels: list[str]

supports_bayes: bool
supports_ols: bool

def __init__(self, model=None):
def __init__(self, model: Union[PyMCModel, RegressorMixin] | None = None) -> None:
# Ensure we've made any provided Scikit Learn model (as identified as being type
# RegressorMixin) compatible with CausalPy by appending our custom methods.
if isinstance(model, RegressorMixin):
Expand All @@ -50,16 +53,26 @@ def __init__(self, model=None):
if self.model is None:
raise ValueError("model not set or passed.")

def fit(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("fit method not implemented")

@property
def idata(self):
def idata(self) -> az.InferenceData:
"""Return the InferenceData object of the model. Only relevant for PyMC models."""
return self.model.idata

def print_coefficients(self, round_to=None):
"""Ask the model to print its coefficients."""
def print_coefficients(self, round_to: int | None = None) -> None:
"""Ask the model to print its coefficients.

Parameters
----------
round_to : int, optional
Number of significant figures to round to. Defaults to None,
in which case 2 significant figures are used.
"""
self.model.print_coefficients(self.labels, round_to)

def plot(self, *args, **kwargs) -> tuple:
def plot(self, *args: Any, **kwargs: Any) -> tuple:
"""Plot the model.

Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot`
Expand All @@ -75,16 +88,16 @@ def plot(self, *args, **kwargs) -> tuple:
raise ValueError("Unsupported model type")

@abstractmethod
def _bayesian_plot(self, *args, **kwargs):
def _bayesian_plot(self, *args: Any, **kwargs: Any) -> tuple:
"""Abstract method for plotting the model."""
raise NotImplementedError("_bayesian_plot method not yet implemented")

@abstractmethod
def _ols_plot(self, *args, **kwargs):
def _ols_plot(self, *args: Any, **kwargs: Any) -> tuple:
"""Abstract method for plotting the model."""
raise NotImplementedError("_ols_plot method not yet implemented")

def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
def get_plot_data(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
"""Recover the data of an experiment along with the prediction and causal impact information.

Internally, this function dispatches to either :func:`get_plot_data_bayesian` or :func:`get_plot_data_ols`
Expand All @@ -98,11 +111,11 @@ def get_plot_data(self, *args, **kwargs) -> pd.DataFrame:
raise ValueError("Unsupported model type")

@abstractmethod
def get_plot_data_bayesian(self, *args, **kwargs):
def get_plot_data_bayesian(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
"""Abstract method for recovering plot data."""
raise NotImplementedError("get_plot_data_bayesian method not yet implemented")

@abstractmethod
def get_plot_data_ols(self, *args, **kwargs):
def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
"""Abstract method for recovering plot data."""
raise NotImplementedError("get_plot_data_ols method not yet implemented")
Loading