Skip to content

Commit f35c1fe

Browse files
committed
first attempt
1 parent ee8ff1c commit f35c1fe

File tree

7 files changed

+1897
-420
lines changed

7 files changed

+1897
-420
lines changed

causalpy/experiments/base.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,24 @@
1616
"""
1717

1818
from abc import abstractmethod
19+
from typing import Literal, Optional, Union
1920

2021
import arviz as az
2122
import matplotlib.pyplot as plt
2223
import pandas as pd
2324
from sklearn.base import RegressorMixin
2425

2526
from causalpy.pymc_models import PyMCModel
27+
from causalpy.reporting import (
28+
EffectSummary,
29+
_compute_statistics,
30+
_detect_experiment_type,
31+
_effect_summary_did,
32+
_extract_counterfactual,
33+
_extract_window,
34+
_generate_prose,
35+
_generate_table,
36+
)
2637
from causalpy.skl_models import create_causalpy_compatible_class
2738

2839

@@ -106,3 +117,121 @@ def get_plot_data_bayesian(self, *args, **kwargs):
106117
def get_plot_data_ols(self, *args, **kwargs):
107118
"""Abstract method for recovering plot data."""
108119
raise NotImplementedError("get_plot_data_ols method not yet implemented")
120+
121+
def effect_summary(
122+
self,
123+
window: Union[Literal["post"], tuple, slice] = "post",
124+
direction: Literal["increase", "decrease", "two-sided"] = "increase",
125+
alpha: float = 0.05,
126+
cumulative: bool = True,
127+
relative: bool = True,
128+
min_effect: Optional[float] = None,
129+
treated_unit: Optional[str] = None,
130+
) -> EffectSummary:
131+
"""
132+
Generate a decision-ready summary of causal effects from posterior draws.
133+
134+
Supports Interrupted Time Series (ITS), Synthetic Control, and
135+
Difference-in-Differences (DiD) experiments. Automatically detects experiment
136+
type and generates appropriate summary.
137+
138+
Parameters
139+
----------
140+
window : str, tuple, or slice, default="post"
141+
Time window for analysis (ITS/SC only, ignored for DiD):
142+
- "post": All post-treatment time points (default)
143+
- (start, end): Tuple of start and end times (handles both datetime and integer indices)
144+
- slice: Python slice object for integer indices
145+
direction : {"increase", "decrease", "two-sided"}, default="increase"
146+
Direction for tail probability calculation:
147+
- "increase": P(effect > 0)
148+
- "decrease": P(effect < 0)
149+
- "two-sided": Two-sided p-value, report 1-p as "probability of effect"
150+
alpha : float, default=0.05
151+
Significance level for HDI intervals (1-alpha confidence level)
152+
cumulative : bool, default=True
153+
Whether to include cumulative effect statistics (ITS/SC only, ignored for DiD)
154+
relative : bool, default=True
155+
Whether to include relative effect statistics (% change vs counterfactual)
156+
(ITS/SC only, ignored for DiD)
157+
min_effect : float, optional
158+
Region of Practical Equivalence (ROPE) threshold. If provided, reports
159+
P(|effect| > min_effect) for two-sided or P(effect > min_effect) for one-sided.
160+
treated_unit : str, optional
161+
For multi-unit experiments (Synthetic Control), specify which treated unit
162+
to analyze. If None and multiple units exist, uses first unit.
163+
164+
Returns
165+
-------
166+
EffectSummary
167+
Object with .table (DataFrame) and .text (str) attributes
168+
169+
Examples
170+
--------
171+
>>> import causalpy as cp
172+
>>> # Interrupted Time Series
173+
>>> result = cp.InterruptedTimeSeries(...)
174+
>>> stats = result.effect_summary()
175+
>>> print(stats.table)
176+
>>> print(stats.text)
177+
>>> # Difference-in-Differences
178+
>>> result = cp.DifferenceInDifferences(...)
179+
>>> stats = result.effect_summary()
180+
>>> print(stats.table)
181+
"""
182+
# Validate model type
183+
if not isinstance(self.model, PyMCModel):
184+
raise ValueError(
185+
"effect_summary currently only supports PyMC models. "
186+
"OLS model support is planned for future release."
187+
)
188+
189+
# Detect experiment type
190+
experiment_type = _detect_experiment_type(self)
191+
192+
if experiment_type == "did":
193+
# Difference-in-Differences: scalar effect, no time dimension
194+
return _effect_summary_did(
195+
self,
196+
direction=direction,
197+
alpha=alpha,
198+
min_effect=min_effect,
199+
)
200+
else:
201+
# ITS or Synthetic Control: time-series effects
202+
# Extract windowed impact data
203+
windowed_impact, window_coords = _extract_window(
204+
self, window, treated_unit=treated_unit
205+
)
206+
207+
# Extract counterfactual for relative effects
208+
counterfactual = _extract_counterfactual(
209+
self, window_coords, treated_unit=treated_unit
210+
)
211+
212+
# Compute statistics
213+
hdi_prob = 1 - alpha
214+
stats = _compute_statistics(
215+
windowed_impact,
216+
counterfactual,
217+
hdi_prob=hdi_prob,
218+
direction=direction,
219+
cumulative=cumulative,
220+
relative=relative,
221+
min_effect=min_effect,
222+
)
223+
224+
# Generate table
225+
table = _generate_table(stats, cumulative=cumulative, relative=relative)
226+
227+
# Generate prose
228+
text = _generate_prose(
229+
stats,
230+
window_coords,
231+
alpha=alpha,
232+
direction=direction,
233+
cumulative=cumulative,
234+
relative=relative,
235+
)
236+
237+
return EffectSummary(table=table, text=text)

0 commit comments

Comments
 (0)