Skip to content

Commit 62ab14d

Browse files
committed
add ability to use a bandwidth in Regression Discontinuity designs
1 parent e007b94 commit 62ab14d

File tree

7 files changed

+673
-244
lines changed

7 files changed

+673
-244
lines changed

causalpy/pymc_experiments.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Union
1+
from typing import Optional, Union
22

33
import arviz as az
44
import matplotlib.pyplot as plt
@@ -556,6 +556,9 @@ class RegressionDiscontinuity(ExperimentalDesign):
556556
:param epsilon:
557557
A small scalar value which determines how far above and below the treatment
558558
threshold to evaluate the causal impact.
559+
:param bandwidth:
560+
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
561+
the model.
559562
"""
560563

561564
def __init__(
@@ -566,6 +569,7 @@ def __init__(
566569
model=None,
567570
running_variable_name: str = "x",
568571
epsilon: float = 0.001,
572+
bandwidth: Optional[float] = None,
569573
**kwargs,
570574
):
571575
super().__init__(model=model, **kwargs)
@@ -575,9 +579,17 @@ def __init__(
575579
self.running_variable_name = running_variable_name
576580
self.treatment_threshold = treatment_threshold
577581
self.epsilon = epsilon
582+
self.bandwidth = bandwidth
578583
self._input_validation()
579584

580-
y, X = dmatrices(formula, self.data)
585+
if self.bandwidth is not None:
586+
fmin = self.treatment_threshold - self.bandwidth
587+
fmax = self.treatment_threshold + self.bandwidth
588+
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
589+
y, X = dmatrices(formula, filtered_data)
590+
else:
591+
y, X = dmatrices(formula, self.data)
592+
581593
self._y_design_info = y.design_info
582594
self._x_design_info = X.design_info
583595
self.labels = X.design_info.column_names
@@ -594,11 +606,14 @@ def __init__(
594606
self.score = self.model.score(X=self.X, y=self.y)
595607

596608
# get the model predictions of the observed data
597-
xi = np.linspace(
598-
np.min(self.data[self.running_variable_name]),
599-
np.max(self.data[self.running_variable_name]),
600-
200,
601-
)
609+
if self.bandwidth is not None:
610+
xi = np.linspace(fmin, fmax, 200)
611+
else:
612+
xi = np.linspace(
613+
np.min(self.data[self.running_variable_name]),
614+
np.max(self.data[self.running_variable_name]),
615+
200,
616+
)
602617
self.x_pred = pd.DataFrame(
603618
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
604619
)

causalpy/skl_experiments.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
import matplotlib.pyplot as plt
24
import numpy as np
35
import pandas as pd
@@ -361,6 +363,9 @@ class RegressionDiscontinuity(ExperimentalDesign):
361363
:param epsilon:
362364
A small scalar value which determines how far above and below the treatment
363365
threshold to evaluate the causal impact.
366+
:param bandwidth:
367+
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
368+
the model.
364369
"""
365370

366371
def __init__(
@@ -371,15 +376,25 @@ def __init__(
371376
model=None,
372377
running_variable_name="x",
373378
epsilon: float = 0.001,
379+
bandwidth: Optional[float] = None,
374380
**kwargs,
375381
):
376382
super().__init__(model=model, **kwargs)
377383
self.data = data
378384
self.formula = formula
379385
self.running_variable_name = running_variable_name
380386
self.treatment_threshold = treatment_threshold
387+
self.bandwidth = bandwidth
381388
self.epsilon = epsilon
382-
y, X = dmatrices(formula, self.data)
389+
390+
if self.bandwidth is not None:
391+
fmin = self.treatment_threshold - self.bandwidth
392+
fmax = self.treatment_threshold + self.bandwidth
393+
filtered_data = self.data.query(f"{fmin} <= x <= {fmax}")
394+
y, X = dmatrices(formula, filtered_data)
395+
else:
396+
y, X = dmatrices(formula, self.data)
397+
383398
self._y_design_info = y.design_info
384399
self._x_design_info = X.design_info
385400
self.labels = X.design_info.column_names
@@ -396,11 +411,14 @@ def __init__(
396411
self.score = self.model.score(X=self.X, y=self.y)
397412

398413
# get the model predictions of the observed data
399-
xi = np.linspace(
400-
np.min(self.data[self.running_variable_name]),
401-
np.max(self.data[self.running_variable_name]),
402-
1000,
403-
)
414+
if self.bandwidth is not None:
415+
xi = np.linspace(fmin, fmax, 200)
416+
else:
417+
xi = np.linspace(
418+
np.min(self.data[self.running_variable_name]),
419+
np.max(self.data[self.running_variable_name]),
420+
200,
421+
)
404422
self.x_pred = pd.DataFrame(
405423
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
406424
)

causalpy/tests/test_integration_pymc_examples.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,23 @@ def test_rd():
120120
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
121121

122122

123+
@pytest.mark.integration
124+
def test_rd_bandwidth():
125+
df = cp.load_data("rd")
126+
result = cp.pymc_experiments.RegressionDiscontinuity(
127+
df,
128+
formula="y ~ 1 + x + treated + x:treated",
129+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
130+
treatment_threshold=0.5,
131+
epsilon=0.001,
132+
bandwidth=0.3,
133+
)
134+
assert isinstance(df, pd.DataFrame)
135+
assert isinstance(result, cp.pymc_experiments.RegressionDiscontinuity)
136+
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
137+
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]
138+
139+
123140
@pytest.mark.integration
124141
def test_rd_drinking():
125142
df = (

causalpy/tests/test_integration_skl_examples.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,21 @@ def test_rd_linear_main_effects():
8888
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
8989

9090

91+
@pytest.mark.integration
92+
def test_rd_linear_main_effects_bandwidth():
93+
data = cp.load_data("rd")
94+
result = cp.skl_experiments.RegressionDiscontinuity(
95+
data,
96+
formula="y ~ 1 + x + treated",
97+
model=LinearRegression(),
98+
treatment_threshold=0.5,
99+
epsilon=0.001,
100+
bandwidth=0.3,
101+
)
102+
assert isinstance(data, pd.DataFrame)
103+
assert isinstance(result, cp.skl_experiments.RegressionDiscontinuity)
104+
105+
91106
@pytest.mark.integration
92107
def test_rd_linear_with_interaction():
93108
data = cp.load_data("rd")

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

docs/source/notebooks/rd_pymc.ipynb

Lines changed: 538 additions & 218 deletions
Large diffs are not rendered by default.

docs/source/notebooks/rd_skl.ipynb

Lines changed: 54 additions & 10 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)