Skip to content

Commit 06f72c4

Browse files
author
Andrew Clark
authored
Merge pull request #57 from CITCOM-project/lr91_example
Merge case study progress as examples
2 parents 4a45aa0 + 877c3ef commit 06f72c4

File tree

11 files changed

+10943
-60
lines changed

11 files changed

+10943
-60
lines changed

causal_testing/testing/causal_test_engine.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ def execute_test(self, estimator: Estimator, estimate_type: str = 'ate') -> Caus
9292
"""
9393
if self.scenario_execution_data_df.empty:
9494
raise Exception('No data has been loaded. Please call load_data prior to executing a causal test case.')
95+
if estimator.df is None:
96+
estimator.df = self.scenario_execution_data_df
9597
treatments = [v.name for v in self.treatment_variables]
96-
print("treatment_variables", self.treatment_variables)
9798
outcomes = [v.name for v in self.causal_test_case.outcome_variables]
9899
minimal_adjustment_sets = self.casual_dag.enumerate_minimal_adjustment_sets(treatments, outcomes)
99100
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
@@ -178,8 +179,6 @@ def _check_positivity_violation(self, variables_list):
178179
:return: True if positivity is violated, False otherwise.
179180
"""
180181
# TODO: Improve positivity checks to look for stratum-specific violations, not just missing variables in df
181-
print("variables_list", variables_list)
182-
print("columns", self.scenario_execution_data_df.columns)
183182
if not set(variables_list).issubset(self.scenario_execution_data_df.columns):
184183
missing_variables = set(variables_list) - set(self.scenario_execution_data_df.columns)
185184
logger.warning(f'Positivity violation: missing data for variables {missing_variables}.\n'

causal_testing/testing/estimators.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,14 @@ def __init__(self, treatment: tuple, treatment_values: float, control_values: fl
3737
self.adjustment_set = adjustment_set
3838
self.outcome = outcome
3939
self.df = df
40-
self.effect_modifiers = {k.name: v for k, v in effect_modifiers.items()} if effect_modifiers else dict()
40+
if effect_modifiers is None:
41+
self.effect_modifiers = dict()
42+
elif isinstance(effect_modifiers, set) or isinstance(effect_modifiers, list):
43+
self.effect_modifiers = {k.name for k in effect_modifiers}
44+
elif isinstance(effect_modifiers, dict):
45+
self.effect_modifiers = {k.name: v for k, v in effect_modifiers.items()}
46+
else:
47+
raise ValueError(f"Unsupported type for effect_modifiers {effect_modifiers}. Expected iterable")
4148
self.modelling_assumptions = []
4249
logger.debug("Effect Modifiers: %s", self.effect_modifiers)
4350

@@ -140,9 +147,14 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
140147
"""
141148
model = self._run_linear_regression()
142149
# Create an empty individual for the control and treated
143-
individuals = pd.DataFrame(0, index=['control', 'treated'], columns=model.params.index)
150+
individuals = pd.DataFrame(1, index=['control', 'treated'], columns=model.params.index)
144151
individuals.loc['control', list(self.treatment)] = self.control_values
145152
individuals.loc['treated', list(self.treatment)] = self.treatment_values
153+
# This is a temporary hack
154+
for t in self.square_terms:
155+
individuals[t+'^2'] = individuals[t] ** 2
156+
for a, b in self.product_terms:
157+
individuals[f"{a}*{b}"] = individuals[a] * individuals[b]
146158

147159
# Perform a t-test to compare the predicted outcome of the control and treated individual (ATE)
148160
t_test_results = model.t_test(individuals.loc['treated'] - individuals.loc['control'])

examples/covasim_/causal_test_beta.py

Lines changed: 332 additions & 51 deletions
Large diffs are not rendered by default.

examples/covasim_/dag.dot

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@ digraph CausalDAG {
22
rankdir=LR;
33
"variants" -> "beta";
44
"beta" -> "cum_infections";
5-
"beta" -> "cum_deaths";
6-
"beta" -> "cum_recoveries";
75
"location" -> "variants";
86
"location" -> "avg_age";
7+
"location" -> "contacts";
8+
"contacts" -> "cum_infections";
99
"avg_age" -> "cum_infections";
10-
"avg_age" -> "cum_deaths";
11-
"avg_age" -> "cum_recoveries";
1210
}

examples/covasim_/dag.png

-11.6 KB
Loading

examples/covasim_/data/10k_observational_data.csv

Lines changed: 10001 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import pandas as pd
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
from causal_testing.specification.causal_dag import CausalDAG
5+
from causal_testing.specification.scenario import Scenario
6+
from causal_testing.specification.variable import Input, Output
7+
from causal_testing.specification.causal_specification import CausalSpecification
8+
from causal_testing.data_collection.data_collector import ObservationalDataCollector
9+
from causal_testing.testing.causal_test_case import CausalTestCase
10+
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
11+
from causal_testing.testing.intervention import Intervention
12+
from causal_testing.testing.causal_test_engine import CausalTestEngine
13+
from causal_testing.testing.estimators import LinearRegressionEstimator
14+
from matplotlib.pyplot import rcParams
15+
OBSERVATIONAL_DATA_PATH = "./data/normalised_results.csv"
16+
17+
rc_fonts = {
18+
"font.size": 8,
19+
"figure.figsize": (6, 4),
20+
"text.usetex": True,
21+
"font.family": "serif",
22+
"text.latex.preamble": r"\usepackage{libertine}",
23+
}
24+
rcParams.update(rc_fonts)
25+
26+
27+
def effects_on_APD90(observational_data_path, treatment_var, control_val, treatment_val, expected_causal_effect):
28+
""" Perform causal testing for the scenario in which we investigate the causal effect of G_K on APD90.
29+
30+
:param observational_data_path: Path to observational data containing previous executions of the LR91 model.
31+
:param treatment_var: The input variable whose effect on APD90 we are interested in.
32+
:param control_val: The control value for the treatment variable (before intervention).
33+
:param treatment_val: The treatment value for the treatment variable (after intervention).
34+
:return: ATE for the effect of G_K on APD90
35+
"""
36+
# 1. Define Causal DAG
37+
causal_dag = CausalDAG('./dag.dot')
38+
39+
# 2. Specify all inputs
40+
g_na = Input('G_Na', float)
41+
g_si = Input('G_si', float)
42+
g_k = Input('G_K', float)
43+
g_k1 = Input('G_K1', float)
44+
g_kp = Input('G_Kp', float)
45+
g_b = Input('G_b', float)
46+
47+
# 3. Specify all outputs
48+
max_voltage = Output('max_voltage', float)
49+
rest_voltage = Output('rest_voltage', float)
50+
max_voltage_gradient = Output('max_voltage_gradient', float)
51+
dome_voltage = Output('dome_voltage', float)
52+
apd50 = Output('APD50', int)
53+
apd90 = Output('APD90', int)
54+
55+
# 3. Create scenario by applying constraints over a subset of the inputs
56+
scenario = Scenario(
57+
variables={g_na, g_si, g_k, g_k1, g_kp, g_b,
58+
max_voltage, rest_voltage, max_voltage_gradient, dome_voltage, apd50, apd90},
59+
constraints=set()
60+
)
61+
62+
# 4. Create a causal specification from the scenario and causal DAG
63+
causal_specification = CausalSpecification(scenario, causal_dag)
64+
65+
# 5. Create a causal test case
66+
causal_test_case = CausalTestCase(control_input_configuration={treatment_var: control_val},
67+
expected_causal_effect=expected_causal_effect,
68+
outcome_variables={apd90},
69+
intervention=Intervention((treatment_var,), (treatment_val,), ), )
70+
71+
# 6. Create a data collector
72+
data_collector = ObservationalDataCollector(scenario, observational_data_path)
73+
74+
# 7. Create an instance of the causal test engine
75+
causal_test_engine = CausalTestEngine(causal_test_case, causal_specification, data_collector)
76+
77+
# 8. Obtain the minimal adjustment set from the causal DAG
78+
minimal_adjustment_set = causal_test_engine.load_data(index_col=0)
79+
80+
linear_regression_estimator = LinearRegressionEstimator((treatment_var.name,), treatment_val, control_val,
81+
minimal_adjustment_set,
82+
('APD90',),
83+
)
84+
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, 'ate')
85+
print(causal_test_result)
86+
return causal_test_result.ate, causal_test_result.confidence_intervals
87+
88+
89+
def plot_ates_with_cis(results_dict, xs, save=True):
90+
"""
91+
Plot the average treatment effects for a given treatment against a list of x-values with confidence intervals.
92+
93+
:param results_dict: A dictionary containing results for sensitivity analysis of each input parameter.
94+
:param xs: Values to be plotted on the x-axis.
95+
"""
96+
fig, axes = plt.subplots()
97+
for treatment, results in results_dict.items():
98+
ates = results['ate']
99+
cis = results['cis']
100+
before_underscore, after_underscore = treatment.split('_')
101+
after_underscore_braces = f"{{{after_underscore}}}"
102+
latex_compatible_treatment_str = rf"${before_underscore}_{after_underscore_braces}$"
103+
cis_low = [ci[0] for ci in cis]
104+
cis_high = [ci[1] for ci in cis]
105+
106+
axes.fill_between(xs, cis_low, cis_high, alpha=.3, label=latex_compatible_treatment_str)
107+
axes.plot(xs, ates, color='black', linewidth=.5)
108+
axes.plot(xs, [0] * len(xs), color='red', alpha=.5, linestyle='--', linewidth=.5)
109+
axes.set_ylabel(r"ATE: Change in $APD_{90} (ms)$")
110+
axes.set_xlabel(r"Treatment value")
111+
axes.set_ylim(-150, 150)
112+
axes.set_xlim(min(xs), max(xs))
113+
box = axes.get_position()
114+
axes.set_position([box.x0, box.y0 + box.height * 0.3,
115+
box.width, box.height * 0.7])
116+
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), fancybox=True, ncol=6, title='Input Parameters')
117+
if save:
118+
plt.savefig(f"APD90_sensitivity.pdf", format="pdf")
119+
plt.show()
120+
121+
122+
def normalise_data(df, columns=None):
123+
""" Normalise the data in the dataframe into the range [0, 1]. """
124+
if columns:
125+
df[columns] = (df[columns] - df[columns].min())/(df[columns].max() - df[columns].min())
126+
return df
127+
else:
128+
return (df - df.min())/(df.max() - df.min())
129+
130+
131+
if __name__ == '__main__':
132+
df = pd.read_csv("data/results.csv")
133+
conductance_means = {'G_K': (.5, Positive),
134+
'G_b': (.5, Positive),
135+
'G_K1': (.5, Positive),
136+
'G_si': (.5, Negative),
137+
'G_Na': (.5, NoEffect),
138+
'G_Kp': (.5, NoEffect)}
139+
normalised_df = normalise_data(df, columns=list(conductance_means.keys()))
140+
normalised_df.to_csv("data/normalised_results.csv")
141+
142+
treatment_values = np.linspace(0, 1, 20)
143+
results_dict = {'G_K': {},
144+
'G_b': {},
145+
'G_K1': {},
146+
'G_si': {},
147+
'G_Na': {},
148+
'G_Kp': {}}
149+
for conductance_param, mean_and_oracle in conductance_means.items():
150+
average_treatment_effects = []
151+
confidence_intervals = []
152+
for treatment_value in treatment_values:
153+
mean, oracle = mean_and_oracle
154+
conductance_input = Input(conductance_param, float)
155+
ate, cis = effects_on_APD90(OBSERVATIONAL_DATA_PATH, conductance_input, 0, treatment_value, oracle)
156+
average_treatment_effects.append(ate)
157+
confidence_intervals.append(cis)
158+
results_dict[conductance_param] = {"ate": average_treatment_effects, "cis": confidence_intervals}
159+
plot_ates_with_cis(results_dict, treatment_values)

examples/lr91/dag.dot

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
digraph DAG {
2+
rankdir=LR;
3+
G_Na -> max_voltage;
4+
G_Na -> max_voltage_gradient;
5+
G_Na -> dome_voltage;
6+
G_Na -> APD50;
7+
8+
G_si -> rest_voltage;
9+
G_si -> dome_voltage;
10+
G_si -> APD50;
11+
G_si -> APD90;
12+
13+
G_K -> dome_voltage;
14+
G_K -> APD50;
15+
G_K -> APD90;
16+
17+
G_K1 -> max_voltage;
18+
G_K1 -> rest_voltage;
19+
G_K1 -> dome_voltage;
20+
G_K1 -> APD50;
21+
G_K1 -> APD90;
22+
23+
G_Kp -> dome_voltage;
24+
G_Kp -> APD50;
25+
26+
G_b -> max_voltage;
27+
G_b -> rest_voltage;
28+
G_b -> dome_voltage;
29+
G_b -> APD50;
30+
G_b -> APD90;
31+
}

examples/lr91/dag.png

72.2 KB
Loading

0 commit comments

Comments
 (0)