Skip to content

Commit 7dea992

Browse files
Add reporting layer (#554)
* first stab at an AGENTS.md file * Add "code structure and style" section * first attempt * fix failing test * update notebooks * fix failing doctest * add to all the other notebooks + fix failing tests * remove uninformative examples from docstring * fix notation for sckit-learn #555 * Add experiment reporting for more experiments/notebooks * refactor for better code quality * Add reporting statistics documentation and glossary updates * increase code coverage * fix pre-commit * Clarify effect summary description in notebooks * fix plot style error. Seaborn was taking over * fixes to banking notebook * fixes to drinking notebook * fix glossary terms * update AGENTS.md in terms of glossary links and cross references * Add 'decrease' direction to ROPE probability and update docs Extended the _compute_rope_probability function to support a 'decrease' direction, returning the probability that the effect is less than -min_effect. Added corresponding unit test. Updated documentation to clarify ROPE calculation for all directions and expanded reporting statistics and usage examples. --------- Co-authored-by: Juan Orduz <[email protected]>
1 parent 42bfcda commit 7dea992

24 files changed

+7338
-1215
lines changed

AGENTS.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
- **Structure**: Notebooks (how-to examples) go in `docs/source/notebooks/`, knowledgebase (educational content) goes in `docs/source/knowledgebase/`
2323
- **Notebook naming**: Use pattern `{method}_{model}.ipynb` (e.g., `did_pymc.ipynb`, `rd_skl.ipynb`), organized by causal method
2424
- **MyST directives**: Use `:::{note}` and other MyST features for callouts and formatting
25-
- **Glossary linking**: Use Sphinx `:term:` directives to link to glossary terms (defined in `glossary.rst`), typically on first mention in a file
25+
- **Glossary linking**: Link to glossary terms (defined in `glossary.rst`) on first mention in a file:
26+
- In Markdown files (`.md`, `.ipynb`): Use MyST syntax `{term}glossary term``
27+
- In RST files (`.rst`): Use Sphinx syntax `:term:`glossary term``
28+
- **Cross-references**: For other cross-references in Markdown files, use MyST role syntax with curly braces (e.g., `{doc}path/to/doc`, `{ref}label-name`)
2629
- **Citations**: Use `references.bib` for citations, cite sources in example notebooks where possible. Include reference section at bottom of notebooks using `:::{bibliography}` directive with `:filter: docname in docnames`
2730
- **API documentation**: Auto-generated from docstrings via Sphinx autodoc, no manual API docs needed
2831
- **Build**: Use `make html` to build documentation

causalpy/experiments/base.py

Lines changed: 178 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,33 @@
1616
"""
1717

1818
from abc import abstractmethod
19-
from typing import Any, Union
19+
from typing import Any, Literal, Union
2020

2121
import arviz as az
2222
import matplotlib.pyplot as plt
23+
import numpy as np
2324
import pandas as pd
2425
from sklearn.base import RegressorMixin
2526

2627
from causalpy.pymc_models import PyMCModel
28+
from causalpy.reporting import (
29+
EffectSummary,
30+
_compute_statistics,
31+
_compute_statistics_did_ols,
32+
_compute_statistics_ols,
33+
_detect_experiment_type,
34+
_effect_summary_did,
35+
_effect_summary_rd,
36+
_effect_summary_rkink,
37+
_extract_counterfactual,
38+
_extract_window,
39+
_generate_prose,
40+
_generate_prose_did_ols,
41+
_generate_prose_ols,
42+
_generate_table,
43+
_generate_table_did_ols,
44+
_generate_table_ols,
45+
)
2746
from causalpy.skl_models import create_causalpy_compatible_class
2847

2948

@@ -119,3 +138,161 @@ def get_plot_data_bayesian(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
119138
def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
120139
"""Abstract method for recovering plot data."""
121140
raise NotImplementedError("get_plot_data_ols method not yet implemented")
141+
142+
def effect_summary(
143+
self,
144+
window: Union[Literal["post"], tuple, slice] = "post",
145+
direction: Literal["increase", "decrease", "two-sided"] = "increase",
146+
alpha: float = 0.05,
147+
cumulative: bool = True,
148+
relative: bool = True,
149+
min_effect: float | None = None,
150+
treated_unit: str | None = None,
151+
) -> EffectSummary:
152+
"""
153+
Generate a decision-ready summary of causal effects.
154+
155+
Supports Interrupted Time Series (ITS), Synthetic Control, Difference-in-Differences (DiD),
156+
and Regression Discontinuity (RD) experiments. Works with both PyMC (Bayesian) and OLS models.
157+
Automatically detects experiment type and model type, generating appropriate summary.
158+
159+
Parameters
160+
----------
161+
window : str, tuple, or slice, default="post"
162+
Time window for analysis (ITS/SC only, ignored for DiD/RD):
163+
- "post": All post-treatment time points (default)
164+
- (start, end): Tuple of start and end times (handles both datetime and integer indices)
165+
- slice: Python slice object for integer indices
166+
direction : {"increase", "decrease", "two-sided"}, default="increase"
167+
Direction for tail probability calculation (PyMC only, ignored for OLS):
168+
- "increase": P(effect > 0)
169+
- "decrease": P(effect < 0)
170+
- "two-sided": Two-sided p-value, report 1-p as "probability of effect"
171+
alpha : float, default=0.05
172+
Significance level for HDI/CI intervals (1-alpha confidence level)
173+
cumulative : bool, default=True
174+
Whether to include cumulative effect statistics (ITS/SC only, ignored for DiD/RD)
175+
relative : bool, default=True
176+
Whether to include relative effect statistics (% change vs counterfactual)
177+
(ITS/SC only, ignored for DiD/RD)
178+
min_effect : float, optional
179+
Region of Practical Equivalence (ROPE) threshold (PyMC only, ignored for OLS).
180+
If provided, reports P(|effect| > min_effect) for two-sided or P(effect > min_effect) for one-sided.
181+
treated_unit : str, optional
182+
For multi-unit experiments (Synthetic Control), specify which treated unit
183+
to analyze. If None and multiple units exist, uses first unit.
184+
185+
Returns
186+
-------
187+
EffectSummary
188+
Object with .table (DataFrame) and .text (str) attributes
189+
"""
190+
# Detect experiment type
191+
experiment_type = _detect_experiment_type(self)
192+
193+
# Check if PyMC or OLS model
194+
is_pymc = isinstance(self.model, PyMCModel)
195+
196+
if experiment_type == "rd":
197+
# Regression Discontinuity: scalar effect, no time dimension
198+
return _effect_summary_rd(
199+
self,
200+
direction=direction,
201+
alpha=alpha,
202+
min_effect=min_effect,
203+
)
204+
elif experiment_type == "rkink":
205+
# Regression Kink: scalar effect (gradient change at kink point)
206+
return _effect_summary_rkink(
207+
self,
208+
direction=direction,
209+
alpha=alpha,
210+
min_effect=min_effect,
211+
)
212+
elif experiment_type == "did":
213+
# Difference-in-Differences: scalar effect, no time dimension
214+
if is_pymc:
215+
return _effect_summary_did(
216+
self,
217+
direction=direction,
218+
alpha=alpha,
219+
min_effect=min_effect,
220+
)
221+
else:
222+
# OLS DiD
223+
stats = _compute_statistics_did_ols(self, alpha=alpha)
224+
table = _generate_table_did_ols(stats)
225+
text = _generate_prose_did_ols(stats, alpha=alpha)
226+
return EffectSummary(table=table, text=text)
227+
else:
228+
# ITS or Synthetic Control: time-series effects
229+
# Extract windowed impact data
230+
windowed_impact, window_coords = _extract_window(
231+
self, window, treated_unit=treated_unit
232+
)
233+
234+
# Extract counterfactual for relative effects
235+
counterfactual = _extract_counterfactual(
236+
self, window_coords, treated_unit=treated_unit
237+
)
238+
239+
if is_pymc:
240+
# PyMC model: use posterior draws
241+
hdi_prob = 1 - alpha
242+
stats = _compute_statistics(
243+
windowed_impact,
244+
counterfactual,
245+
hdi_prob=hdi_prob,
246+
direction=direction,
247+
cumulative=cumulative,
248+
relative=relative,
249+
min_effect=min_effect,
250+
)
251+
252+
# Generate table
253+
table = _generate_table(stats, cumulative=cumulative, relative=relative)
254+
255+
# Generate prose
256+
text = _generate_prose(
257+
stats,
258+
window_coords,
259+
alpha=alpha,
260+
direction=direction,
261+
cumulative=cumulative,
262+
relative=relative,
263+
)
264+
else:
265+
# OLS model: use point estimates and CIs
266+
# Convert to numpy arrays if needed
267+
if hasattr(windowed_impact, "values"):
268+
impact_array = windowed_impact.values
269+
else:
270+
impact_array = np.asarray(windowed_impact)
271+
if hasattr(counterfactual, "values"):
272+
counterfactual_array = counterfactual.values
273+
else:
274+
counterfactual_array = np.asarray(counterfactual)
275+
276+
stats = _compute_statistics_ols(
277+
impact_array,
278+
counterfactual_array,
279+
alpha=alpha,
280+
cumulative=cumulative,
281+
relative=relative,
282+
)
283+
284+
# Generate table
285+
table = _generate_table_ols(
286+
stats, cumulative=cumulative, relative=relative
287+
)
288+
289+
# Generate prose
290+
text = _generate_prose_ols(
291+
stats,
292+
window_coords,
293+
alpha=alpha,
294+
cumulative=cumulative,
295+
relative=relative,
296+
)
297+
298+
return EffectSummary(table=table, text=text)

0 commit comments

Comments
 (0)