Skip to content

Commit 65765cb

Browse files
authored
Merge pull request #145 from CITCOM-project/instrumental-variables
Basic instrumental variables support
2 parents be1c55e + c31b021 commit 65765cb

File tree

16 files changed

+292
-89
lines changed

16 files changed

+292
-89
lines changed

causal_testing/json_front/json_class.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,12 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
208208
treatment_var = causal_test_case.treatment_variable
209209
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
210210
estimation_model = estimator(
211-
(treatment_var.name,),
212-
causal_test_case.treatment_value,
213-
causal_test_case.control_value,
214-
minimal_adjustment_set,
215-
(causal_test_case.outcome_variable.name,),
216-
causal_test_engine.scenario_execution_data_df,
211+
treatment=treatment_var.name,
212+
treatment_value=causal_test_case.treatment_value,
213+
control_value=causal_test_case.control_value,
214+
adjustment_set=minimal_adjustment_set,
215+
outcome=causal_test_case.outcome_variable.name,
216+
df=causal_test_engine.scenario_execution_data_df,
217217
effect_modifiers=causal_test_case.effect_modifier_configuration,
218218
)
219219

causal_testing/specification/causal_dag.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,36 @@ def __init__(self, dot_path: str = None, **attr):
138138
if not self.is_acyclic():
139139
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
140140

141+
def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
142+
"""
143+
Checks the three instrumental variable assumptions, raising a
144+
ValueError if any are violated.
145+
146+
:return Boolean True if the three IV assumptions hold.
147+
"""
148+
# (i) Instrument is associated with treatment
149+
if nx.d_separated(self.graph, {instrument}, {treatment}, set()):
150+
raise ValueError(f"Instrument {instrument} is not associated with treatment {treatment} in the DAG")
151+
152+
# (ii) Instrument does not affect outcome except through its potential effect on treatment
153+
if not all([treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome)]):
154+
raise ValueError(
155+
f"Instrument {instrument} affects the outcome {outcome} other than through the treatment {treatment}"
156+
)
157+
158+
# (iii) Instrument and outcome do not share causes
159+
if any(
160+
[
161+
cause
162+
for cause in self.graph.nodes
163+
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
164+
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
165+
]
166+
):
167+
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
168+
169+
return True
170+
141171
def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
142172
"""Add an edge to the causal DAG.
143173

causal_testing/testing/causal_test_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
8989
treatment_value = test.treatment_value
9090
control_value = test.control_value
9191
estimator = estimator_class(
92-
(treatment_variable.name,),
92+
treatment_variable.name,
9393
treatment_value,
9494
control_value,
9595
minimal_adjustment_set,
96-
(test.outcome_variable.name,),
96+
test.outcome_variable.name,
9797
)
9898
if estimator.df is None:
9999
estimator.df = self.scenario_execution_data_df

causal_testing/testing/causal_test_result.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,22 +74,22 @@ def to_dict(self):
7474
"adjustment_set": self.adjustment_set,
7575
"test_value": self.test_value,
7676
}
77-
if self.confidence_intervals:
77+
if self.confidence_intervals and all(self.confidence_intervals):
7878
base_dict["ci_low"] = min(self.confidence_intervals)
7979
base_dict["ci_high"] = max(self.confidence_intervals)
8080
return base_dict
8181

8282
def ci_low(self):
8383
"""Return the lower bracket of the confidence intervals."""
84-
if not self.confidence_intervals:
85-
return None
86-
return min(self.confidence_intervals)
84+
if self.confidence_intervals and all(self.confidence_intervals):
85+
return min(self.confidence_intervals)
86+
return None
8787

8888
def ci_high(self):
8989
"""Return the higher bracket of the confidence intervals."""
90-
if not self.confidence_intervals:
91-
return None
92-
return max(self.confidence_intervals)
90+
if self.confidence_intervals and all(self.confidence_intervals):
91+
return max(self.confidence_intervals)
92+
return None
9393

9494
def summary(self):
9595
"""Summarise the causal test result as an intuitive sentence."""

causal_testing/testing/estimators.py

Lines changed: 84 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ class Estimator(ABC):
3636

3737
def __init__(
3838
self,
39-
treatment: tuple,
39+
treatment: str,
4040
treatment_value: float,
4141
control_value: float,
4242
adjustment_set: set,
43-
outcome: tuple,
43+
outcome: str,
4444
df: pd.DataFrame = None,
4545
effect_modifiers: dict[Variable:Any] = None,
4646
):
@@ -93,11 +93,11 @@ class LogisticRegressionEstimator(Estimator):
9393

9494
def __init__(
9595
self,
96-
treatment: tuple,
96+
treatment: str,
9797
treatment_value: float,
9898
control_value: float,
9999
adjustment_set: set,
100-
outcome: tuple,
100+
outcome: str,
101101
df: pd.DataFrame = None,
102102
effect_modifiers: dict[Variable:Any] = None,
103103
intercept: int = 1,
@@ -133,20 +133,20 @@ def _run_logistic_regression(self, data) -> RegressionResultsWrapper:
133133
"""
134134
# 1. Reduce dataframe to contain only the necessary columns
135135
reduced_df = data.copy()
136-
necessary_cols = list(self.treatment) + list(self.adjustment_set) + list(self.outcome)
136+
necessary_cols = [self.treatment] + list(self.adjustment_set) + [self.outcome]
137137
missing_rows = reduced_df[necessary_cols].isnull().any(axis=1)
138138
reduced_df = reduced_df[~missing_rows]
139-
reduced_df = reduced_df.sort_values(list(self.treatment))
139+
reduced_df = reduced_df.sort_values([self.treatment])
140140
logger.debug(reduced_df[necessary_cols])
141141

142142
# 2. Add intercept
143143
reduced_df["Intercept"] = self.intercept
144144

145145
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
146-
cols = list(self.treatment)
146+
cols = [self.treatment]
147147
cols += [x for x in self.adjustment_set if x not in cols]
148148
treatment_and_adjustments_cols = reduced_df[cols + ["Intercept"]]
149-
outcome_col = reduced_df[list(self.outcome)]
149+
outcome_col = reduced_df[[self.outcome]]
150150
for col in treatment_and_adjustments_cols:
151151
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
152152
treatment_and_adjustments_cols = pd.get_dummies(
@@ -165,7 +165,7 @@ def estimate(self, data: pd.DataFrame) -> RegressionResultsWrapper:
165165
self.model = model
166166

167167
x = pd.DataFrame()
168-
x[self.treatment[0]] = [self.treatment_value, self.control_value]
168+
x[self.treatment] = [self.treatment_value, self.control_value]
169169
x["Intercept"] = self.intercept
170170
for k, v in self.effect_modifiers.items():
171171
x[k] = v
@@ -238,7 +238,7 @@ def estimate_ate(self, bootstrap_size=100) -> float:
238238
ci_high = bootstraps[bootstrap_size - bound]
239239

240240
logger.info(
241-
f"Changing {self.treatment[0]} from {self.control_value} to {self.treatment_value} gives an estimated "
241+
f"Changing {self.treatment} from {self.control_value} to {self.treatment_value} gives an estimated "
242242
f"ATE of {ci_low} < {estimate} < {ci_high}"
243243
)
244244
assert ci_low < estimate < ci_high, f"Expecting {ci_low} < {estimate} < {ci_high}"
@@ -268,7 +268,7 @@ def estimate_risk_ratio(self, bootstrap_size=100) -> float:
268268
ci_high = bootstraps[bootstrap_size - bound]
269269

270270
logger.info(
271-
f"Changing {self.treatment[0]} from {self.control_value} to {self.treatment_value} gives an estimated "
271+
f"Changing {self.treatment} from {self.control_value} to {self.treatment_value} gives an estimated "
272272
f"risk ratio of {ci_low} < {estimate} < {ci_high}"
273273
)
274274
assert ci_low < estimate < ci_high, f"Expecting {ci_low} < {estimate} < {ci_high}"
@@ -282,7 +282,7 @@ def estimate_unit_odds_ratio(self) -> float:
282282
:return: The odds ratio. Confidence intervals are not yet supported.
283283
"""
284284
model = self._run_logistic_regression(self.df)
285-
return np.exp(model.params[self.treatment[0]])
285+
return np.exp(model.params[self.treatment])
286286

287287

288288
class LinearRegressionEstimator(Estimator):
@@ -292,11 +292,11 @@ class LinearRegressionEstimator(Estimator):
292292

293293
def __init__(
294294
self,
295-
treatment: tuple,
295+
treatment: str,
296296
treatment_value: float,
297297
control_value: float,
298298
adjustment_set: set,
299-
outcome: tuple,
299+
outcome: str,
300300
df: pd.DataFrame = None,
301301
effect_modifiers: dict[Variable:Any] = None,
302302
product_terms: list[tuple[Variable, Variable]] = None,
@@ -382,7 +382,7 @@ def estimate_unit_ate(self) -> float:
382382
:return: The unit average treatment effect and the 95% Wald confidence intervals.
383383
"""
384384
model = self._run_linear_regression()
385-
unit_effect = model.params[list(self.treatment)].values[0] # Unit effect is the coefficient of the treatment
385+
unit_effect = model.params[[self.treatment]].values[0] # Unit effect is the coefficient of the treatment
386386
[ci_low, ci_high] = self._get_confidence_intervals(model)
387387

388388
return unit_effect * self.treatment_value - unit_effect * self.control_value, [ci_low, ci_high]
@@ -406,8 +406,8 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
406406

407407
# It is ABSOLUTELY CRITICAL that these go last, otherwise we can't index
408408
# the effect with "ate = t_test_results.effect[0]"
409-
individuals.loc["control", list(self.treatment)] = self.control_value
410-
individuals.loc["treated", list(self.treatment)] = self.treatment_value
409+
individuals.loc["control", [self.treatment]] = self.control_value
410+
individuals.loc["treated", [self.treatment]] = self.treatment_value
411411

412412
# Perform a t-test to compare the predicted outcome of the control and treated individual (ATE)
413413
t_test_results = model.t_test(individuals.loc["treated"] - individuals.loc["control"])
@@ -428,7 +428,7 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
428428
self.model = model
429429

430430
x = pd.DataFrame()
431-
x[self.treatment[0]] = [self.treatment_value, self.control_value]
431+
x[self.treatment] = [self.treatment_value, self.control_value]
432432
x["Intercept"] = self.intercept
433433
for k, v in adjustment_config.items():
434434
x[k] = v
@@ -484,7 +484,7 @@ def estimate_cates(self) -> tuple[float, list[float, float]]:
484484
self.effect_modifiers
485485
), f"Must have at least one effect modifier to compute CATE - {self.effect_modifiers}."
486486
x = pd.DataFrame()
487-
x[self.treatment[0]] = [self.treatment_value, self.control_value]
487+
x[self.treatment] = [self.treatment_value, self.control_value]
488488
x["Intercept"] = self.intercept
489489
for k, v in self.effect_modifiers.items():
490490
self.adjustment_set.add(k)
@@ -510,20 +510,20 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
510510
"""
511511
# 1. Reduce dataframe to contain only the necessary columns
512512
reduced_df = self.df.copy()
513-
necessary_cols = list(self.treatment) + list(self.adjustment_set) + list(self.outcome)
513+
necessary_cols = [self.treatment] + list(self.adjustment_set) + [self.outcome]
514514
missing_rows = reduced_df[necessary_cols].isnull().any(axis=1)
515515
reduced_df = reduced_df[~missing_rows]
516-
reduced_df = reduced_df.sort_values(list(self.treatment))
516+
reduced_df = reduced_df.sort_values([self.treatment])
517517
logger.debug(reduced_df[necessary_cols])
518518

519519
# 2. Add intercept
520520
reduced_df["Intercept"] = self.intercept
521521

522522
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
523-
cols = list(self.treatment)
523+
cols = [self.treatment]
524524
cols += [x for x in self.adjustment_set if x not in cols]
525525
treatment_and_adjustments_cols = reduced_df[cols + ["Intercept"]]
526-
outcome_col = reduced_df[list(self.outcome)]
526+
outcome_col = reduced_df[[self.outcome]]
527527
for col in treatment_and_adjustments_cols:
528528
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
529529
treatment_and_adjustments_cols = pd.get_dummies(
@@ -536,12 +536,65 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
536536
def _get_confidence_intervals(self, model):
537537
confidence_intervals = model.conf_int(alpha=0.05, cols=None)
538538
ci_low, ci_high = (
539-
confidence_intervals[0][list(self.treatment)],
540-
confidence_intervals[1][list(self.treatment)],
539+
confidence_intervals[0][[self.treatment]],
540+
confidence_intervals[1][[self.treatment]],
541541
)
542542
return [ci_low.values[0], ci_high.values[0]]
543543

544544

545+
class InstrumentalVariableEstimator(Estimator):
546+
"""
547+
Carry out estimation using instrumental variable adjustment rather than conventional adjustment. This means we do
548+
not need to observe all confounders in order to adjust for them. A key assumption here is linearity.
549+
"""
550+
551+
def __init__(
552+
self,
553+
treatment: str,
554+
treatment_value: float,
555+
control_value: float,
556+
adjustment_set: set,
557+
outcome: str,
558+
instrument: str,
559+
df: pd.DataFrame = None,
560+
intercept: int = 1,
561+
effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility
562+
):
563+
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, None)
564+
self.intercept = intercept
565+
self.model = None
566+
self.instrument = instrument
567+
568+
def add_modelling_assumptions(self):
569+
"""
570+
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
571+
must hold if the resulting causal inference is to be considered valid.
572+
"""
573+
self.modelling_assumptions += """The instrument and the treatment, and the treatment and the outcome must be
574+
related linearly in the form Y = aX + b."""
575+
self.modelling_assumptions += """The three IV conditions must hold
576+
(i) Instrument is associated with treatment
577+
(ii) Instrument does not affect outcome except through its potential effect on treatment
578+
(iii) Instrument and outcome do not share causes
579+
"""
580+
581+
def estimate_coefficient(self):
582+
"""
583+
Estimate the linear regression coefficient of the treatment on the outcome.
584+
"""
585+
# Estimate the total effect of instrument I on outcome Y = abI + c1
586+
ab = sm.OLS(self.df[self.outcome], self.df[[self.instrument]]).fit().params[self.instrument]
587+
588+
# Estimate the direct effect of instrument I on treatment X = aI + c1
589+
a = sm.OLS(self.df[self.treatment], self.df[[self.instrument]]).fit().params[self.instrument]
590+
591+
# Estimate the coefficient of I on X by cancelling
592+
return ab / a
593+
594+
def estimate_ate(self):
595+
return (self.treatment_value - self.control_value) * self.estimate_coefficient(), (None, None)
596+
597+
545598
class CausalForestEstimator(Estimator):
546599
"""A causal random forest estimator is a non-parametric estimator which recursively partitions the covariate space
547600
to learn a low-dimensional representation of treatment effect heterogeneity. This form of estimator is best suited
@@ -563,7 +616,7 @@ def estimate_ate(self) -> float:
563616
"""
564617
# Remove any NA containing rows
565618
reduced_df = self.df.copy()
566-
necessary_cols = list(self.treatment) + list(self.adjustment_set) + list(self.outcome)
619+
necessary_cols = [self.treatment] + list(self.adjustment_set) + [self.outcome]
567620
missing_rows = reduced_df[necessary_cols].isnull().any(axis=1)
568621
reduced_df = reduced_df[~missing_rows]
569622

@@ -574,8 +627,8 @@ def estimate_ate(self) -> float:
574627
else:
575628
effect_modifier_df = reduced_df[list(self.adjustment_set)]
576629
confounders_df = reduced_df[list(self.adjustment_set)]
577-
treatment_df = np.ravel(reduced_df[list(self.treatment)])
578-
outcome_df = np.ravel(reduced_df[list(self.outcome)])
630+
treatment_df = np.ravel(reduced_df[[self.treatment]])
631+
outcome_df = np.ravel(reduced_df[[self.outcome]])
579632

580633
# Fit the model to the data using a gradient boosting regressor for both the treatment and outcome model
581634
model = CausalForestDML(
@@ -603,7 +656,7 @@ def estimate_cates(self) -> pd.DataFrame:
603656

604657
# Remove any NA containing rows
605658
reduced_df = self.df.copy()
606-
necessary_cols = list(self.treatment) + list(self.adjustment_set) + list(self.outcome)
659+
necessary_cols = [self.treatment] + list(self.adjustment_set) + [self.outcome]
607660
missing_rows = reduced_df[necessary_cols].isnull().any(axis=1)
608661
reduced_df = reduced_df[~missing_rows]
609662

@@ -617,8 +670,8 @@ def estimate_cates(self) -> pd.DataFrame:
617670
confounders_df = reduced_df[list(self.adjustment_set)]
618671
else:
619672
confounders_df = None
620-
treatment_df = reduced_df[list(self.treatment)]
621-
outcome_df = reduced_df[list(self.outcome)]
673+
treatment_df = reduced_df[[self.treatment]]
674+
outcome_df = reduced_df[[self.outcome]]
622675

623676
# Fit a model to the data
624677
model = CausalForestDML(model_y=GradientBoostingRegressor(), model_t=GradientBoostingRegressor())

examples/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.pdf

0 commit comments

Comments
 (0)