@@ -74,27 +74,19 @@ def set_variables(self, inputs: dict, outputs: dict, metas: dict):
74
74
"""
75
75
self .inputs = [Input (i ["name" ], i ["type" ], i ["distribution" ]) for i in inputs ]
76
76
self .outputs = [Output (i ["name" ], i ["type" ]) for i in outputs ]
77
- self .metas = (
78
- [Meta (i ["name" ], i ["type" ], i ["populate" ]) for i in metas ]
79
- if metas
80
- else []
81
- )
77
+ self .metas = [Meta (i ["name" ], i ["type" ], i ["populate" ]) for i in metas ] if metas else []
82
78
83
79
def setup (self ):
84
80
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
85
- self .modelling_scenario = Scenario (
86
- self .inputs + self .outputs + self .metas , None
87
- )
81
+ self .modelling_scenario = Scenario (self .inputs + self .outputs + self .metas , None )
88
82
self .modelling_scenario .setup_treatment_variables ()
89
83
self .causal_specification = CausalSpecification (
90
84
scenario = self .modelling_scenario , causal_dag = CausalDAG (self .dag_path )
91
85
)
92
86
self ._json_parse ()
93
87
self ._populate_metas ()
94
88
95
- def execute_tests (
96
- self , effects : dict , mutates : dict , estimators : dict , f_flag : bool
97
- ):
89
+ def execute_tests (self , effects : dict , mutates : dict , estimators : dict , f_flag : bool ):
98
90
"""Runs and evaluates each test case specified in the JSON input
99
91
100
92
:param effects: Dictionary mapping effect class instances to string representations.
@@ -110,20 +102,13 @@ def execute_tests(
110
102
111
103
abstract_test = AbstractCausalTestCase (
112
104
scenario = self .modelling_scenario ,
113
- intervention_constraints = [
114
- mutates [v ](k ) for k , v in test ["mutations" ].items ()
115
- ],
116
- treatment_variables = {
117
- self .modelling_scenario .variables [v ] for v in test ["mutations" ]
118
- },
105
+ intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
106
+ treatment_variables = {self .modelling_scenario .variables [v ] for v in test ["mutations" ]},
119
107
expected_causal_effect = {
120
108
self .modelling_scenario .variables [variable ]: effects [effect ]
121
109
for variable , effect in test ["expectedEffect" ].items ()
122
110
},
123
- effect_modifiers = {
124
- self .modelling_scenario .variables [v ]
125
- for v in test ["effect_modifiers" ]
126
- }
111
+ effect_modifiers = {self .modelling_scenario .variables [v ] for v in test ["effect_modifiers" ]}
127
112
if "effect_modifiers" in test
128
113
else {},
129
114
estimate_type = test ["estimate_type" ],
@@ -132,17 +117,11 @@ def execute_tests(
132
117
concrete_tests , dummy = abstract_test .generate_concrete_tests (5 , 0.05 )
133
118
logger .info ("Executing test: %s" , test ["name" ])
134
119
logger .info (abstract_test )
135
- logger .info (
136
- [(v .name , v .distribution ) for v in abstract_test .treatment_variables ]
137
- )
138
- logger .info (
139
- "Number of concrete tests for test case: %s" , str (len (concrete_tests ))
140
- )
120
+ logger .info ([(v .name , v .distribution ) for v in abstract_test .treatment_variables ])
121
+ logger .info ("Number of concrete tests for test case: %s" , str (len (concrete_tests )))
141
122
for concrete_test in concrete_tests :
142
123
executed_tests += 1
143
- failed = self ._execute_test_case (
144
- concrete_test , estimators [test ["estimator" ]], f_flag
145
- )
124
+ failed = self ._execute_test_case (concrete_test , estimators [test ["estimator" ]], f_flag )
146
125
if failed :
147
126
failures += 1
148
127
@@ -170,9 +149,7 @@ def _populate_metas(self):
170
149
var .distribution = getattr (scipy .stats , dist )(** params )
171
150
logger .info (var .name + f"{ dist } ({ params } )" )
172
151
173
- def _execute_test_case (
174
- self , causal_test_case : CausalTestCase , estimator : Estimator , f_flag : bool
175
- ) -> bool :
152
+ def _execute_test_case (self , causal_test_case : CausalTestCase , estimator : Estimator , f_flag : bool ) -> bool :
176
153
"""Executes a singular test case, prints the results and returns the test case result
177
154
:param causal_test_case: The concrete test case to be executed
178
155
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
@@ -181,9 +158,7 @@ def _execute_test_case(
181
158
"""
182
159
failed = False
183
160
184
- causal_test_engine , estimation_model = self ._setup_test (
185
- causal_test_case , estimator
186
- )
161
+ causal_test_engine , estimation_model = self ._setup_test (causal_test_case , estimator )
187
162
causal_test_result = causal_test_engine .execute_test (
188
163
estimation_model , estimate_type = causal_test_case .estimate_type
189
164
)
@@ -192,7 +167,9 @@ def _execute_test_case(
192
167
193
168
result_string = str ()
194
169
if causal_test_result .ci_low () and causal_test_result .ci_high ():
195
- result_string = f"{ causal_test_result .ci_low ()} < { causal_test_result .ate } < { causal_test_result .ci_high ()} "
170
+ result_string = (
171
+ f"{ causal_test_result .ci_low ()} < { causal_test_result .ate } < { causal_test_result .ci_high ()} "
172
+ )
196
173
else :
197
174
result_string = causal_test_result .ate
198
175
if f_flag :
@@ -209,34 +186,22 @@ def _execute_test_case(
209
186
)
210
187
return failed
211
188
212
- def _setup_test (
213
- self , causal_test_case : CausalTestCase , estimator : Estimator
214
- ) -> tuple [CausalTestEngine , Estimator ]:
189
+ def _setup_test (self , causal_test_case : CausalTestCase , estimator : Estimator ) -> tuple [CausalTestEngine , Estimator ]:
215
190
"""Create the necessary inputs for a single test case
216
191
:param causal_test_case: The concrete test case to be executed
217
192
:returns:
218
193
- causal_test_engine - Test Engine instance for the test being run
219
194
- estimation_model - Estimator instance for the test being run
220
195
"""
221
- data_collector = ObservationalDataCollector (
222
- self .modelling_scenario , self .data_path
223
- )
224
- causal_test_engine = CausalTestEngine (
225
- causal_test_case , self .causal_specification , data_collector
226
- )
196
+ data_collector = ObservationalDataCollector (self .modelling_scenario , self .data_path )
197
+ causal_test_engine = CausalTestEngine (causal_test_case , self .causal_specification , data_collector )
227
198
minimal_adjustment_set = causal_test_engine .load_data (index_col = 0 )
228
199
treatment_vars = list (causal_test_case .treatment_input_configuration )
229
- minimal_adjustment_set = minimal_adjustment_set - {
230
- v .name for v in treatment_vars
231
- }
200
+ minimal_adjustment_set = minimal_adjustment_set - {v .name for v in treatment_vars }
232
201
estimation_model = estimator (
233
202
(list (treatment_vars )[0 ].name ,),
234
- [causal_test_case .treatment_input_configuration [v ] for v in treatment_vars ][
235
- 0
236
- ],
237
- [causal_test_case .control_input_configuration [v ] for v in treatment_vars ][
238
- 0
239
- ],
203
+ [causal_test_case .treatment_input_configuration [v ] for v in treatment_vars ][0 ],
204
+ [causal_test_case .control_input_configuration [v ] for v in treatment_vars ][0 ],
240
205
minimal_adjustment_set ,
241
206
(list (causal_test_case .outcome_variables )[0 ].name ,),
242
207
causal_test_engine .scenario_execution_data_df ,
0 commit comments