2020
2121import arviz as az
2222import matplotlib .pyplot as plt
23+ import numpy as np
2324import pandas as pd
2425from sklearn .base import RegressorMixin
2526
2627from causalpy .pymc_models import PyMCModel
2728from causalpy .reporting import (
2829 EffectSummary ,
2930 _compute_statistics ,
31+ _compute_statistics_did_ols ,
32+ _compute_statistics_ols ,
3033 _detect_experiment_type ,
3134 _effect_summary_did ,
35+ _effect_summary_rd ,
3236 _extract_counterfactual ,
3337 _extract_window ,
3438 _generate_prose ,
39+ _generate_prose_did_ols ,
40+ _generate_prose_ols ,
3541 _generate_table ,
42+ _generate_table_did_ols ,
43+ _generate_table_ols ,
3644)
3745from causalpy .skl_models import create_causalpy_compatible_class
3846
@@ -129,34 +137,34 @@ def effect_summary(
129137 treated_unit : Optional [str ] = None ,
130138 ) -> EffectSummary :
131139 """
132- Generate a decision-ready summary of causal effects from posterior draws .
140+ Generate a decision-ready summary of causal effects.
133141
134- Supports Interrupted Time Series (ITS), Synthetic Control, and
135- Difference-in-Differences (DiD ) experiments. Automatically detects experiment
136- type and generates appropriate summary.
142+ Supports Interrupted Time Series (ITS), Synthetic Control, Difference-in-Differences (DiD),
143+ and Regression Discontinuity (RD ) experiments. Works with both PyMC (Bayesian) and OLS models.
144+ Automatically detects experiment type and model type, generating appropriate summary.
137145
138146 Parameters
139147 ----------
140148 window : str, tuple, or slice, default="post"
141- Time window for analysis (ITS/SC only, ignored for DiD):
149+ Time window for analysis (ITS/SC only, ignored for DiD/RD ):
142150 - "post": All post-treatment time points (default)
143151 - (start, end): Tuple of start and end times (handles both datetime and integer indices)
144152 - slice: Python slice object for integer indices
145153 direction : {"increase", "decrease", "two-sided"}, default="increase"
146- Direction for tail probability calculation:
154+ Direction for tail probability calculation (PyMC only, ignored for OLS) :
147155 - "increase": P(effect > 0)
148156 - "decrease": P(effect < 0)
149157 - "two-sided": Two-sided p-value, report 1-p as "probability of effect"
150158 alpha : float, default=0.05
151- Significance level for HDI intervals (1-alpha confidence level)
159+ Significance level for HDI/CI intervals (1-alpha confidence level)
152160 cumulative : bool, default=True
153- Whether to include cumulative effect statistics (ITS/SC only, ignored for DiD)
161+ Whether to include cumulative effect statistics (ITS/SC only, ignored for DiD/RD )
154162 relative : bool, default=True
155163 Whether to include relative effect statistics (% change vs counterfactual)
156- (ITS/SC only, ignored for DiD)
164+ (ITS/SC only, ignored for DiD/RD )
157165 min_effect : float, optional
158- Region of Practical Equivalence (ROPE) threshold. If provided, reports
159- P(|effect| > min_effect) for two-sided or P(effect > min_effect) for one-sided.
166+ Region of Practical Equivalence (ROPE) threshold (PyMC only, ignored for OLS).
167+ If provided, reports P(|effect| > min_effect) for two-sided or P(effect > min_effect) for one-sided.
160168 treated_unit : str, optional
161169 For multi-unit experiments (Synthetic Control), specify which treated unit
162170 to analyze. If None and multiple units exist, uses first unit.
@@ -168,38 +176,62 @@ def effect_summary(
168176
169177 Examples
170178 --------
171- Interrupted Time Series:
179+ Interrupted Time Series (PyMC) :
172180
173181 import causalpy as cp
174- result = cp.InterruptedTimeSeries(...)
182+ result = cp.InterruptedTimeSeries(..., model=cp.pymc_models.LinearRegression(...) )
175183 stats = result.effect_summary()
176184 print(stats.table)
177185 print(stats.text)
178186
187+ Interrupted Time Series (OLS):
188+
189+ from sklearn.linear_model import LinearRegression
190+ result = cp.InterruptedTimeSeries(..., model=LinearRegression())
191+ stats = result.effect_summary()
192+ print(stats.table)
193+
179194 Difference-in-Differences:
180195
181196 result = cp.DifferenceInDifferences(...)
182197 stats = result.effect_summary()
183198 print(stats.table)
184- """
185- # Validate model type
186- if not isinstance (self .model , PyMCModel ):
187- raise ValueError (
188- "effect_summary currently only supports PyMC models. "
189- "OLS model support is planned for future release."
190- )
191199
200+ Regression Discontinuity:
201+
202+ result = cp.RegressionDiscontinuity(...)
203+ stats = result.effect_summary()
204+ print(stats.table)
205+ """
192206 # Detect experiment type
193207 experiment_type = _detect_experiment_type (self )
194208
195- if experiment_type == "did" :
196- # Difference-in-Differences: scalar effect, no time dimension
197- return _effect_summary_did (
209+ # Check if PyMC or OLS model
210+ is_pymc = isinstance (self .model , PyMCModel )
211+
212+ if experiment_type == "rd" :
213+ # Regression Discontinuity: scalar effect, no time dimension
214+ return _effect_summary_rd (
198215 self ,
199216 direction = direction ,
200217 alpha = alpha ,
201218 min_effect = min_effect ,
202219 )
220+ elif experiment_type == "did" :
221+ # Difference-in-Differences: scalar effect, no time dimension
222+ if is_pymc :
223+ return _effect_summary_did (
224+ self ,
225+ direction = direction ,
226+ alpha = alpha ,
227+ min_effect = min_effect ,
228+ )
229+ else :
230+ # OLS DiD
231+ stats = _compute_statistics_did_ols (self , alpha = alpha )
232+ table = _generate_table_did_ols (stats )
233+ text = _generate_prose_did_ols (stats , alpha = alpha )
234+ return EffectSummary (table = table , text = text )
203235 else :
204236 # ITS or Synthetic Control: time-series effects
205237 # Extract windowed impact data
@@ -212,29 +244,63 @@ def effect_summary(
212244 self , window_coords , treated_unit = treated_unit
213245 )
214246
215- # Compute statistics
216- hdi_prob = 1 - alpha
217- stats = _compute_statistics (
218- windowed_impact ,
219- counterfactual ,
220- hdi_prob = hdi_prob ,
221- direction = direction ,
222- cumulative = cumulative ,
223- relative = relative ,
224- min_effect = min_effect ,
225- )
226-
227- # Generate table
228- table = _generate_table (stats , cumulative = cumulative , relative = relative )
229-
230- # Generate prose
231- text = _generate_prose (
232- stats ,
233- window_coords ,
234- alpha = alpha ,
235- direction = direction ,
236- cumulative = cumulative ,
237- relative = relative ,
238- )
247+ if is_pymc :
248+ # PyMC model: use posterior draws
249+ hdi_prob = 1 - alpha
250+ stats = _compute_statistics (
251+ windowed_impact ,
252+ counterfactual ,
253+ hdi_prob = hdi_prob ,
254+ direction = direction ,
255+ cumulative = cumulative ,
256+ relative = relative ,
257+ min_effect = min_effect ,
258+ )
259+
260+ # Generate table
261+ table = _generate_table (stats , cumulative = cumulative , relative = relative )
262+
263+ # Generate prose
264+ text = _generate_prose (
265+ stats ,
266+ window_coords ,
267+ alpha = alpha ,
268+ direction = direction ,
269+ cumulative = cumulative ,
270+ relative = relative ,
271+ )
272+ else :
273+ # OLS model: use point estimates and CIs
274+ # Convert to numpy arrays if needed
275+ if hasattr (windowed_impact , "values" ):
276+ impact_array = windowed_impact .values
277+ else :
278+ impact_array = np .asarray (windowed_impact )
279+ if hasattr (counterfactual , "values" ):
280+ counterfactual_array = counterfactual .values
281+ else :
282+ counterfactual_array = np .asarray (counterfactual )
283+
284+ stats = _compute_statistics_ols (
285+ impact_array ,
286+ counterfactual_array ,
287+ alpha = alpha ,
288+ cumulative = cumulative ,
289+ relative = relative ,
290+ )
291+
292+ # Generate table
293+ table = _generate_table_ols (
294+ stats , cumulative = cumulative , relative = relative
295+ )
296+
297+ # Generate prose
298+ text = _generate_prose_ols (
299+ stats ,
300+ window_coords ,
301+ alpha = alpha ,
302+ cumulative = cumulative ,
303+ relative = relative ,
304+ )
239305
240306 return EffectSummary (table = table , text = text )
0 commit comments