|
16 | 16 | """ |
17 | 17 |
|
18 | 18 | from abc import abstractmethod |
19 | | -from typing import Any, Union |
| 19 | +from typing import Any, Literal, Union |
20 | 20 |
|
21 | 21 | import arviz as az |
22 | 22 | import matplotlib.pyplot as plt |
| 23 | +import numpy as np |
23 | 24 | import pandas as pd |
24 | 25 | from sklearn.base import RegressorMixin |
25 | 26 |
|
26 | 27 | from causalpy.pymc_models import PyMCModel |
| 28 | +from causalpy.reporting import ( |
| 29 | + EffectSummary, |
| 30 | + _compute_statistics, |
| 31 | + _compute_statistics_did_ols, |
| 32 | + _compute_statistics_ols, |
| 33 | + _detect_experiment_type, |
| 34 | + _effect_summary_did, |
| 35 | + _effect_summary_rd, |
| 36 | + _effect_summary_rkink, |
| 37 | + _extract_counterfactual, |
| 38 | + _extract_window, |
| 39 | + _generate_prose, |
| 40 | + _generate_prose_did_ols, |
| 41 | + _generate_prose_ols, |
| 42 | + _generate_table, |
| 43 | + _generate_table_did_ols, |
| 44 | + _generate_table_ols, |
| 45 | +) |
27 | 46 | from causalpy.skl_models import create_causalpy_compatible_class |
28 | 47 |
|
29 | 48 |
|
@@ -119,3 +138,161 @@ def get_plot_data_bayesian(self, *args: Any, **kwargs: Any) -> pd.DataFrame: |
119 | 138 | def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame: |
120 | 139 | """Abstract method for recovering plot data.""" |
121 | 140 | raise NotImplementedError("get_plot_data_ols method not yet implemented") |
| 141 | + |
| 142 | + def effect_summary( |
| 143 | + self, |
| 144 | + window: Union[Literal["post"], tuple, slice] = "post", |
| 145 | + direction: Literal["increase", "decrease", "two-sided"] = "increase", |
| 146 | + alpha: float = 0.05, |
| 147 | + cumulative: bool = True, |
| 148 | + relative: bool = True, |
| 149 | + min_effect: float | None = None, |
| 150 | + treated_unit: str | None = None, |
| 151 | + ) -> EffectSummary: |
| 152 | + """ |
| 153 | + Generate a decision-ready summary of causal effects. |
| 154 | +
|
| 155 | + Supports Interrupted Time Series (ITS), Synthetic Control, Difference-in-Differences (DiD), |
| 156 | + and Regression Discontinuity (RD) experiments. Works with both PyMC (Bayesian) and OLS models. |
| 157 | + Automatically detects experiment type and model type, generating appropriate summary. |
| 158 | +
|
| 159 | + Parameters |
| 160 | + ---------- |
| 161 | + window : str, tuple, or slice, default="post" |
| 162 | + Time window for analysis (ITS/SC only, ignored for DiD/RD): |
| 163 | + - "post": All post-treatment time points (default) |
| 164 | + - (start, end): Tuple of start and end times (handles both datetime and integer indices) |
| 165 | + - slice: Python slice object for integer indices |
| 166 | + direction : {"increase", "decrease", "two-sided"}, default="increase" |
| 167 | + Direction for tail probability calculation (PyMC only, ignored for OLS): |
| 168 | + - "increase": P(effect > 0) |
| 169 | + - "decrease": P(effect < 0) |
| 170 | + - "two-sided": Two-sided p-value, report 1-p as "probability of effect" |
| 171 | + alpha : float, default=0.05 |
| 172 | + Significance level for HDI/CI intervals (1-alpha confidence level) |
| 173 | + cumulative : bool, default=True |
| 174 | + Whether to include cumulative effect statistics (ITS/SC only, ignored for DiD/RD) |
| 175 | + relative : bool, default=True |
| 176 | + Whether to include relative effect statistics (% change vs counterfactual) |
| 177 | + (ITS/SC only, ignored for DiD/RD) |
| 178 | + min_effect : float, optional |
| 179 | + Region of Practical Equivalence (ROPE) threshold (PyMC only, ignored for OLS). |
| 180 | + If provided, reports P(|effect| > min_effect) for two-sided or P(effect > min_effect) for one-sided. |
| 181 | + treated_unit : str, optional |
| 182 | + For multi-unit experiments (Synthetic Control), specify which treated unit |
| 183 | + to analyze. If None and multiple units exist, uses first unit. |
| 184 | +
|
| 185 | + Returns |
| 186 | + ------- |
| 187 | + EffectSummary |
| 188 | + Object with .table (DataFrame) and .text (str) attributes |
| 189 | + """ |
| 190 | + # Detect experiment type |
| 191 | + experiment_type = _detect_experiment_type(self) |
| 192 | + |
| 193 | + # Check if PyMC or OLS model |
| 194 | + is_pymc = isinstance(self.model, PyMCModel) |
| 195 | + |
| 196 | + if experiment_type == "rd": |
| 197 | + # Regression Discontinuity: scalar effect, no time dimension |
| 198 | + return _effect_summary_rd( |
| 199 | + self, |
| 200 | + direction=direction, |
| 201 | + alpha=alpha, |
| 202 | + min_effect=min_effect, |
| 203 | + ) |
| 204 | + elif experiment_type == "rkink": |
| 205 | + # Regression Kink: scalar effect (gradient change at kink point) |
| 206 | + return _effect_summary_rkink( |
| 207 | + self, |
| 208 | + direction=direction, |
| 209 | + alpha=alpha, |
| 210 | + min_effect=min_effect, |
| 211 | + ) |
| 212 | + elif experiment_type == "did": |
| 213 | + # Difference-in-Differences: scalar effect, no time dimension |
| 214 | + if is_pymc: |
| 215 | + return _effect_summary_did( |
| 216 | + self, |
| 217 | + direction=direction, |
| 218 | + alpha=alpha, |
| 219 | + min_effect=min_effect, |
| 220 | + ) |
| 221 | + else: |
| 222 | + # OLS DiD |
| 223 | + stats = _compute_statistics_did_ols(self, alpha=alpha) |
| 224 | + table = _generate_table_did_ols(stats) |
| 225 | + text = _generate_prose_did_ols(stats, alpha=alpha) |
| 226 | + return EffectSummary(table=table, text=text) |
| 227 | + else: |
| 228 | + # ITS or Synthetic Control: time-series effects |
| 229 | + # Extract windowed impact data |
| 230 | + windowed_impact, window_coords = _extract_window( |
| 231 | + self, window, treated_unit=treated_unit |
| 232 | + ) |
| 233 | + |
| 234 | + # Extract counterfactual for relative effects |
| 235 | + counterfactual = _extract_counterfactual( |
| 236 | + self, window_coords, treated_unit=treated_unit |
| 237 | + ) |
| 238 | + |
| 239 | + if is_pymc: |
| 240 | + # PyMC model: use posterior draws |
| 241 | + hdi_prob = 1 - alpha |
| 242 | + stats = _compute_statistics( |
| 243 | + windowed_impact, |
| 244 | + counterfactual, |
| 245 | + hdi_prob=hdi_prob, |
| 246 | + direction=direction, |
| 247 | + cumulative=cumulative, |
| 248 | + relative=relative, |
| 249 | + min_effect=min_effect, |
| 250 | + ) |
| 251 | + |
| 252 | + # Generate table |
| 253 | + table = _generate_table(stats, cumulative=cumulative, relative=relative) |
| 254 | + |
| 255 | + # Generate prose |
| 256 | + text = _generate_prose( |
| 257 | + stats, |
| 258 | + window_coords, |
| 259 | + alpha=alpha, |
| 260 | + direction=direction, |
| 261 | + cumulative=cumulative, |
| 262 | + relative=relative, |
| 263 | + ) |
| 264 | + else: |
| 265 | + # OLS model: use point estimates and CIs |
| 266 | + # Convert to numpy arrays if needed |
| 267 | + if hasattr(windowed_impact, "values"): |
| 268 | + impact_array = windowed_impact.values |
| 269 | + else: |
| 270 | + impact_array = np.asarray(windowed_impact) |
| 271 | + if hasattr(counterfactual, "values"): |
| 272 | + counterfactual_array = counterfactual.values |
| 273 | + else: |
| 274 | + counterfactual_array = np.asarray(counterfactual) |
| 275 | + |
| 276 | + stats = _compute_statistics_ols( |
| 277 | + impact_array, |
| 278 | + counterfactual_array, |
| 279 | + alpha=alpha, |
| 280 | + cumulative=cumulative, |
| 281 | + relative=relative, |
| 282 | + ) |
| 283 | + |
| 284 | + # Generate table |
| 285 | + table = _generate_table_ols( |
| 286 | + stats, cumulative=cumulative, relative=relative |
| 287 | + ) |
| 288 | + |
| 289 | + # Generate prose |
| 290 | + text = _generate_prose_ols( |
| 291 | + stats, |
| 292 | + window_coords, |
| 293 | + alpha=alpha, |
| 294 | + cumulative=cumulative, |
| 295 | + relative=relative, |
| 296 | + ) |
| 297 | + |
| 298 | + return EffectSummary(table=table, text=text) |
0 commit comments