@@ -144,15 +144,21 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
144
144
failed , _ = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
145
145
146
146
msg = (
147
- f"Executing concrete test: { test ['name' ]} \n "
148
- + f"treatment variable: { test ['treatment_variable' ]} \n "
149
- + f"outcome_variable = { outcome_variable } \n "
150
- + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
151
- + f"Result: { 'FAILED' if failed else 'Passed' } "
147
+ f"Executing concrete test: { test ['name' ]} \n "
148
+ + f"treatment variable: { test ['treatment_variable' ]} \n "
149
+ + f"outcome_variable = { outcome_variable } \n "
150
+ + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
151
+ + f"Result: { 'FAILED' if failed else 'Passed' } "
152
152
)
153
153
self ._append_to_file (msg , logging .INFO )
154
154
155
- def run_coefficient_test (self , test , f_flag ):
155
+ def run_coefficient_test (self , test : dict , f_flag : bool ):
156
+ """Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
157
+
158
+ :param test: Single JSON test definition stored in a mapping (dict)
159
+ :param f_flag: Failure flag that if True the script will stop executing when a test fails.
160
+ :return: String containing the message to be outputted
161
+ """
156
162
base_test_case = BaseTestCase (
157
163
treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
158
164
outcome_variable = next (self .scenario .variables [v ] for v in test ["expected_effect" ]),
@@ -161,26 +167,28 @@ def run_coefficient_test(self, test, f_flag):
161
167
assert len (test ["expected_effect" ]) == 1 , "Can only have one expected effect."
162
168
causal_test_case = CausalTestCase (
163
169
base_test_case = base_test_case ,
164
- expected_causal_effect = next (
165
- self .effects [effect ] for variable , effect in test ["expected_effect" ].items ()
166
- ),
170
+ expected_causal_effect = next (self .effects [effect ] for variable , effect in test ["expected_effect" ].items ()),
167
171
estimate_type = "coefficient" ,
168
- effect_modifier_configuration = {
169
- self .scenario .variables [v ] for v in test .get ("effect_modifiers" , [])
170
- },
172
+ effect_modifier_configuration = {self .scenario .variables [v ] for v in test .get ("effect_modifiers" , [])},
171
173
)
172
174
result = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
173
175
msg = (
174
- f"Executing test: { test ['name' ]} \n "
175
- + f" { causal_test_case } \n "
176
- + " "
177
- + ("\n " ).join (str (result [1 ]).split ("\n " ))
178
- + "==============\n "
179
- + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
176
+ f"Executing test: { test ['name' ]} \n "
177
+ + f" { causal_test_case } \n "
178
+ + " "
179
+ + ("\n " ).join (str (result [1 ]).split ("\n " ))
180
+ + "==============\n "
181
+ + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
180
182
)
181
183
return msg
182
184
183
- def run_ate_test (self , test , f_flag ):
185
+ def run_ate_test (self , test : dict , f_flag : bool ):
186
+ """Builds structures and runs test case for tests with an estimate_type of 'ate'.
187
+
188
+ :param test: Single JSON test definition stored in a mapping (dict)
189
+ :param f_flag: Failure flag that if True the script will stop executing when a test fails.
190
+ :return: String containing the message to be outputted
191
+ """
184
192
if "sample_size" in test :
185
193
sample_size = test ["sample_size" ]
186
194
else :
@@ -190,17 +198,19 @@ def run_ate_test(self, test, f_flag):
190
198
else :
191
199
target_ks_score = 0.05
192
200
abstract_test = self ._create_abstract_test_case (test , self .mutates , self .effects )
193
- concrete_tests , _ = abstract_test .generate_concrete_tests (sample_size = sample_size , target_ks_score = target_ks_score )
201
+ concrete_tests , _ = abstract_test .generate_concrete_tests (
202
+ sample_size = sample_size , target_ks_score = target_ks_score
203
+ )
194
204
failures , _ = self ._execute_tests (concrete_tests , test , f_flag )
195
205
196
206
msg = (
197
- f"Executing test: { test ['name' ]} \n "
198
- + " abstract_test \n "
199
- + f" { abstract_test } \n "
200
- + f" { abstract_test .treatment_variable .name } ,"
201
- + f" { abstract_test .treatment_variable .distribution } \n "
202
- + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
203
- + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
207
+ f"Executing test: { test ['name' ]} \n "
208
+ + " abstract_test \n "
209
+ + f" { abstract_test } \n "
210
+ + f" { abstract_test .treatment_variable .name } ,"
211
+ + f" { abstract_test .treatment_variable .distribution } \n "
212
+ + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
213
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
204
214
)
205
215
return msg
206
216
@@ -233,7 +243,7 @@ def _populate_metas(self):
233
243
meta .populate (self .data )
234
244
235
245
def _execute_test_case (
236
- self , causal_test_case : CausalTestCase , test : Iterable [Mapping ], f_flag : bool
246
+ self , causal_test_case : CausalTestCase , test : Iterable [Mapping ], f_flag : bool
237
247
) -> (bool , CausalTestResult ):
238
248
"""Executes a singular test case, prints the results and returns the test case result
239
249
:param causal_test_case: The concrete test case to be executed
@@ -273,11 +283,11 @@ def _execute_test_case(
273
283
return failed , causal_test_result
274
284
275
285
def _setup_test (
276
- self , causal_test_case : CausalTestCase , test : Mapping , conditions : list [str ] = None
286
+ self , causal_test_case : CausalTestCase , test : Mapping , conditions : list [str ] = None
277
287
) -> tuple [CausalTestEngine , Estimator ]:
278
288
"""Create the necessary inputs for a single test case
279
289
:param causal_test_case: The concrete test case to be executed
280
- :param test: Single JSON test definition stored in a mapping (dict)
290
+ ` :param test: Single JSON test definition stored in a mapping (dict)`
281
291
:param conditions: A list of conditions which should be applied to the
282
292
data. Conditions should be in the query format detailed at
283
293
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
@@ -358,7 +368,7 @@ def get_args(test_args=None) -> argparse.Namespace:
358
368
parser .add_argument (
359
369
"-w" ,
360
370
help = "Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
361
- "careful" ,
371
+ "careful" ,
362
372
action = "store_true" ,
363
373
)
364
374
parser .add_argument (
0 commit comments