|
16 | 16 | """ |
17 | 17 |
|
18 | 18 | from abc import abstractmethod |
| 19 | +from typing import Literal, Optional, Union |
19 | 20 |
|
20 | 21 | import arviz as az |
21 | 22 | import matplotlib.pyplot as plt |
22 | 23 | import pandas as pd |
23 | 24 | from sklearn.base import RegressorMixin |
24 | 25 |
|
25 | 26 | from causalpy.pymc_models import PyMCModel |
| 27 | +from causalpy.reporting import ( |
| 28 | + EffectSummary, |
| 29 | + _compute_statistics, |
| 30 | + _detect_experiment_type, |
| 31 | + _effect_summary_did, |
| 32 | + _extract_counterfactual, |
| 33 | + _extract_window, |
| 34 | + _generate_prose, |
| 35 | + _generate_table, |
| 36 | +) |
26 | 37 | from causalpy.skl_models import create_causalpy_compatible_class |
27 | 38 |
|
28 | 39 |
|
@@ -106,3 +117,121 @@ def get_plot_data_bayesian(self, *args, **kwargs): |
106 | 117 | def get_plot_data_ols(self, *args, **kwargs): |
107 | 118 | """Abstract method for recovering plot data.""" |
108 | 119 | raise NotImplementedError("get_plot_data_ols method not yet implemented") |
| 120 | + |
| 121 | + def effect_summary( |
| 122 | + self, |
| 123 | + window: Union[Literal["post"], tuple, slice] = "post", |
| 124 | + direction: Literal["increase", "decrease", "two-sided"] = "increase", |
| 125 | + alpha: float = 0.05, |
| 126 | + cumulative: bool = True, |
| 127 | + relative: bool = True, |
| 128 | + min_effect: Optional[float] = None, |
| 129 | + treated_unit: Optional[str] = None, |
| 130 | + ) -> EffectSummary: |
| 131 | + """ |
| 132 | + Generate a decision-ready summary of causal effects from posterior draws. |
| 133 | +
|
| 134 | + Supports Interrupted Time Series (ITS), Synthetic Control, and |
| 135 | + Difference-in-Differences (DiD) experiments. Automatically detects experiment |
| 136 | + type and generates appropriate summary. |
| 137 | +
|
| 138 | + Parameters |
| 139 | + ---------- |
| 140 | + window : str, tuple, or slice, default="post" |
| 141 | + Time window for analysis (ITS/SC only, ignored for DiD): |
| 142 | + - "post": All post-treatment time points (default) |
| 143 | + - (start, end): Tuple of start and end times (handles both datetime and integer indices) |
| 144 | + - slice: Python slice object for integer indices |
| 145 | + direction : {"increase", "decrease", "two-sided"}, default="increase" |
| 146 | + Direction for tail probability calculation: |
| 147 | + - "increase": P(effect > 0) |
| 148 | + - "decrease": P(effect < 0) |
| 149 | + - "two-sided": Two-sided p-value, report 1-p as "probability of effect" |
| 150 | + alpha : float, default=0.05 |
| 151 | + Significance level for HDI intervals (1-alpha confidence level) |
| 152 | + cumulative : bool, default=True |
| 153 | + Whether to include cumulative effect statistics (ITS/SC only, ignored for DiD) |
| 154 | + relative : bool, default=True |
| 155 | + Whether to include relative effect statistics (% change vs counterfactual) |
| 156 | + (ITS/SC only, ignored for DiD) |
| 157 | + min_effect : float, optional |
| 158 | + Region of Practical Equivalence (ROPE) threshold. If provided, reports |
| 159 | + P(|effect| > min_effect) for two-sided or P(effect > min_effect) for one-sided. |
| 160 | + treated_unit : str, optional |
| 161 | + For multi-unit experiments (Synthetic Control), specify which treated unit |
| 162 | + to analyze. If None and multiple units exist, uses first unit. |
| 163 | +
|
| 164 | + Returns |
| 165 | + ------- |
| 166 | + EffectSummary |
| 167 | + Object with .table (DataFrame) and .text (str) attributes |
| 168 | +
|
| 169 | + Examples |
| 170 | + -------- |
| 171 | + >>> import causalpy as cp |
| 172 | + >>> # Interrupted Time Series |
| 173 | + >>> result = cp.InterruptedTimeSeries(...) |
| 174 | + >>> stats = result.effect_summary() |
| 175 | + >>> print(stats.table) |
| 176 | + >>> print(stats.text) |
| 177 | + >>> # Difference-in-Differences |
| 178 | + >>> result = cp.DifferenceInDifferences(...) |
| 179 | + >>> stats = result.effect_summary() |
| 180 | + >>> print(stats.table) |
| 181 | + """ |
| 182 | + # Validate model type |
| 183 | + if not isinstance(self.model, PyMCModel): |
| 184 | + raise ValueError( |
| 185 | + "effect_summary currently only supports PyMC models. " |
| 186 | + "OLS model support is planned for future release." |
| 187 | + ) |
| 188 | + |
| 189 | + # Detect experiment type |
| 190 | + experiment_type = _detect_experiment_type(self) |
| 191 | + |
| 192 | + if experiment_type == "did": |
| 193 | + # Difference-in-Differences: scalar effect, no time dimension |
| 194 | + return _effect_summary_did( |
| 195 | + self, |
| 196 | + direction=direction, |
| 197 | + alpha=alpha, |
| 198 | + min_effect=min_effect, |
| 199 | + ) |
| 200 | + else: |
| 201 | + # ITS or Synthetic Control: time-series effects |
| 202 | + # Extract windowed impact data |
| 203 | + windowed_impact, window_coords = _extract_window( |
| 204 | + self, window, treated_unit=treated_unit |
| 205 | + ) |
| 206 | + |
| 207 | + # Extract counterfactual for relative effects |
| 208 | + counterfactual = _extract_counterfactual( |
| 209 | + self, window_coords, treated_unit=treated_unit |
| 210 | + ) |
| 211 | + |
| 212 | + # Compute statistics |
| 213 | + hdi_prob = 1 - alpha |
| 214 | + stats = _compute_statistics( |
| 215 | + windowed_impact, |
| 216 | + counterfactual, |
| 217 | + hdi_prob=hdi_prob, |
| 218 | + direction=direction, |
| 219 | + cumulative=cumulative, |
| 220 | + relative=relative, |
| 221 | + min_effect=min_effect, |
| 222 | + ) |
| 223 | + |
| 224 | + # Generate table |
| 225 | + table = _generate_table(stats, cumulative=cumulative, relative=relative) |
| 226 | + |
| 227 | + # Generate prose |
| 228 | + text = _generate_prose( |
| 229 | + stats, |
| 230 | + window_coords, |
| 231 | + alpha=alpha, |
| 232 | + direction=direction, |
| 233 | + cumulative=cumulative, |
| 234 | + relative=relative, |
| 235 | + ) |
| 236 | + |
| 237 | + return EffectSummary(table=table, text=text) |
0 commit comments