Skip to content

Commit 87bc91d

Browse files
Add example file using test_suites
1 parent 33f85bd commit 87bc91d

File tree

1 file changed

+188
-0
lines changed

1 file changed

+188
-0
lines changed
Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
import pandas as pd
2+
import numpy as np
3+
import matplotlib.pyplot as plt
4+
from causal_testing.specification.causal_dag import CausalDAG
5+
from causal_testing.specification.scenario import Scenario
6+
from causal_testing.specification.variable import Input, Output
7+
from causal_testing.specification.causal_specification import CausalSpecification
8+
from causal_testing.data_collection.data_collector import ObservationalDataCollector
9+
from causal_testing.testing.causal_test_case import CausalTestCase
10+
from causal_testing.testing.causal_test_outcome import Positive, Negative, NoEffect
11+
from causal_testing.testing.causal_test_engine import CausalTestEngine
12+
from causal_testing.testing.estimators import LinearRegressionEstimator
13+
from causal_testing.testing.base_causal_test import BaseCausalTest
14+
from matplotlib.pyplot import rcParams
15+
16+
# Uncommenting the code below will make all graphs publication quality but requires a suitable latex installation
17+
18+
# rc_fonts = {
19+
# "font.size": 8,
20+
# "figure.figsize": (5, 4),
21+
# "text.usetex": True,
22+
# "font.family": "serif",
23+
# "text.latex.preamble": r"\usepackage{libertine}",
24+
# }
25+
# rcParams.update(rc_fonts)
26+
OBSERVATIONAL_DATA_PATH = "./data/normalised_results.csv"
27+
28+
29+
def causal_testing_sensitivity_analysis():
30+
"""Perform causal testing to evaluate the effect of six conductance inputs on one output, APD90, over the defined
31+
(normalised) design distribution to quantify the extent to which each input affects the output, and plot as a
32+
graph.
33+
"""
34+
# Read in the 200 model runs and define mean value and expected effect
35+
model_runs = pd.read_csv("data/results.csv")
36+
conductance_means = {'G_K': (.5, Positive),
37+
'G_b': (.5, Positive),
38+
'G_K1': (.5, Positive),
39+
'G_si': (.5, Negative),
40+
'G_Na': (.5, NoEffect),
41+
'G_Kp': (.5, NoEffect)}
42+
43+
# Normalise the inputs as per the original study
44+
normalised_df = normalise_data(model_runs, columns=list(conductance_means.keys()))
45+
normalised_df.to_csv("data/normalised_results.csv")
46+
47+
# For each input, perform 10 causal tests that change the input from its mean value (0.5) to the equidistant values
48+
# [0, 0.1, 0.2, ..., 0.9, 1] over the input space of each input, as defined by the normalised design distribution.
49+
# For each input, this will yield 10 causal test results that measure the extent the input causes APD90 to change,
50+
# enabling us to compare the magnitude and direction of each inputs' effect.
51+
treatment_values = np.linspace(0, 1, 11)
52+
results = {'G_K': {},
53+
'G_b': {},
54+
'G_K1': {},
55+
'G_si': {},
56+
'G_Na': {},
57+
'G_Kp': {}}
58+
59+
apd90 = Output('APD90', int)
60+
outcome_variable = apd90
61+
test_suite = {}
62+
for conductance_param, mean_and_oracle in conductance_means.items():
63+
treatment_variable = Input(conductance_param, float)
64+
base_test_case = BaseCausalTest(treatment_variable, outcome_variable)
65+
test_list = []
66+
control_value = 0.5
67+
mean, oracle = mean_and_oracle
68+
for treatment_value in treatment_values:
69+
test_list.append(CausalTestCase(base_test_case, oracle, control_value, treatment_value))
70+
71+
test_suite[base_test_case] = {'tests': test_list,
72+
'estimators': [LinearRegressionEstimator],
73+
'estimate_type': "ate"}
74+
75+
causal_test_results = effects_on_APD90(OBSERVATIONAL_DATA_PATH, test_suite)
76+
77+
for base_test_case in causal_test_results:
78+
results[base_test_case.treatment_variable.name] = {"ate": [result.ate for result in causal_test_results[base_test_case][0]],
79+
"cis": [result.confidence_intervals for result in
80+
causal_test_results[base_test_case][0]]}
81+
82+
plot_ates_with_cis(results, treatment_values)
83+
84+
85+
def effects_on_APD90(observational_data_path, test_suite):
86+
"""Perform causal testing for the scenario in which we investigate the causal effect of a given input on APD90.
87+
88+
:param observational_data_path: Path to observational data containing previous executions of the LR91 model.
89+
:param treatment_var: The input variable whose effect on APD90 we are interested in.
90+
:param control_val: The control value for the treatment variable (before intervention).
91+
:param treatment_val: The treatment value for the treatment variable (after intervention).
92+
:param expected_causal_effect: The expected causal effect (Positive, Negative, No Effect).
93+
:return: ATE for the effect of G_K on APD90
94+
"""
95+
# 1. Define Causal DAG
96+
causal_dag = CausalDAG('./dag.dot')
97+
98+
# 2. Specify all inputs
99+
g_na = Input('G_Na', float)
100+
g_si = Input('G_si', float)
101+
g_k = Input('G_K', float)
102+
g_k1 = Input('G_K1', float)
103+
g_kp = Input('G_Kp', float)
104+
g_b = Input('G_b', float)
105+
106+
# 3. Specify all outputs
107+
max_voltage = Output('max_voltage', float)
108+
rest_voltage = Output('rest_voltage', float)
109+
max_voltage_gradient = Output('max_voltage_gradient', float)
110+
dome_voltage = Output('dome_voltage', float)
111+
apd50 = Output('APD50', int)
112+
apd90 = Output('APD90', int)
113+
114+
# 4. Create scenario by applying constraints over a subset of the inputs
115+
scenario = Scenario(
116+
variables={g_na, g_si, g_k, g_k1, g_kp, g_b,
117+
max_voltage, rest_voltage, max_voltage_gradient, dome_voltage, apd50, apd90},
118+
constraints=set()
119+
)
120+
121+
# 5. Create a causal specification from the scenario and causal DAG
122+
causal_specification = CausalSpecification(scenario, causal_dag)
123+
124+
# 7. Create a data collector
125+
data_collector = ObservationalDataCollector(scenario, observational_data_path)
126+
127+
# 8. Create an instance of the causal test engine
128+
causal_test_engine = CausalTestEngine(causal_specification, data_collector)
129+
130+
# 9. Obtain the minimal adjustment set from the causal DAG
131+
132+
# 10. Run the causal test and print results
133+
causal_test_results = causal_test_engine.execute_test_suite(test_suite)
134+
print(causal_test_results)
135+
return causal_test_results
136+
137+
138+
def plot_ates_with_cis(results_dict: dict, xs: list, save: bool = True):
139+
"""Plot the average treatment effects for a given treatment against a list of x-values with confidence intervals.
140+
141+
:param results_dict: A dictionary containing results for sensitivity analysis of each input parameter.
142+
:param xs: Values to be plotted on the x-axis.
143+
:param save: Whether to save the plot.
144+
"""
145+
fig, axes = plt.subplots()
146+
input_colors = {'G_Na': 'red',
147+
'G_si': 'green',
148+
'G_K': 'blue',
149+
'G_K1': 'magenta',
150+
'G_Kp': 'cyan',
151+
'G_b': 'yellow'}
152+
for treatment, test_results in results_dict.items():
153+
ates = test_results['ate']
154+
cis = test_results['cis']
155+
before_underscore, after_underscore = treatment.split('_')
156+
after_underscore_braces = f"{{{after_underscore}}}"
157+
latex_compatible_treatment_str = rf"${before_underscore}_{after_underscore_braces}$"
158+
cis_low = [c[0] for c in cis]
159+
cis_high = [c[1] for c in cis]
160+
axes.fill_between(xs, cis_low, cis_high, alpha=.2, color=input_colors[treatment],
161+
label=latex_compatible_treatment_str)
162+
axes.plot(xs, ates, color=input_colors[treatment], linewidth=1)
163+
axes.plot(xs, [0] * len(xs), color='black', alpha=.5, linestyle='--', linewidth=1)
164+
axes.set_ylabel(r"ATE: Change in $APD_{90} (ms)$")
165+
axes.set_xlabel(r"Treatment value")
166+
axes.set_ylim(-80, 80)
167+
axes.set_xlim(min(xs), max(xs))
168+
box = axes.get_position()
169+
axes.set_position([box.x0, box.y0 + box.height * 0.3,
170+
box.width * 0.85, box.height * 0.7])
171+
plt.legend(loc='center left', bbox_to_anchor=(1.01, 0.5), fancybox=True, ncol=1,
172+
title=r'Input (95\% CIs)')
173+
if save:
174+
plt.savefig(f"APD90_sensitivity.pdf", format="pdf")
175+
plt.show()
176+
177+
178+
def normalise_data(df, columns=None):
179+
""" Normalise the data in the dataframe into the range [0, 1]. """
180+
if columns:
181+
df[columns] = (df[columns] - df[columns].min()) / (df[columns].max() - df[columns].min())
182+
return df
183+
else:
184+
return (df - df.min()) / (df.max() - df.min())
185+
186+
187+
if __name__ == '__main__':
188+
causal_testing_sensitivity_analysis()

0 commit comments

Comments
 (0)