Skip to content

Commit ff34ff9

Browse files
author
AndrewC19
committed
Refactored sensitivity analysis case study
1 parent de524af commit ff34ff9

File tree

1 file changed

+86
-56
lines changed

1 file changed

+86
-56
lines changed

examples/lr91/causal_test_max_conductances.py

Lines changed: 86 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,73 @@
1212
from causal_testing.testing.causal_test_engine import CausalTestEngine
1313
from causal_testing.testing.estimators import LinearRegressionEstimator
1414
from matplotlib.pyplot import rcParams
15-
OBSERVATIONAL_DATA_PATH = "./data/normalised_results.csv"
1615

16+
# Make figures publication quality
1717
rc_fonts = {
1818
"font.size": 8,
19-
"figure.figsize": (6, 4),
19+
"figure.figsize": (5, 4),
2020
"text.usetex": True,
2121
"font.family": "serif",
2222
"text.latex.preamble": r"\usepackage{libertine}",
2323
}
2424
rcParams.update(rc_fonts)
25+
OBSERVATIONAL_DATA_PATH = "./data/normalised_results.csv"
26+
27+
28+
def causal_testing_sensitivity_analysis():
29+
"""Perform causal testing to evaluate the effect of six conductance inputs on one output, APD90, over the defined
30+
(normalised) design distribution to quantify the extent to which each input affects the output, and plot as a
31+
graph.
32+
"""
33+
# Read in the 200 model runs and define mean value and expected effect
34+
model_runs = pd.read_csv("data/results.csv")
35+
conductance_means = {'G_K': (.5, Positive),
36+
'G_b': (.5, Positive),
37+
'G_K1': (.5, Positive),
38+
'G_si': (.5, Negative),
39+
'G_Na': (.5, NoEffect),
40+
'G_Kp': (.5, NoEffect)}
41+
42+
# Normalise the inputs as per the original study
43+
normalised_df = normalise_data(model_runs, columns=list(conductance_means.keys()))
44+
normalised_df.to_csv("data/normalised_results.csv")
45+
46+
# For each input, perform 10 causal tests that change the input from its mean value (0.5) to the equidistant values
47+
# [0, 0.1, 0.2, ..., 0.9, 1] over the input space of each input, as defined by the normalised design distribution.
48+
# For each input, this will yield 10 causal test results that measure the extent the input causes APD90 to change,
49+
# enabling us to compare the magnitude and direction of each inputs' effect.
50+
treatment_values = np.linspace(0, 1, 11)
51+
results = {'G_K': {},
52+
'G_b': {},
53+
'G_K1': {},
54+
'G_si': {},
55+
'G_Na': {},
56+
'G_Kp': {}}
57+
for conductance_param, mean_and_oracle in conductance_means.items():
58+
average_treatment_effects = []
59+
confidence_intervals = []
60+
61+
# Perform each causal test for the given input
62+
for treatment_value in treatment_values:
63+
mean, oracle = mean_and_oracle
64+
conductance_input = Input(conductance_param, float)
65+
ate, ci = effects_on_APD90(OBSERVATIONAL_DATA_PATH, conductance_input, 0.5, treatment_value, oracle)
66+
67+
# Store results
68+
average_treatment_effects.append(ate)
69+
confidence_intervals.append(ci)
70+
results[conductance_param] = {"ate": average_treatment_effects, "cis": confidence_intervals}
71+
plot_ates_with_cis(results, treatment_values)
2572

2673

2774
def effects_on_APD90(observational_data_path, treatment_var, control_val, treatment_val, expected_causal_effect):
28-
""" Perform causal testing for the scenario in which we investigate the causal effect of G_K on APD90.
75+
"""Perform causal testing for the scenario in which we investigate the causal effect of a given input on APD90.
2976
3077
:param observational_data_path: Path to observational data containing previous executions of the LR91 model.
3178
:param treatment_var: The input variable whose effect on APD90 we are interested in.
3279
:param control_val: The control value for the treatment variable (before intervention).
3380
:param treatment_val: The treatment value for the treatment variable (after intervention).
81+
:param expected_causal_effect: The expected causal effect (Positive, Negative, No Effect).
3482
:return: ATE for the effect of G_K on APD90
3583
"""
3684
# 1. Define Causal DAG
@@ -52,68 +100,77 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
52100
apd50 = Output('APD50', int)
53101
apd90 = Output('APD90', int)
54102

55-
# 3. Create scenario by applying constraints over a subset of the inputs
103+
# 4. Create scenario by applying constraints over a subset of the inputs
56104
scenario = Scenario(
57105
variables={g_na, g_si, g_k, g_k1, g_kp, g_b,
58106
max_voltage, rest_voltage, max_voltage_gradient, dome_voltage, apd50, apd90},
59107
constraints=set()
60108
)
61109

62-
# 4. Create a causal specification from the scenario and causal DAG
110+
# 5. Create a causal specification from the scenario and causal DAG
63111
causal_specification = CausalSpecification(scenario, causal_dag)
64112

65-
# 5. Create a causal test case
113+
# 6. Create a causal test case
66114
causal_test_case = CausalTestCase(control_input_configuration={treatment_var: control_val},
67115
expected_causal_effect=expected_causal_effect,
68116
outcome_variables={apd90},
69117
intervention=Intervention((treatment_var,), (treatment_val,), ), )
70118

71-
# 6. Create a data collector
119+
# 7. Create a data collector
72120
data_collector = ObservationalDataCollector(scenario, observational_data_path)
73121

74-
# 7. Create an instance of the causal test engine
122+
# 8. Create an instance of the causal test engine
75123
causal_test_engine = CausalTestEngine(causal_test_case, causal_specification, data_collector)
76124

77-
# 8. Obtain the minimal adjustment set from the causal DAG
125+
# 9. Obtain the minimal adjustment set from the causal DAG
78126
minimal_adjustment_set = causal_test_engine.load_data(index_col=0)
79-
80127
linear_regression_estimator = LinearRegressionEstimator((treatment_var.name,), treatment_val, control_val,
81128
minimal_adjustment_set,
82-
('APD90',),
83-
)
129+
('APD90',)
130+
)
131+
132+
# 10. Run the causal test and print results
84133
causal_test_result = causal_test_engine.execute_test(linear_regression_estimator, 'ate')
85134
print(causal_test_result)
86135
return causal_test_result.ate, causal_test_result.confidence_intervals
87136

88137

89-
def plot_ates_with_cis(results_dict, xs, save=True):
90-
"""
91-
Plot the average treatment effects for a given treatment against a list of x-values with confidence intervals.
138+
def plot_ates_with_cis(results_dict: dict, xs: list, save: bool = True):
139+
"""Plot the average treatment effects for a given treatment against a list of x-values with confidence intervals.
92140
93141
:param results_dict: A dictionary containing results for sensitivity analysis of each input parameter.
94142
:param xs: Values to be plotted on the x-axis.
143+
:param save: Whether to save the plot.
95144
"""
96145
fig, axes = plt.subplots()
97-
for treatment, results in results_dict.items():
98-
ates = results['ate']
99-
cis = results['cis']
146+
input_colors = {'G_Na': 'red',
147+
'G_si': 'green',
148+
'G_K': 'blue',
149+
'G_K1': 'magenta',
150+
'G_Kp': 'cyan',
151+
'G_b': 'yellow'}
152+
for treatment, test_results in results_dict.items():
153+
ates = test_results['ate']
154+
cis = test_results['cis']
100155
before_underscore, after_underscore = treatment.split('_')
101156
after_underscore_braces = f"{{{after_underscore}}}"
102157
latex_compatible_treatment_str = rf"${before_underscore}_{after_underscore_braces}$"
103-
cis_low = [ci[0] for ci in cis]
104-
cis_high = [ci[1] for ci in cis]
158+
cis_low = [c[0] for c in cis]
159+
cis_high = [c[1] for c in cis]
105160

106-
axes.fill_between(xs, cis_low, cis_high, alpha=.3, label=latex_compatible_treatment_str)
107-
axes.plot(xs, ates, color='black', linewidth=.5)
108-
axes.plot(xs, [0] * len(xs), color='red', alpha=.5, linestyle='--', linewidth=.5)
161+
axes.fill_between(xs, cis_low, cis_high, alpha=.2, color=input_colors[treatment],
162+
label=latex_compatible_treatment_str)
163+
axes.plot(xs, ates, color=input_colors[treatment], linewidth=1)
164+
axes.plot(xs, [0] * len(xs), color='black', alpha=.5, linestyle='--', linewidth=1)
109165
axes.set_ylabel(r"ATE: Change in $APD_{90} (ms)$")
110166
axes.set_xlabel(r"Treatment value")
111-
axes.set_ylim(-150, 150)
167+
axes.set_ylim(-80, 80)
112168
axes.set_xlim(min(xs), max(xs))
113169
box = axes.get_position()
114170
axes.set_position([box.x0, box.y0 + box.height * 0.3,
115-
box.width, box.height * 0.7])
116-
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), fancybox=True, ncol=6, title='Input Parameters')
171+
box.width * 0.85, box.height * 0.7])
172+
plt.legend(loc='center left', bbox_to_anchor=(1.01, 0.5), fancybox=True, ncol=1,
173+
title=r'Input (95\% CIs)')
117174
if save:
118175
plt.savefig(f"APD90_sensitivity.pdf", format="pdf")
119176
plt.show()
@@ -122,38 +179,11 @@ def plot_ates_with_cis(results_dict, xs, save=True):
122179
def normalise_data(df, columns=None):
123180
""" Normalise the data in the dataframe into the range [0, 1]. """
124181
if columns:
125-
df[columns] = (df[columns] - df[columns].min())/(df[columns].max() - df[columns].min())
182+
df[columns] = (df[columns] - df[columns].min()) / (df[columns].max() - df[columns].min())
126183
return df
127184
else:
128-
return (df - df.min())/(df.max() - df.min())
185+
return (df - df.min()) / (df.max() - df.min())
129186

130187

131188
if __name__ == '__main__':
132-
df = pd.read_csv("data/results.csv")
133-
conductance_means = {'G_K': (.5, Positive),
134-
'G_b': (.5, Positive),
135-
'G_K1': (.5, Positive),
136-
'G_si': (.5, Negative),
137-
'G_Na': (.5, NoEffect),
138-
'G_Kp': (.5, NoEffect)}
139-
normalised_df = normalise_data(df, columns=list(conductance_means.keys()))
140-
normalised_df.to_csv("data/normalised_results.csv")
141-
142-
treatment_values = np.linspace(0, 1, 20)
143-
results_dict = {'G_K': {},
144-
'G_b': {},
145-
'G_K1': {},
146-
'G_si': {},
147-
'G_Na': {},
148-
'G_Kp': {}}
149-
for conductance_param, mean_and_oracle in conductance_means.items():
150-
average_treatment_effects = []
151-
confidence_intervals = []
152-
for treatment_value in treatment_values:
153-
mean, oracle = mean_and_oracle
154-
conductance_input = Input(conductance_param, float)
155-
ate, cis = effects_on_APD90(OBSERVATIONAL_DATA_PATH, conductance_input, 0, treatment_value, oracle)
156-
average_treatment_effects.append(ate)
157-
confidence_intervals.append(cis)
158-
results_dict[conductance_param] = {"ate": average_treatment_effects, "cis": confidence_intervals}
159-
plot_ates_with_cis(results_dict, treatment_values)
189+
causal_testing_sensitivity_analysis()

0 commit comments

Comments
 (0)