Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
82acb15
first stab at an AGENTS.md file
drbenvincent Nov 11, 2025
ee8ff1c
Add "code structure and style" section
drbenvincent Nov 11, 2025
f35c1fe
first attempt
drbenvincent Nov 11, 2025
dbe9a89
fix failing test
drbenvincent Nov 11, 2025
b75b6d5
update notebooks
drbenvincent Nov 11, 2025
4ceceb9
fix failing doctest
drbenvincent Nov 11, 2025
b22e443
add to all the other notebooks + fix failing tests
drbenvincent Nov 12, 2025
2638420
remove uninformative examples from docstring
drbenvincent Nov 12, 2025
d485665
fix notation for sckit-learn #555
drbenvincent Nov 12, 2025
f42be32
Add experiment reporting for more experiments/notebooks
drbenvincent Nov 12, 2025
d2d1813
refactor for better code quality
drbenvincent Nov 12, 2025
b832192
Add reporting statistics documentation and glossary updates
drbenvincent Nov 12, 2025
6323f39
increase code coverage
drbenvincent Nov 12, 2025
76eb637
Merge branch 'main' into reporting
juanitorduz Nov 12, 2025
d54d28a
fix pre-commit
juanitorduz Nov 12, 2025
e94496b
Clarify effect summary description in notebooks
drbenvincent Nov 12, 2025
6dcb347
Merge branch 'reporting' of https://github.com/pymc-labs/CausalPy int…
drbenvincent Nov 12, 2025
d4a657d
fix plot style error. Seaborn was taking over
drbenvincent Nov 12, 2025
54d31a1
fixes to banking notebook
drbenvincent Nov 12, 2025
764c460
fixes to drinking notebook
drbenvincent Nov 12, 2025
8f53047
fix glossary terms
drbenvincent Nov 12, 2025
dba0975
update AGENTS.md in terms of glossary links and cross references
drbenvincent Nov 12, 2025
ebee053
Add 'decrease' direction to ROPE probability and update docs
drbenvincent Nov 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# AGENTS

## Testing preferences

- Write all Python tests as `pytest` style functions, not unittest classes
- Use descriptive function names starting with `test_`
- Prefer fixtures over setup/teardown methods
- Use assert statements directly, not self.assertEqual

## Testing approach

- Never create throwaway test scripts or ad hoc verification files
- If you need to test functionality, write a proper test in the test suite
- All tests go in the `causalpy/tests/` directory following the project structure
- Tests should be runnable with the rest of the suite (`python -m pytest`)
- Even for quick verification, write it as a real test that provides ongoing value
- Preference should be given to integration tests, but unit tests are acceptable for core functionality to maintain high code coverage.
- Tests should remain quick to run. Tests involving MCMC sampling with PyMC should use custom `sample_kwargs` to minimize the computational load.

## Documentation

- **Structure**: Notebooks (how-to examples) go in `docs/source/notebooks/`, knowledgebase (educational content) goes in `docs/source/knowledgebase/`
- **Notebook naming**: Use pattern `{method}_{model}.ipynb` (e.g., `did_pymc.ipynb`, `rd_skl.ipynb`), organized by causal method
- **MyST directives**: Use `:::{note}` and other MyST features for callouts and formatting
- **Glossary linking**: Use Sphinx `:term:` directives to link to glossary terms (defined in `glossary.rst`), typically on first mention in a file
- **Citations**: Use `references.bib` for citations, cite sources in example notebooks where possible. Include reference section at bottom of notebooks using `:::{bibliography}` directive with `:filter: docname in docnames`
- **API documentation**: Auto-generated from docstrings via Sphinx autodoc, no manual API docs needed
- **Build**: Use `make html` to build documentation
- **Doctest**: Use `make doctest` to test that Python examples in doctests work

## Code structure and style

- **Experiment classes**: All experiment classes inherit from `BaseExperiment` in `causalpy/experiments/`. Must declare `supports_ols` and `supports_bayes` class attributes. Only implement abstract methods for supported model types (e.g., if only Bayesian is supported, implement `_bayesian_plot()` and `get_plot_data_bayesian()`; if only OLS is supported, implement `_ols_plot()` and `get_plot_data_ols()`)
- **Model-agnostic design**: Experiment classes should work with both PyMC and scikit-learn models. Use `isinstance(self.model, PyMCModel)` vs `isinstance(self.model, RegressorMixin)` to dispatch to appropriate implementations
- **Model classes**: PyMC models inherit from `PyMCModel` (extends `pm.Model`). Scikit-learn models use `RegressorMixin` and are made compatible via `create_causalpy_compatible_class()`. Common interface: `fit()`, `predict()`, `score()`, `calculate_impact()`, `print_coefficients()`
- **Data handling**: PyMC models use `xarray.DataArray` with coords (keys like "coeffs", "obs_ind", "treated_units"). Scikit-learn models use numpy arrays. Data index should be named "obs_ind"
- **Formulas**: Use patsy for formula parsing (via `dmatrices()`)
- **Custom exceptions**: Use project-specific exceptions from `causalpy.custom_exceptions`: `FormulaException`, `DataException`, `BadIndexException`
- **File organization**: Experiments in `causalpy/experiments/`, PyMC models in `causalpy/pymc_models.py`, scikit-learn models in `causalpy/skl_models.py`
3 changes: 3 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ We appreciate being notified of problems with the existing CausalPy code. We pre

Please verify that your issue is not being currently addressed by other issues or pull requests by using the GitHub search tool to look for key words in the project issue tracker.

## Use of agents
PR's with agent-generated code are fine. But don't spam us with code you don't understand. See [AGENTS.md](./AGENTS.md) for how we use LLMs in this repo.

## Contributing code via pull requests

While issue reporting is valuable, we strongly encourage users who are inclined to do so to submit patches for new or existing issues via pull requests. This is particularly the case for simple fixes, such as typos or tweaks to documentation, which do not require a heavy investment of time and attention.
Expand Down
180 changes: 179 additions & 1 deletion causalpy/experiments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,33 @@
"""

from abc import abstractmethod
from typing import Literal, Optional, Union

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.base import RegressorMixin

from causalpy.pymc_models import PyMCModel
from causalpy.reporting import (
EffectSummary,
_compute_statistics,
_compute_statistics_did_ols,
_compute_statistics_ols,
_detect_experiment_type,
_effect_summary_did,
_effect_summary_rd,
_effect_summary_rkink,
_extract_counterfactual,
_extract_window,
_generate_prose,
_generate_prose_did_ols,
_generate_prose_ols,
_generate_table,
_generate_table_did_ols,
_generate_table_ols,
)
from causalpy.skl_models import create_causalpy_compatible_class


Expand All @@ -33,7 +53,7 @@ class BaseExperiment:
supports_ols: bool

def __init__(self, model=None):
# Ensure we've made any provided Scikit Learn model (as identified as being type
# Ensure we've made any provided scikit-learn model (as identified as being type
# RegressorMixin) compatible with CausalPy by appending our custom methods.
if isinstance(model, RegressorMixin):
model = create_causalpy_compatible_class(model)
Expand Down Expand Up @@ -106,3 +126,161 @@ def get_plot_data_bayesian(self, *args, **kwargs):
def get_plot_data_ols(self, *args, **kwargs):
"""Abstract method for recovering plot data."""
raise NotImplementedError("get_plot_data_ols method not yet implemented")

def effect_summary(
self,
window: Union[Literal["post"], tuple, slice] = "post",
direction: Literal["increase", "decrease", "two-sided"] = "increase",
alpha: float = 0.05,
cumulative: bool = True,
relative: bool = True,
min_effect: Optional[float] = None,
treated_unit: Optional[str] = None,
) -> EffectSummary:
"""
Generate a decision-ready summary of causal effects.

Supports Interrupted Time Series (ITS), Synthetic Control, Difference-in-Differences (DiD),
and Regression Discontinuity (RD) experiments. Works with both PyMC (Bayesian) and OLS models.
Automatically detects experiment type and model type, generating appropriate summary.

Parameters
----------
window : str, tuple, or slice, default="post"
Time window for analysis (ITS/SC only, ignored for DiD/RD):
- "post": All post-treatment time points (default)
- (start, end): Tuple of start and end times (handles both datetime and integer indices)
- slice: Python slice object for integer indices
direction : {"increase", "decrease", "two-sided"}, default="increase"
Direction for tail probability calculation (PyMC only, ignored for OLS):
- "increase": P(effect > 0)
- "decrease": P(effect < 0)
- "two-sided": Two-sided p-value, report 1-p as "probability of effect"
alpha : float, default=0.05
Significance level for HDI/CI intervals (1-alpha confidence level)
cumulative : bool, default=True
Whether to include cumulative effect statistics (ITS/SC only, ignored for DiD/RD)
relative : bool, default=True
Whether to include relative effect statistics (% change vs counterfactual)
(ITS/SC only, ignored for DiD/RD)
min_effect : float, optional
Region of Practical Equivalence (ROPE) threshold (PyMC only, ignored for OLS).
If provided, reports P(|effect| > min_effect) for two-sided or P(effect > min_effect) for one-sided.
treated_unit : str, optional
For multi-unit experiments (Synthetic Control), specify which treated unit
to analyze. If None and multiple units exist, uses first unit.

Returns
-------
EffectSummary
Object with .table (DataFrame) and .text (str) attributes
"""
# Detect experiment type
experiment_type = _detect_experiment_type(self)

# Check if PyMC or OLS model
is_pymc = isinstance(self.model, PyMCModel)

if experiment_type == "rd":
# Regression Discontinuity: scalar effect, no time dimension
return _effect_summary_rd(
self,
direction=direction,
alpha=alpha,
min_effect=min_effect,
)
elif experiment_type == "rkink":
# Regression Kink: scalar effect (gradient change at kink point)
return _effect_summary_rkink(
self,
direction=direction,
alpha=alpha,
min_effect=min_effect,
)
elif experiment_type == "did":
# Difference-in-Differences: scalar effect, no time dimension
if is_pymc:
return _effect_summary_did(
self,
direction=direction,
alpha=alpha,
min_effect=min_effect,
)
else:
# OLS DiD
stats = _compute_statistics_did_ols(self, alpha=alpha)
table = _generate_table_did_ols(stats)
text = _generate_prose_did_ols(stats, alpha=alpha)
return EffectSummary(table=table, text=text)
else:
# ITS or Synthetic Control: time-series effects
# Extract windowed impact data
windowed_impact, window_coords = _extract_window(
self, window, treated_unit=treated_unit
)

# Extract counterfactual for relative effects
counterfactual = _extract_counterfactual(
self, window_coords, treated_unit=treated_unit
)

if is_pymc:
# PyMC model: use posterior draws
hdi_prob = 1 - alpha
stats = _compute_statistics(
windowed_impact,
counterfactual,
hdi_prob=hdi_prob,
direction=direction,
cumulative=cumulative,
relative=relative,
min_effect=min_effect,
)

# Generate table
table = _generate_table(stats, cumulative=cumulative, relative=relative)

# Generate prose
text = _generate_prose(
stats,
window_coords,
alpha=alpha,
direction=direction,
cumulative=cumulative,
relative=relative,
)
else:
# OLS model: use point estimates and CIs
# Convert to numpy arrays if needed
if hasattr(windowed_impact, "values"):
impact_array = windowed_impact.values
else:
impact_array = np.asarray(windowed_impact)
if hasattr(counterfactual, "values"):
counterfactual_array = counterfactual.values
else:
counterfactual_array = np.asarray(counterfactual)

stats = _compute_statistics_ols(
impact_array,
counterfactual_array,
alpha=alpha,
cumulative=cumulative,
relative=relative,
)

# Generate table
table = _generate_table_ols(
stats, cumulative=cumulative, relative=relative
)

# Generate prose
text = _generate_prose_ols(
stats,
window_coords,
alpha=alpha,
cumulative=cumulative,
relative=relative,
)

return EffectSummary(table=table, text=text)
Loading