Skip to content

Commit db004c5

Browse files
committed
Instrumental variable estimation
1 parent 1e030d8 commit db004c5

File tree

2 files changed

+82
-0
lines changed

2 files changed

+82
-0
lines changed

causal_testing/testing/estimators.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,59 @@ def _get_confidence_intervals(self, model):
543543
return [ci_low.values[0], ci_high.values[0]]
544544

545545

546+
class InstrumentalVariableEstimator(Estimator):
547+
"""
548+
Carry out estimation using instrumental variable adjustment rather than conventional adjustment. This means we do
549+
not need to observe all confounders in order to adjust for them. A key assumption here is linearity.
550+
"""
551+
552+
def __init__(
553+
self,
554+
treatment: tuple,
555+
treatment_value: float,
556+
control_value: float,
557+
adjustment_set: set,
558+
outcome: tuple,
559+
instrument: str,
560+
df: pd.DataFrame = None,
561+
intercept: int = 1,
562+
):
563+
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, None)
564+
self.intercept = intercept
565+
self.model = None
566+
self.instrument = instrument
567+
568+
def add_modelling_assumptions(self):
569+
"""
570+
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
571+
must hold if the resulting causal inference is to be considered valid.
572+
"""
573+
self.modelling_assumptions += """The instrument and the treatment, and the treatment and the outcome must be
574+
related linearly in the form Y = aX + b."""
575+
self.modelling_assumptions += """The three IV conditions must hold
576+
(i) Instrument is associated with treatment
577+
(ii) Instrument does not affect outcome except through its potential effect on treatment
578+
(iii) Instrument and outcome do not share causes
579+
"""
580+
581+
def estimate_coefficient(self):
582+
"""
583+
Estimate the linear regression coefficient of the treatment on the outcome.
584+
"""
585+
586+
# Estimate the total effect of instrument I on outcome Y = abI + c1
587+
ab = sm.OLS(self.df[self.outcome], self.df[[self.instrument]]).fit().params[self.instrument]
588+
589+
# Estimate the direct effect of instrument I on treatment X = aI + c1
590+
a = sm.OLS(self.df[self.treatment], self.df[[self.instrument]]).fit().params[self.instrument]
591+
592+
# Estimate the coefficient of I on X by cancelling
593+
return ab / a
594+
595+
def estimate_ate(self):
596+
return (self.treatment_value - self.control_value) * self.estimate_coefficient(), (None, None)
597+
598+
546599
class CausalForestEstimator(Estimator):
547600
"""A causal random forest estimator is a non-parametric estimator which recursively partitions the covariate space
548601
to learn a low-dimensional representation of treatment effect heterogeneity. This form of estimator is best suited

tests/testing_tests/test_estimators.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
LinearRegressionEstimator,
77
CausalForestEstimator,
88
LogisticRegressionEstimator,
9+
InstrumentalVariableEstimator,
910
)
1011
from causal_testing.specification.variable import Input
1112

@@ -110,6 +111,34 @@ def test_odds_ratio(self):
110111
self.assertEqual(round(odds, 4), 0.8948)
111112

112113

114+
class TestInstrumentalVariableEstimator(unittest.TestCase):
115+
"""
116+
Test the instrumental variable estimator.
117+
"""
118+
119+
@classmethod
120+
def setUpClass(cls) -> None:
121+
Z = np.linspace(0, 10)
122+
X = 2 * Z
123+
Y = 2 * X
124+
cls.df = pd.DataFrame({"Z": Z, "X": X, "Y": Y})
125+
126+
def test_estimate_coefficient(self):
127+
"""
128+
Test we get the correct coefficient.
129+
"""
130+
iv_estimator = InstrumentalVariableEstimator(
131+
treatment="X",
132+
treatment_value=None,
133+
control_value=None,
134+
adjustment_set=set(),
135+
outcome="Y",
136+
instrument="Z",
137+
df=self.df,
138+
)
139+
self.assertEqual(iv_estimator.estimate_coefficient(), 2)
140+
141+
113142
class TestLinearRegressionEstimator(unittest.TestCase):
114143
"""Test the linear regression estimator against the programming exercises in Section 2 of Hernán and Robins [1].
115144

0 commit comments

Comments
 (0)