12
12
from causal_testing .testing .causal_test_engine import CausalTestEngine
13
13
from causal_testing .testing .estimators import LinearRegressionEstimator
14
14
from matplotlib .pyplot import rcParams
15
- OBSERVATIONAL_DATA_PATH = "./data/normalised_results.csv"
16
15
16
+ # Make figures publication quality
17
17
rc_fonts = {
18
18
"font.size" : 8 ,
19
- "figure.figsize" : (6 , 4 ),
19
+ "figure.figsize" : (5 , 4 ),
20
20
"text.usetex" : True ,
21
21
"font.family" : "serif" ,
22
22
"text.latex.preamble" : r"\usepackage{libertine}" ,
23
23
}
24
24
rcParams .update (rc_fonts )
25
+ OBSERVATIONAL_DATA_PATH = "./data/normalised_results.csv"
26
+
27
+
28
+ def causal_testing_sensitivity_analysis ():
29
+ """Perform causal testing to evaluate the effect of six conductance inputs on one output, APD90, over the defined
30
+ (normalised) design distribution to quantify the extent to which each input affects the output, and plot as a
31
+ graph.
32
+ """
33
+ # Read in the 200 model runs and define mean value and expected effect
34
+ model_runs = pd .read_csv ("data/results.csv" )
35
+ conductance_means = {'G_K' : (.5 , Positive ),
36
+ 'G_b' : (.5 , Positive ),
37
+ 'G_K1' : (.5 , Positive ),
38
+ 'G_si' : (.5 , Negative ),
39
+ 'G_Na' : (.5 , NoEffect ),
40
+ 'G_Kp' : (.5 , NoEffect )}
41
+
42
+ # Normalise the inputs as per the original study
43
+ normalised_df = normalise_data (model_runs , columns = list (conductance_means .keys ()))
44
+ normalised_df .to_csv ("data/normalised_results.csv" )
45
+
46
+ # For each input, perform 10 causal tests that change the input from its mean value (0.5) to the equidistant values
47
+ # [0, 0.1, 0.2, ..., 0.9, 1] over the input space of each input, as defined by the normalised design distribution.
48
+ # For each input, this will yield 10 causal test results that measure the extent the input causes APD90 to change,
49
+ # enabling us to compare the magnitude and direction of each inputs' effect.
50
+ treatment_values = np .linspace (0 , 1 , 11 )
51
+ results = {'G_K' : {},
52
+ 'G_b' : {},
53
+ 'G_K1' : {},
54
+ 'G_si' : {},
55
+ 'G_Na' : {},
56
+ 'G_Kp' : {}}
57
+ for conductance_param , mean_and_oracle in conductance_means .items ():
58
+ average_treatment_effects = []
59
+ confidence_intervals = []
60
+
61
+ # Perform each causal test for the given input
62
+ for treatment_value in treatment_values :
63
+ mean , oracle = mean_and_oracle
64
+ conductance_input = Input (conductance_param , float )
65
+ ate , ci = effects_on_APD90 (OBSERVATIONAL_DATA_PATH , conductance_input , 0.5 , treatment_value , oracle )
66
+
67
+ # Store results
68
+ average_treatment_effects .append (ate )
69
+ confidence_intervals .append (ci )
70
+ results [conductance_param ] = {"ate" : average_treatment_effects , "cis" : confidence_intervals }
71
+ plot_ates_with_cis (results , treatment_values )
25
72
26
73
27
74
def effects_on_APD90 (observational_data_path , treatment_var , control_val , treatment_val , expected_causal_effect ):
28
- """ Perform causal testing for the scenario in which we investigate the causal effect of G_K on APD90.
75
+ """Perform causal testing for the scenario in which we investigate the causal effect of a given input on APD90.
29
76
30
77
:param observational_data_path: Path to observational data containing previous executions of the LR91 model.
31
78
:param treatment_var: The input variable whose effect on APD90 we are interested in.
32
79
:param control_val: The control value for the treatment variable (before intervention).
33
80
:param treatment_val: The treatment value for the treatment variable (after intervention).
81
+ :param expected_causal_effect: The expected causal effect (Positive, Negative, No Effect).
34
82
:return: ATE for the effect of G_K on APD90
35
83
"""
36
84
# 1. Define Causal DAG
@@ -52,68 +100,77 @@ def effects_on_APD90(observational_data_path, treatment_var, control_val, treatm
52
100
apd50 = Output ('APD50' , int )
53
101
apd90 = Output ('APD90' , int )
54
102
55
- # 3 . Create scenario by applying constraints over a subset of the inputs
103
+ # 4 . Create scenario by applying constraints over a subset of the inputs
56
104
scenario = Scenario (
57
105
variables = {g_na , g_si , g_k , g_k1 , g_kp , g_b ,
58
106
max_voltage , rest_voltage , max_voltage_gradient , dome_voltage , apd50 , apd90 },
59
107
constraints = set ()
60
108
)
61
109
62
- # 4 . Create a causal specification from the scenario and causal DAG
110
+ # 5 . Create a causal specification from the scenario and causal DAG
63
111
causal_specification = CausalSpecification (scenario , causal_dag )
64
112
65
- # 5 . Create a causal test case
113
+ # 6 . Create a causal test case
66
114
causal_test_case = CausalTestCase (control_input_configuration = {treatment_var : control_val },
67
115
expected_causal_effect = expected_causal_effect ,
68
116
outcome_variables = {apd90 },
69
117
intervention = Intervention ((treatment_var ,), (treatment_val ,), ), )
70
118
71
- # 6 . Create a data collector
119
+ # 7 . Create a data collector
72
120
data_collector = ObservationalDataCollector (scenario , observational_data_path )
73
121
74
- # 7 . Create an instance of the causal test engine
122
+ # 8 . Create an instance of the causal test engine
75
123
causal_test_engine = CausalTestEngine (causal_test_case , causal_specification , data_collector )
76
124
77
- # 8 . Obtain the minimal adjustment set from the causal DAG
125
+ # 9 . Obtain the minimal adjustment set from the causal DAG
78
126
minimal_adjustment_set = causal_test_engine .load_data (index_col = 0 )
79
-
80
127
linear_regression_estimator = LinearRegressionEstimator ((treatment_var .name ,), treatment_val , control_val ,
81
128
minimal_adjustment_set ,
82
- ('APD90' ,),
83
- )
129
+ ('APD90' ,)
130
+ )
131
+
132
+ # 10. Run the causal test and print results
84
133
causal_test_result = causal_test_engine .execute_test (linear_regression_estimator , 'ate' )
85
134
print (causal_test_result )
86
135
return causal_test_result .ate , causal_test_result .confidence_intervals
87
136
88
137
89
- def plot_ates_with_cis (results_dict , xs , save = True ):
90
- """
91
- Plot the average treatment effects for a given treatment against a list of x-values with confidence intervals.
138
+ def plot_ates_with_cis (results_dict : dict , xs : list , save : bool = True ):
139
+ """Plot the average treatment effects for a given treatment against a list of x-values with confidence intervals.
92
140
93
141
:param results_dict: A dictionary containing results for sensitivity analysis of each input parameter.
94
142
:param xs: Values to be plotted on the x-axis.
143
+ :param save: Whether to save the plot.
95
144
"""
96
145
fig , axes = plt .subplots ()
97
- for treatment , results in results_dict .items ():
98
- ates = results ['ate' ]
99
- cis = results ['cis' ]
146
+ input_colors = {'G_Na' : 'red' ,
147
+ 'G_si' : 'green' ,
148
+ 'G_K' : 'blue' ,
149
+ 'G_K1' : 'magenta' ,
150
+ 'G_Kp' : 'cyan' ,
151
+ 'G_b' : 'yellow' }
152
+ for treatment , test_results in results_dict .items ():
153
+ ates = test_results ['ate' ]
154
+ cis = test_results ['cis' ]
100
155
before_underscore , after_underscore = treatment .split ('_' )
101
156
after_underscore_braces = f"{{{ after_underscore } }}"
102
157
latex_compatible_treatment_str = rf"${ before_underscore } _{ after_underscore_braces } $"
103
- cis_low = [ci [0 ] for ci in cis ]
104
- cis_high = [ci [1 ] for ci in cis ]
158
+ cis_low = [c [0 ] for c in cis ]
159
+ cis_high = [c [1 ] for c in cis ]
105
160
106
- axes .fill_between (xs , cis_low , cis_high , alpha = .3 , label = latex_compatible_treatment_str )
107
- axes .plot (xs , ates , color = 'black' , linewidth = .5 )
108
- axes .plot (xs , [0 ] * len (xs ), color = 'red' , alpha = .5 , linestyle = '--' , linewidth = .5 )
161
+ axes .fill_between (xs , cis_low , cis_high , alpha = .2 , color = input_colors [treatment ],
162
+ label = latex_compatible_treatment_str )
163
+ axes .plot (xs , ates , color = input_colors [treatment ], linewidth = 1 )
164
+ axes .plot (xs , [0 ] * len (xs ), color = 'black' , alpha = .5 , linestyle = '--' , linewidth = 1 )
109
165
axes .set_ylabel (r"ATE: Change in $APD_{90} (ms)$" )
110
166
axes .set_xlabel (r"Treatment value" )
111
- axes .set_ylim (- 150 , 150 )
167
+ axes .set_ylim (- 80 , 80 )
112
168
axes .set_xlim (min (xs ), max (xs ))
113
169
box = axes .get_position ()
114
170
axes .set_position ([box .x0 , box .y0 + box .height * 0.3 ,
115
- box .width , box .height * 0.7 ])
116
- plt .legend (loc = 'upper center' , bbox_to_anchor = (0.5 , - 0.15 ), fancybox = True , ncol = 6 , title = 'Input Parameters' )
171
+ box .width * 0.85 , box .height * 0.7 ])
172
+ plt .legend (loc = 'center left' , bbox_to_anchor = (1.01 , 0.5 ), fancybox = True , ncol = 1 ,
173
+ title = r'Input (95\% CIs)' )
117
174
if save :
118
175
plt .savefig (f"APD90_sensitivity.pdf" , format = "pdf" )
119
176
plt .show ()
@@ -122,38 +179,11 @@ def plot_ates_with_cis(results_dict, xs, save=True):
122
179
def normalise_data (df , columns = None ):
123
180
""" Normalise the data in the dataframe into the range [0, 1]. """
124
181
if columns :
125
- df [columns ] = (df [columns ] - df [columns ].min ())/ (df [columns ].max () - df [columns ].min ())
182
+ df [columns ] = (df [columns ] - df [columns ].min ()) / (df [columns ].max () - df [columns ].min ())
126
183
return df
127
184
else :
128
- return (df - df .min ())/ (df .max () - df .min ())
185
+ return (df - df .min ()) / (df .max () - df .min ())
129
186
130
187
131
188
if __name__ == '__main__' :
132
- df = pd .read_csv ("data/results.csv" )
133
- conductance_means = {'G_K' : (.5 , Positive ),
134
- 'G_b' : (.5 , Positive ),
135
- 'G_K1' : (.5 , Positive ),
136
- 'G_si' : (.5 , Negative ),
137
- 'G_Na' : (.5 , NoEffect ),
138
- 'G_Kp' : (.5 , NoEffect )}
139
- normalised_df = normalise_data (df , columns = list (conductance_means .keys ()))
140
- normalised_df .to_csv ("data/normalised_results.csv" )
141
-
142
- treatment_values = np .linspace (0 , 1 , 20 )
143
- results_dict = {'G_K' : {},
144
- 'G_b' : {},
145
- 'G_K1' : {},
146
- 'G_si' : {},
147
- 'G_Na' : {},
148
- 'G_Kp' : {}}
149
- for conductance_param , mean_and_oracle in conductance_means .items ():
150
- average_treatment_effects = []
151
- confidence_intervals = []
152
- for treatment_value in treatment_values :
153
- mean , oracle = mean_and_oracle
154
- conductance_input = Input (conductance_param , float )
155
- ate , cis = effects_on_APD90 (OBSERVATIONAL_DATA_PATH , conductance_input , 0 , treatment_value , oracle )
156
- average_treatment_effects .append (ate )
157
- confidence_intervals .append (cis )
158
- results_dict [conductance_param ] = {"ate" : average_treatment_effects , "cis" : confidence_intervals }
159
- plot_ates_with_cis (results_dict , treatment_values )
189
+ causal_testing_sensitivity_analysis ()
0 commit comments