Skip to content

Commit e007b94

Browse files
committed
Add the epsilon kwarg into the sci-kit learn version of the RegressionDiscontinuity class
1 parent 9bda3d9 commit e007b94

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

causalpy/skl_experiments.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,21 @@ def plot(self):
346346

347347
class RegressionDiscontinuity(ExperimentalDesign):
348348
"""
349-
Analyse data from regression discontinuity experiments.
350-
351-
.. note::
352-
353-
There is no pre/post intervention data distinction for the regression
354-
discontinuity design, we fit all the data available.
355-
349+
A class to analyse regression discontinuity experiments.
350+
351+
:param data:
352+
A pandas dataframe
353+
:param formula:
354+
A statistical model formula
355+
:param treatment_threshold:
356+
A scalar threshold value at which the treatment is applied
357+
:param model:
358+
A sci-kit learn model object
359+
:param running_variable_name:
360+
The name of the predictor variable that the treatment threshold is based upon
361+
:param epsilon:
362+
A small scalar value which determines how far above and below the treatment
363+
threshold to evaluate the causal impact.
356364
"""
357365

358366
def __init__(
@@ -362,13 +370,15 @@ def __init__(
362370
treatment_threshold,
363371
model=None,
364372
running_variable_name="x",
373+
epsilon: float = 0.001,
365374
**kwargs,
366375
):
367376
super().__init__(model=model, **kwargs)
368377
self.data = data
369378
self.formula = formula
370379
self.running_variable_name = running_variable_name
371380
self.treatment_threshold = treatment_threshold
381+
self.epsilon = epsilon
372382
y, X = dmatrices(formula, self.data)
373383
self._y_design_info = y.design_info
374384
self._x_design_info = X.design_info
@@ -404,7 +414,10 @@ def __init__(
404414
self.x_discon = pd.DataFrame(
405415
{
406416
self.running_variable_name: np.array(
407-
[self.treatment_threshold - 0.001, self.treatment_threshold + 0.001]
417+
[
418+
self.treatment_threshold - self.epsilon,
419+
self.treatment_threshold + self.epsilon,
420+
]
408421
),
409422
"treated": np.array([0, 1]),
410423
}

0 commit comments

Comments
 (0)