Skip to content

Commit a316203

Browse files
committed
Merge branch 'instrumental-variables' into functional_form
2 parents cd7ae1f + ac10df3 commit a316203

File tree

15 files changed

+233
-72
lines changed

15 files changed

+233
-72
lines changed

causal_testing/json_front/json_class.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,12 @@ def _setup_test(self, causal_test_case: CausalTestCase, estimator: Estimator) ->
208208
treatment_var = causal_test_case.treatment_variable
209209
minimal_adjustment_set = minimal_adjustment_set - {treatment_var}
210210
estimation_model = estimator(
211-
(treatment_var.name,),
212-
causal_test_case.treatment_value,
213-
causal_test_case.control_value,
214-
minimal_adjustment_set,
215-
(causal_test_case.outcome_variable.name,),
216-
causal_test_engine.scenario_execution_data_df,
211+
treatment=treatment_var.name,
212+
treatment_value=causal_test_case.treatment_value,
213+
control_value=causal_test_case.control_value,
214+
adjustment_set=minimal_adjustment_set,
215+
outcome=causal_test_case.outcome_variable.name,
216+
df=causal_test_engine.scenario_execution_data_df,
217217
effect_modifiers=causal_test_case.effect_modifier_configuration,
218218
)
219219

causal_testing/specification/causal_dag.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,36 @@ def __init__(self, dot_path: str = None, **attr):
138138
if not self.is_acyclic():
139139
raise nx.HasACycle("Invalid Causal DAG: contains a cycle.")
140140

141+
def check_iv_assumptions(self, treatment, outcome, instrument) -> bool:
142+
"""
143+
Checks the three instrumental variable assumptions, raising a
144+
ValueError if any are violated.
145+
146+
:return Boolean True if the three IV assumptions hold.
147+
"""
148+
# (i) Instrument is associated with treatment
149+
if nx.d_separated(self.graph, {instrument}, {treatment}, set()):
150+
raise ValueError(f"Instrument {instrument} is not associated with treatment {treatment} in the DAG")
151+
152+
# (ii) Instrument does not affect outcome except through its potential effect on treatment
153+
if not all([treatment in path for path in nx.all_simple_paths(self.graph, source=instrument, target=outcome)]):
154+
raise ValueError(
155+
f"Instrument {instrument} affects the outcome {outcome} other than through the treatment {treatment}"
156+
)
157+
158+
# (iii) Instrument and outcome do not share causes
159+
if any(
160+
[
161+
cause
162+
for cause in self.graph.nodes
163+
if list(nx.all_simple_paths(self.graph, source=cause, target=instrument))
164+
and list(nx.all_simple_paths(self.graph, source=cause, target=outcome))
165+
]
166+
):
167+
raise ValueError(f"Instrument {instrument} and outcome {outcome} share common causes")
168+
169+
return True
170+
141171
def add_edge(self, u_of_edge: Node, v_of_edge: Node, **attr):
142172
"""Add an edge to the causal DAG.
143173

causal_testing/testing/causal_test_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ def execute_test_suite(self, test_suite: CausalTestSuite) -> list[CausalTestResu
8989
treatment_value = test.treatment_value
9090
control_value = test.control_value
9191
estimator = estimator_class(
92-
(treatment_variable.name,),
92+
treatment_variable.name,
9393
treatment_value,
9494
control_value,
9595
minimal_adjustment_set,
96-
(test.outcome_variable.name,),
96+
test.outcome_variable.name,
9797
)
9898
if estimator.df is None:
9999
estimator.df = self.scenario_execution_data_df

causal_testing/testing/causal_test_result.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,16 @@ def ci_low(self):
8383
"""Return the lower bracket of the confidence intervals."""
8484
if not self.confidence_intervals:
8585
return None
86+
if any([x is None for x in self.confidence_intervals]):
87+
return None
8688
return min(self.confidence_intervals)
8789

8890
def ci_high(self):
8991
"""Return the higher bracket of the confidence intervals."""
9092
if not self.confidence_intervals:
9193
return None
94+
if any([x is None for x in self.confidence_intervals]):
95+
return None
9296
return max(self.confidence_intervals)
9397

9498
def summary(self):

causal_testing/testing/estimators.py

Lines changed: 79 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -135,20 +135,20 @@ def _run_logistic_regression(self, data) -> RegressionResultsWrapper:
135135
"""
136136
# 1. Reduce dataframe to contain only the necessary columns
137137
reduced_df = data.copy()
138-
necessary_cols = list(self.treatment) + list(self.adjustment_set) + list(self.outcome)
138+
necessary_cols = [self.treatment] + list(self.adjustment_set) + [self.outcome]
139139
missing_rows = reduced_df[necessary_cols].isnull().any(axis=1)
140140
reduced_df = reduced_df[~missing_rows]
141-
reduced_df = reduced_df.sort_values(list(self.treatment))
141+
reduced_df = reduced_df.sort_values([self.treatment])
142142
logger.debug(reduced_df[necessary_cols])
143143

144144
# 2. Add intercept
145145
reduced_df["Intercept"] = self.intercept
146146

147147
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
148-
cols = list(self.treatment)
148+
cols = [self.treatment]
149149
cols += [x for x in self.adjustment_set if x not in cols]
150150
treatment_and_adjustments_cols = reduced_df[cols + ["Intercept"]]
151-
outcome_col = reduced_df[list(self.outcome)]
151+
outcome_col = reduced_df[[self.outcome]]
152152
for col in treatment_and_adjustments_cols:
153153
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
154154
treatment_and_adjustments_cols = pd.get_dummies(
@@ -167,7 +167,7 @@ def estimate(self, data: pd.DataFrame) -> RegressionResultsWrapper:
167167
self.model = model
168168

169169
x = pd.DataFrame()
170-
x[self.treatment[0]] = [self.treatment_value, self.control_value]
170+
x[self.treatment] = [self.treatment_value, self.control_value]
171171
x["Intercept"] = self.intercept
172172
for k, v in self.effect_modifiers.items():
173173
x[k] = v
@@ -240,7 +240,7 @@ def estimate_ate(self, bootstrap_size=100) -> float:
240240
ci_high = bootstraps[bootstrap_size - bound]
241241

242242
logger.info(
243-
f"Changing {self.treatment[0]} from {self.control_value} to {self.treatment_value} gives an estimated "
243+
f"Changing {self.treatment} from {self.control_value} to {self.treatment_value} gives an estimated "
244244
f"ATE of {ci_low} < {estimate} < {ci_high}"
245245
)
246246
assert ci_low < estimate < ci_high, f"Expecting {ci_low} < {estimate} < {ci_high}"
@@ -270,7 +270,7 @@ def estimate_risk_ratio(self, bootstrap_size=100) -> float:
270270
ci_high = bootstraps[bootstrap_size - bound]
271271

272272
logger.info(
273-
f"Changing {self.treatment[0]} from {self.control_value} to {self.treatment_value} gives an estimated "
273+
f"Changing {self.treatment} from {self.control_value} to {self.treatment_value} gives an estimated "
274274
f"risk ratio of {ci_low} < {estimate} < {ci_high}"
275275
)
276276
assert ci_low < estimate < ci_high, f"Expecting {ci_low} < {estimate} < {ci_high}"
@@ -284,7 +284,7 @@ def estimate_unit_odds_ratio(self) -> float:
284284
:return: The odds ratio. Confidence intervals are not yet supported.
285285
"""
286286
model = self._run_logistic_regression(self.df)
287-
return np.exp(model.params[self.treatment[0]])
287+
return np.exp(model.params[self.treatment])
288288

289289

290290
class LinearRegressionEstimator(Estimator):
@@ -385,7 +385,7 @@ def estimate_unit_ate(self) -> float:
385385
:return: The unit average treatment effect and the 95% Wald confidence intervals.
386386
"""
387387
model = self._run_linear_regression()
388-
unit_effect = model.params[list(self.treatment)].values[0] # Unit effect is the coefficient of the treatment
388+
unit_effect = model.params[[self.treatment]].values[0] # Unit effect is the coefficient of the treatment
389389
[ci_low, ci_high] = self._get_confidence_intervals(model)
390390

391391
return unit_effect * self.treatment_value - unit_effect * self.control_value, [ci_low, ci_high]
@@ -409,8 +409,8 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
409409

410410
# It is ABSOLUTELY CRITICAL that these go last, otherwise we can't index
411411
# the effect with "ate = t_test_results.effect[0]"
412-
individuals.loc["control", list(self.treatment)] = self.control_value
413-
individuals.loc["treated", list(self.treatment)] = self.treatment_value
412+
individuals.loc["control", [self.treatment]] = self.control_value
413+
individuals.loc["treated", [self.treatment]] = self.treatment_value
414414

415415
# Perform a t-test to compare the predicted outcome of the control and treated individual (ATE)
416416
t_test_results = model.t_test(individuals.loc["treated"] - individuals.loc["control"])
@@ -431,7 +431,7 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd
431431
self.model = model
432432

433433
x = pd.DataFrame()
434-
x[self.treatment[0]] = [self.treatment_value, self.control_value]
434+
x[self.treatment] = [self.treatment_value, self.control_value]
435435
x["Intercept"] = 1#self.intercept
436436
for k, v in adjustment_config.items():
437437
x[k] = v
@@ -487,7 +487,7 @@ def estimate_cates(self) -> tuple[float, list[float, float]]:
487487
self.effect_modifiers
488488
), f"Must have at least one effect modifier to compute CATE - {self.effect_modifiers}."
489489
x = pd.DataFrame()
490-
x[self.treatment[0]] = [self.treatment_value, self.control_value]
490+
x[self.treatment] = [self.treatment_value, self.control_value]
491491
x["Intercept"] = 1#self.intercept
492492
for k, v in self.effect_modifiers.items():
493493
self.adjustment_set.add(k)
@@ -513,20 +513,20 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
513513
"""
514514
# 1. Reduce dataframe to contain only the necessary columns
515515
reduced_df = self.df.copy()
516-
necessary_cols = list(self.treatment) + list(self.adjustment_set) + list(self.outcome)
516+
necessary_cols = [self.treatment] + list(self.adjustment_set) + [self.outcome]
517517
missing_rows = reduced_df[necessary_cols].isnull().any(axis=1)
518518
reduced_df = reduced_df[~missing_rows]
519-
reduced_df = reduced_df.sort_values(list(self.treatment))
519+
reduced_df = reduced_df.sort_values([self.treatment])
520520
logger.debug(reduced_df[necessary_cols])
521521

522522
# 2. Add intercept
523523
reduced_df["Intercept"] = 1#self.intercept
524524

525525
# 3. Estimate the unit difference in outcome caused by unit difference in treatment
526-
cols = list(self.treatment)
526+
cols = [self.treatment]
527527
cols += [x for x in self.adjustment_set if x not in cols]
528528
treatment_and_adjustments_cols = reduced_df[cols + ["Intercept"]]
529-
outcome_col = reduced_df[list(self.outcome)]
529+
outcome_col = reduced_df[[self.outcome]]
530530
for col in treatment_and_adjustments_cols:
531531
if str(treatment_and_adjustments_cols.dtypes[col]) == "object":
532532
treatment_and_adjustments_cols = pd.get_dummies(
@@ -539,12 +539,66 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
539539
def _get_confidence_intervals(self, model):
540540
confidence_intervals = model.conf_int(alpha=0.05, cols=None)
541541
ci_low, ci_high = (
542-
confidence_intervals[0][list(self.treatment)],
543-
confidence_intervals[1][list(self.treatment)],
542+
confidence_intervals[0][[self.treatment]],
543+
confidence_intervals[1][[self.treatment]],
544544
)
545545
return [ci_low.values[0], ci_high.values[0]]
546546

547547

548+
class InstrumentalVariableEstimator(Estimator):
549+
"""
550+
Carry out estimation using instrumental variable adjustment rather than conventional adjustment. This means we do
551+
not need to observe all confounders in order to adjust for them. A key assumption here is linearity.
552+
"""
553+
554+
def __init__(
555+
self,
556+
treatment: str,
557+
treatment_value: float,
558+
control_value: float,
559+
adjustment_set: set,
560+
outcome: str,
561+
instrument: str,
562+
df: pd.DataFrame = None,
563+
intercept: int = 1,
564+
effect_modifiers: dict=None # Not used (yet?). Needed for compatibility
565+
):
566+
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, None)
567+
self.intercept = intercept
568+
self.model = None
569+
self.instrument = instrument
570+
571+
572+
def add_modelling_assumptions(self):
573+
"""
574+
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
575+
must hold if the resulting causal inference is to be considered valid.
576+
"""
577+
self.modelling_assumptions += """The instrument and the treatment, and the treatment and the outcome must be
578+
related linearly in the form Y = aX + b."""
579+
self.modelling_assumptions += """The three IV conditions must hold
580+
(i) Instrument is associated with treatment
581+
(ii) Instrument does not affect outcome except through its potential effect on treatment
582+
(iii) Instrument and outcome do not share causes
583+
"""
584+
585+
def estimate_coefficient(self):
586+
"""
587+
Estimate the linear regression coefficient of the treatment on the outcome.
588+
"""
589+
# Estimate the total effect of instrument I on outcome Y = abI + c1
590+
ab = sm.OLS(self.df[self.outcome], self.df[[self.instrument]]).fit().params[self.instrument]
591+
592+
# Estimate the direct effect of instrument I on treatment X = aI + c1
593+
a = sm.OLS(self.df[self.treatment], self.df[[self.instrument]]).fit().params[self.instrument]
594+
595+
# Estimate the coefficient of I on X by cancelling
596+
return ab / a
597+
598+
def estimate_ate(self):
599+
return (self.treatment_value - self.control_value) * self.estimate_coefficient(), (None, None)
600+
601+
548602
class CausalForestEstimator(Estimator):
549603
"""A causal random forest estimator is a non-parametric estimator which recursively partitions the covariate space
550604
to learn a low-dimensional representation of treatment effect heterogeneity. This form of estimator is best suited
@@ -566,7 +620,7 @@ def estimate_ate(self) -> float:
566620
"""
567621
# Remove any NA containing rows
568622
reduced_df = self.df.copy()
569-
necessary_cols = list(self.treatment) + list(self.adjustment_set) + list(self.outcome)
623+
necessary_cols = [self.treatment] + list(self.adjustment_set) + [self.outcome]
570624
missing_rows = reduced_df[necessary_cols].isnull().any(axis=1)
571625
reduced_df = reduced_df[~missing_rows]
572626

@@ -577,8 +631,8 @@ def estimate_ate(self) -> float:
577631
else:
578632
effect_modifier_df = reduced_df[list(self.adjustment_set)]
579633
confounders_df = reduced_df[list(self.adjustment_set)]
580-
treatment_df = np.ravel(reduced_df[list(self.treatment)])
581-
outcome_df = np.ravel(reduced_df[list(self.outcome)])
634+
treatment_df = np.ravel(reduced_df[[self.treatment]])
635+
outcome_df = np.ravel(reduced_df[[self.outcome]])
582636

583637
# Fit the model to the data using a gradient boosting regressor for both the treatment and outcome model
584638
model = CausalForestDML(
@@ -606,7 +660,7 @@ def estimate_cates(self) -> pd.DataFrame:
606660

607661
# Remove any NA containing rows
608662
reduced_df = self.df.copy()
609-
necessary_cols = list(self.treatment) + list(self.adjustment_set) + list(self.outcome)
663+
necessary_cols = [self.treatment] + list(self.adjustment_set) + [self.outcome]
610664
missing_rows = reduced_df[necessary_cols].isnull().any(axis=1)
611665
reduced_df = reduced_df[~missing_rows]
612666

@@ -620,8 +674,8 @@ def estimate_cates(self) -> pd.DataFrame:
620674
confounders_df = reduced_df[list(self.adjustment_set)]
621675
else:
622676
confounders_df = None
623-
treatment_df = reduced_df[list(self.treatment)]
624-
outcome_df = reduced_df[list(self.outcome)]
677+
treatment_df = reduced_df[[self.treatment]]
678+
outcome_df = reduced_df[[self.outcome]]
625679

626680
# Fit a model to the data
627681
model = CausalForestDML(model_y=GradientBoostingRegressor(), model_t=GradientBoostingRegressor())

examples/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.pdf

examples/covasim_/doubling_beta/causal_test_beta.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,19 @@ def doubling_beta_CATE_on_csv(observational_data_path: str, simulate_counterfact
4747
past_execution_df = pd.read_csv(observational_data_path)
4848
_, causal_test_engine, causal_test_case = engine_setup(observational_data_path)
4949

50-
linear_regression_estimator = LinearRegressionEstimator(('beta',), 0.032, 0.016,
50+
linear_regression_estimator = LinearRegressionEstimator('beta', 0.032, 0.016,
5151
{'avg_age', 'contacts'}, # We use custom adjustment set
52-
('cum_infections',),
52+
'cum_infections',
5353
df=past_execution_df)
5454

5555
# Add squared terms for beta, since it has a quadratic relationship with cumulative infections
5656
linear_regression_estimator.add_squared_term_to_df('beta')
5757
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, 'ate')
5858

5959
# Repeat for association estimate (no adjustment)
60-
no_adjustment_linear_regression_estimator = LinearRegressionEstimator(('beta',), 0.032, 0.016,
60+
no_adjustment_linear_regression_estimator = LinearRegressionEstimator('beta', 0.032, 0.016,
6161
set(),
62-
('cum_infections',),
62+
'cum_infections',
6363
df=past_execution_df)
6464
no_adjustment_linear_regression_estimator.add_squared_term_to_df('beta')
6565
association_test_result = causal_test_engine.execute_test(no_adjustment_linear_regression_estimator, causal_test_case, 'ate')
@@ -79,9 +79,9 @@ def doubling_beta_CATE_on_csv(observational_data_path: str, simulate_counterfact
7979
# Repeat causal inference after deleting all rows with treatment value to obtain counterfactual inferences
8080
if simulate_counterfactuals:
8181
counterfactual_past_execution_df = past_execution_df[past_execution_df['beta'] != 0.032]
82-
counterfactual_linear_regression_estimator = LinearRegressionEstimator(('beta',), 0.032, 0.016,
82+
counterfactual_linear_regression_estimator = LinearRegressionEstimator('beta', 0.032, 0.016,
8383
{'avg_age', 'contacts'},
84-
('cum_infections',),
84+
'cum_infections',
8585
df=counterfactual_past_execution_df)
8686
counterfactual_linear_regression_estimator.add_squared_term_to_df('beta')
8787
counterfactual_causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, 'ate')

examples/covasim_/vaccinating_elderly/causal_test_vaccine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def experimental_causal_test_vaccinate_elderly(runs_per_test_per_config: int = 3
8585
minimal_adjustment_set = causal_dag.identification(base_test_case)
8686

8787
# 9. Build statistical model
88-
linear_regression_estimator = LinearRegressionEstimator((vaccine.name,), 1, 0,
88+
linear_regression_estimator = LinearRegressionEstimator(vaccine.name, 1, 0,
8989
minimal_adjustment_set,
90-
(outcome_variable.name,))
90+
outcome_variable.name)
9191

9292
# 10. Execute test and save results in dict
9393
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, causal_test_case, 'ate')

0 commit comments

Comments
 (0)