Skip to content
Merged
98 changes: 67 additions & 31 deletions causalpy/data/simulate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,26 @@
Functions that generate data sets used in examples
"""

from typing import Any

import numpy as np
import pandas as pd
from scipy.stats import dirichlet, gamma, norm, uniform
from statsmodels.nonparametric.smoothers_lowess import lowess

default_lowess_kwargs = {"frac": 0.2, "it": 0}
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
default_lowess_kwargs: dict[str, float] = {"frac": 0.2, "it": 0}
RANDOM_SEED: int = 8927
rng: np.random.Generator = np.random.default_rng(RANDOM_SEED)


def _smoothed_gaussian_random_walk(
gaussian_random_walk_mu, gaussian_random_walk_sigma, N, lowess_kwargs
):
gaussian_random_walk_mu: float,
gaussian_random_walk_sigma: float,
N: int,
lowess_kwargs: dict[str, Any],
) -> tuple[np.ndarray, np.ndarray]:
"""
Generates Gaussian random walk data and applies LOWESS
Generates Gaussian random walk data and applies LOWESS.

:param gaussian_random_walk_mu:
Mean of the random walk
Expand All @@ -48,12 +53,12 @@ def _smoothed_gaussian_random_walk(


def generate_synthetic_control_data(
N=100,
treatment_time=70,
grw_mu=0.25,
grw_sigma=1,
lowess_kwargs=default_lowess_kwargs,
):
N: int = 100,
treatment_time: int = 70,
grw_mu: float = 0.25,
grw_sigma: float = 1,
lowess_kwargs: dict[str, Any] | None = None,
) -> tuple[pd.DataFrame, np.ndarray]:
"""
Generates data for synthetic control example.

Expand All @@ -73,6 +78,8 @@ def generate_synthetic_control_data(
>>> from causalpy.data.simulate_data import generate_synthetic_control_data
>>> df, weightings_true = generate_synthetic_control_data(treatment_time=70)
"""
if lowess_kwargs is None:
lowess_kwargs = default_lowess_kwargs

# 1. Generate non-treated variables
df = pd.DataFrame(
Expand Down Expand Up @@ -108,8 +115,12 @@ def generate_synthetic_control_data(


def generate_time_series_data(
N=100, treatment_time=70, beta_temp=-1, beta_linear=0.5, beta_intercept=3
):
N: int = 100,
treatment_time: int = 70,
beta_temp: float = -1,
beta_linear: float = 0.5,
beta_intercept: float = 3,
) -> pd.DataFrame:
"""
Generates interrupted time series example data

Expand Down Expand Up @@ -155,7 +166,7 @@ def generate_time_series_data(
return df


def generate_time_series_data_seasonal(treatment_time):
def generate_time_series_data_seasonal(treatment_time: pd.Timestamp) -> pd.DataFrame:
"""
Generates 10 years of monthly data with seasonality
"""
Expand Down Expand Up @@ -183,7 +194,9 @@ def generate_time_series_data_seasonal(treatment_time):
return df


def generate_time_series_data_simple(treatment_time, slope=0.0):
def generate_time_series_data_simple(
treatment_time: pd.Timestamp, slope: float = 0.0
) -> pd.DataFrame:
"""Generate simple interrupted time series data, with no seasonality or temporal
structure.
"""
Expand All @@ -205,7 +218,7 @@ def generate_time_series_data_simple(treatment_time, slope=0.0):
return df


def generate_did():
def generate_did() -> pd.DataFrame:
"""
Generate Difference in Differences data

Expand All @@ -223,8 +236,14 @@ def generate_did():

# local functions
def outcome(
t, control_intercept, treat_intercept_delta, trend, Δ, group, post_treatment
):
t: np.ndarray,
control_intercept: float,
treat_intercept_delta: float,
trend: float,
Δ: float,
group: np.ndarray,
post_treatment: np.ndarray,
) -> np.ndarray:
"""Compute the outcome of each unit"""
return (
control_intercept
Expand Down Expand Up @@ -257,8 +276,8 @@ def outcome(


def generate_regression_discontinuity_data(
N=100, true_causal_impact=0.5, true_treatment_threshold=0.0
):
N: int = 100, true_causal_impact: float = 0.5, true_treatment_threshold: float = 0.0
) -> pd.DataFrame:
"""
Generate regression discontinuity example data

Expand All @@ -272,12 +291,12 @@ def generate_regression_discontinuity_data(
... ) # doctest: +SKIP
"""

def is_treated(x):
def is_treated(x: np.ndarray) -> np.ndarray:
"""Check if x was treated"""
return np.greater_equal(x, true_treatment_threshold)

def impact(x):
"""Assign true_causal_impact to all treaated entries"""
def impact(x: np.ndarray) -> np.ndarray:
"""Assign true_causal_impact to all treated entries"""
y = np.zeros(len(x))
y[is_treated(x)] = true_causal_impact
return y
Expand All @@ -289,8 +308,11 @@ def impact(x):


def generate_ancova_data(
N=200, pre_treatment_means=np.array([10, 12]), treatment_effect=2, sigma=1
):
N: int = 200,
pre_treatment_means: np.ndarray = np.array([10, 12]),
treatment_effect: float = 2,
sigma: float = 1,
) -> pd.DataFrame:
"""
Generate ANCOVA example data

Expand All @@ -310,7 +332,7 @@ def generate_ancova_data(
return df


def generate_geolift_data():
def generate_geolift_data() -> pd.DataFrame:
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
countries. The treated unit `Denmark` is a weighted combination of the untreated
units. We additionally specify a treatment effect which takes effect after the
Expand Down Expand Up @@ -360,7 +382,7 @@ def generate_geolift_data():
return df


def generate_multicell_geolift_data():
def generate_multicell_geolift_data() -> pd.DataFrame:
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
countries. The treated unit `Denmark` is a weighted combination of the untreated
units. We additionally specify a treatment effect which takes effect after the
Expand Down Expand Up @@ -422,7 +444,9 @@ def generate_multicell_geolift_data():
# -----------------


def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
def generate_seasonality(
n: int = 12, amplitude: float = 1, length_scale: float = 0.5
) -> np.ndarray:
"""Generate monthly seasonality by sampling from a Gaussian process with a
Gaussian kernel, using numpy code"""
# Generate the covariance matrix
Expand All @@ -436,14 +460,26 @@ def generate_seasonality(n=12, amplitude=1, length_scale=0.5):
return seasonality


def periodic_kernel(x1, x2, period=1, length_scale=1, amplitude=1):
def periodic_kernel(
x1: np.ndarray,
x2: np.ndarray,
period: float = 1,
length_scale: float = 1,
amplitude: float = 1,
) -> np.ndarray:
"""Generate a periodic kernel for gaussian process"""
return amplitude**2 * np.exp(
-2 * np.sin(np.pi * np.abs(x1 - x2) / period) ** 2 / length_scale**2
)


def create_series(n=52, amplitude=1, length_scale=2, n_years=4, intercept=3):
def create_series(
n: int = 52,
amplitude: float = 1,
length_scale: float = 2,
n_years: int = 4,
intercept: float = 3,
) -> np.ndarray:
"""
Returns numpy tile with generated seasonality data repeated over
multiple years
Expand Down
40 changes: 40 additions & 0 deletions causalpy/tests/test_synthetic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,43 @@ def test_generate_geolift_data():
df = generate_geolift_data()
assert isinstance(df, pd.DataFrame)
assert np.all(df >= 0), "Found negative values in dataset"


def test_generate_regression_discontinuity_data():
"""
Test the generate_regression_discontinuity_data function.
"""
from causalpy.data.simulate_data import generate_regression_discontinuity_data

df = generate_regression_discontinuity_data()
assert isinstance(df, pd.DataFrame)
assert "x" in df.columns
assert "y" in df.columns
assert "treated" in df.columns
assert len(df) == 100 # default N value
assert df["treated"].dtype == bool or df["treated"].dtype == np.bool_

# Test with custom parameters
df_custom = generate_regression_discontinuity_data(
N=50, true_causal_impact=1.0, true_treatment_threshold=0.5
)
assert len(df_custom) == 50


def test_generate_synthetic_control_data():
"""
Test the generate_synthetic_control_data function.
"""
from causalpy.data.simulate_data import generate_synthetic_control_data

# Test with default parameters (lowess_kwargs=None)
df, weightings = generate_synthetic_control_data()
assert isinstance(df, pd.DataFrame)
assert isinstance(weightings, np.ndarray)
assert len(df) == 100 # default N value

# Test with explicit lowess_kwargs
df_custom, weightings_custom = generate_synthetic_control_data(
N=50, lowess_kwargs={"frac": 0.3, "it": 5}
)
assert len(df_custom) == 50
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading