Skip to content

Commit b22e443

Browse files
committed
add to all the other notebooks + fix failing tests
1 parent 4ceceb9 commit b22e443

File tree

12 files changed

+2446
-648
lines changed

12 files changed

+2446
-648
lines changed

causalpy/experiments/base.py

Lines changed: 113 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,27 @@
2020

2121
import arviz as az
2222
import matplotlib.pyplot as plt
23+
import numpy as np
2324
import pandas as pd
2425
from sklearn.base import RegressorMixin
2526

2627
from causalpy.pymc_models import PyMCModel
2728
from causalpy.reporting import (
2829
EffectSummary,
2930
_compute_statistics,
31+
_compute_statistics_did_ols,
32+
_compute_statistics_ols,
3033
_detect_experiment_type,
3134
_effect_summary_did,
35+
_effect_summary_rd,
3236
_extract_counterfactual,
3337
_extract_window,
3438
_generate_prose,
39+
_generate_prose_did_ols,
40+
_generate_prose_ols,
3541
_generate_table,
42+
_generate_table_did_ols,
43+
_generate_table_ols,
3644
)
3745
from causalpy.skl_models import create_causalpy_compatible_class
3846

@@ -129,34 +137,34 @@ def effect_summary(
129137
treated_unit: Optional[str] = None,
130138
) -> EffectSummary:
131139
"""
132-
Generate a decision-ready summary of causal effects from posterior draws.
140+
Generate a decision-ready summary of causal effects.
133141
134-
Supports Interrupted Time Series (ITS), Synthetic Control, and
135-
Difference-in-Differences (DiD) experiments. Automatically detects experiment
136-
type and generates appropriate summary.
142+
Supports Interrupted Time Series (ITS), Synthetic Control, Difference-in-Differences (DiD),
143+
and Regression Discontinuity (RD) experiments. Works with both PyMC (Bayesian) and OLS models.
144+
Automatically detects experiment type and model type, generating appropriate summary.
137145
138146
Parameters
139147
----------
140148
window : str, tuple, or slice, default="post"
141-
Time window for analysis (ITS/SC only, ignored for DiD):
149+
Time window for analysis (ITS/SC only, ignored for DiD/RD):
142150
- "post": All post-treatment time points (default)
143151
- (start, end): Tuple of start and end times (handles both datetime and integer indices)
144152
- slice: Python slice object for integer indices
145153
direction : {"increase", "decrease", "two-sided"}, default="increase"
146-
Direction for tail probability calculation:
154+
Direction for tail probability calculation (PyMC only, ignored for OLS):
147155
- "increase": P(effect > 0)
148156
- "decrease": P(effect < 0)
149157
- "two-sided": Two-sided p-value, report 1-p as "probability of effect"
150158
alpha : float, default=0.05
151-
Significance level for HDI intervals (1-alpha confidence level)
159+
Significance level for HDI/CI intervals (1-alpha confidence level)
152160
cumulative : bool, default=True
153-
Whether to include cumulative effect statistics (ITS/SC only, ignored for DiD)
161+
Whether to include cumulative effect statistics (ITS/SC only, ignored for DiD/RD)
154162
relative : bool, default=True
155163
Whether to include relative effect statistics (% change vs counterfactual)
156-
(ITS/SC only, ignored for DiD)
164+
(ITS/SC only, ignored for DiD/RD)
157165
min_effect : float, optional
158-
Region of Practical Equivalence (ROPE) threshold. If provided, reports
159-
P(|effect| > min_effect) for two-sided or P(effect > min_effect) for one-sided.
166+
Region of Practical Equivalence (ROPE) threshold (PyMC only, ignored for OLS).
167+
If provided, reports P(|effect| > min_effect) for two-sided or P(effect > min_effect) for one-sided.
160168
treated_unit : str, optional
161169
For multi-unit experiments (Synthetic Control), specify which treated unit
162170
to analyze. If None and multiple units exist, uses first unit.
@@ -168,38 +176,62 @@ def effect_summary(
168176
169177
Examples
170178
--------
171-
Interrupted Time Series:
179+
Interrupted Time Series (PyMC):
172180
173181
import causalpy as cp
174-
result = cp.InterruptedTimeSeries(...)
182+
result = cp.InterruptedTimeSeries(..., model=cp.pymc_models.LinearRegression(...))
175183
stats = result.effect_summary()
176184
print(stats.table)
177185
print(stats.text)
178186
187+
Interrupted Time Series (OLS):
188+
189+
from sklearn.linear_model import LinearRegression
190+
result = cp.InterruptedTimeSeries(..., model=LinearRegression())
191+
stats = result.effect_summary()
192+
print(stats.table)
193+
179194
Difference-in-Differences:
180195
181196
result = cp.DifferenceInDifferences(...)
182197
stats = result.effect_summary()
183198
print(stats.table)
184-
"""
185-
# Validate model type
186-
if not isinstance(self.model, PyMCModel):
187-
raise ValueError(
188-
"effect_summary currently only supports PyMC models. "
189-
"OLS model support is planned for future release."
190-
)
191199
200+
Regression Discontinuity:
201+
202+
result = cp.RegressionDiscontinuity(...)
203+
stats = result.effect_summary()
204+
print(stats.table)
205+
"""
192206
# Detect experiment type
193207
experiment_type = _detect_experiment_type(self)
194208

195-
if experiment_type == "did":
196-
# Difference-in-Differences: scalar effect, no time dimension
197-
return _effect_summary_did(
209+
# Check if PyMC or OLS model
210+
is_pymc = isinstance(self.model, PyMCModel)
211+
212+
if experiment_type == "rd":
213+
# Regression Discontinuity: scalar effect, no time dimension
214+
return _effect_summary_rd(
198215
self,
199216
direction=direction,
200217
alpha=alpha,
201218
min_effect=min_effect,
202219
)
220+
elif experiment_type == "did":
221+
# Difference-in-Differences: scalar effect, no time dimension
222+
if is_pymc:
223+
return _effect_summary_did(
224+
self,
225+
direction=direction,
226+
alpha=alpha,
227+
min_effect=min_effect,
228+
)
229+
else:
230+
# OLS DiD
231+
stats = _compute_statistics_did_ols(self, alpha=alpha)
232+
table = _generate_table_did_ols(stats)
233+
text = _generate_prose_did_ols(stats, alpha=alpha)
234+
return EffectSummary(table=table, text=text)
203235
else:
204236
# ITS or Synthetic Control: time-series effects
205237
# Extract windowed impact data
@@ -212,29 +244,63 @@ def effect_summary(
212244
self, window_coords, treated_unit=treated_unit
213245
)
214246

215-
# Compute statistics
216-
hdi_prob = 1 - alpha
217-
stats = _compute_statistics(
218-
windowed_impact,
219-
counterfactual,
220-
hdi_prob=hdi_prob,
221-
direction=direction,
222-
cumulative=cumulative,
223-
relative=relative,
224-
min_effect=min_effect,
225-
)
226-
227-
# Generate table
228-
table = _generate_table(stats, cumulative=cumulative, relative=relative)
229-
230-
# Generate prose
231-
text = _generate_prose(
232-
stats,
233-
window_coords,
234-
alpha=alpha,
235-
direction=direction,
236-
cumulative=cumulative,
237-
relative=relative,
238-
)
247+
if is_pymc:
248+
# PyMC model: use posterior draws
249+
hdi_prob = 1 - alpha
250+
stats = _compute_statistics(
251+
windowed_impact,
252+
counterfactual,
253+
hdi_prob=hdi_prob,
254+
direction=direction,
255+
cumulative=cumulative,
256+
relative=relative,
257+
min_effect=min_effect,
258+
)
259+
260+
# Generate table
261+
table = _generate_table(stats, cumulative=cumulative, relative=relative)
262+
263+
# Generate prose
264+
text = _generate_prose(
265+
stats,
266+
window_coords,
267+
alpha=alpha,
268+
direction=direction,
269+
cumulative=cumulative,
270+
relative=relative,
271+
)
272+
else:
273+
# OLS model: use point estimates and CIs
274+
# Convert to numpy arrays if needed
275+
if hasattr(windowed_impact, "values"):
276+
impact_array = windowed_impact.values
277+
else:
278+
impact_array = np.asarray(windowed_impact)
279+
if hasattr(counterfactual, "values"):
280+
counterfactual_array = counterfactual.values
281+
else:
282+
counterfactual_array = np.asarray(counterfactual)
283+
284+
stats = _compute_statistics_ols(
285+
impact_array,
286+
counterfactual_array,
287+
alpha=alpha,
288+
cumulative=cumulative,
289+
relative=relative,
290+
)
291+
292+
# Generate table
293+
table = _generate_table_ols(
294+
stats, cumulative=cumulative, relative=relative
295+
)
296+
297+
# Generate prose
298+
text = _generate_prose_ols(
299+
stats,
300+
window_coords,
301+
alpha=alpha,
302+
cumulative=cumulative,
303+
relative=relative,
304+
)
239305

240306
return EffectSummary(table=table, text=text)

0 commit comments

Comments
 (0)