@@ -104,48 +104,7 @@ def _create_abstract_test_case(self, test, mutates, effects):
104
104
effect = test .get ("effect" , "total" ),
105
105
)
106
106
return abstract_test
107
- def run_coefficient_test (self , test , f_flag ):
108
- base_test_case = BaseTestCase (
109
- treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
110
- outcome_variable = next (self .scenario .variables [v ] for v in test ["expected_effect" ]),
111
- effect = test .get ("effect" , "direct" ),
112
- )
113
- assert len (test ["expected_effect" ]) == 1 , "Can only have one expected effect."
114
- causal_test_case = CausalTestCase (
115
- base_test_case = base_test_case ,
116
- expected_causal_effect = next (
117
- self .effects [effect ] for variable , effect in test ["expected_effect" ].items ()
118
- ),
119
- estimate_type = "coefficient" ,
120
- effect_modifier_configuration = {
121
- self .scenario .variables [v ] for v in test .get ("effect_modifiers" , [])
122
- },
123
- )
124
- result = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
125
- msg = (
126
- f"Executing test: { test ['name' ]} \n "
127
- + f" { causal_test_case } \n "
128
- + " "
129
- + ("\n " ).join (str (result [1 ]).split ("\n " ))
130
- + "==============\n "
131
- + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
132
- )
133
- return msg
134
- def run_ate_test (self , test , f_flag ):
135
- abstract_test = self ._create_abstract_test_case (test , self .mutates , self .effects )
136
- concrete_tests , _ = abstract_test .generate_concrete_tests (sample_size = 5 , target_ks_score = 0.05 )
137
- failures , _ = self ._execute_tests (concrete_tests , test , f_flag )
138
107
139
- msg = (
140
- f"Executing test: { test ['name' ]} \n "
141
- + " abstract_test \n "
142
- + f" { abstract_test } \n "
143
- + f" { abstract_test .treatment_variable .name } ,"
144
- + f" { abstract_test .treatment_variable .distribution } \n "
145
- + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
146
- + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
147
- )
148
- return msg
149
108
def run_json_tests (self , effects : dict , estimators : dict , f_flag : bool = False , mutates : dict = None ):
150
109
"""Runs and evaluates each test case specified in the JSON input
151
110
@@ -164,7 +123,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
164
123
if test ["estimate_type" ] == "coefficient" :
165
124
msg = self .run_coefficient_test (test = test , f_flag = f_flag )
166
125
else :
167
- msg = self .run_ate_test (test = test , f_flag = f_flag )
126
+ msg = self .run_ate_test (test = test , f_flag = f_flag )
168
127
self ._append_to_file (msg , logging .INFO )
169
128
else :
170
129
outcome_variable = next (
@@ -185,14 +144,66 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
185
144
failed , _ = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
186
145
187
146
msg = (
188
- f"Executing concrete test: { test ['name' ]} \n "
189
- + f"treatment variable: { test ['treatment_variable' ]} \n "
190
- + f"outcome_variable = { outcome_variable } \n "
191
- + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
192
- + f"Result: { 'FAILED' if failed else 'Passed' } "
147
+ f"Executing concrete test: { test ['name' ]} \n "
148
+ + f"treatment variable: { test ['treatment_variable' ]} \n "
149
+ + f"outcome_variable = { outcome_variable } \n "
150
+ + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
151
+ + f"Result: { 'FAILED' if failed else 'Passed' } "
193
152
)
194
153
self ._append_to_file (msg , logging .INFO )
195
154
155
+ def run_coefficient_test (self , test , f_flag ):
156
+ base_test_case = BaseTestCase (
157
+ treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
158
+ outcome_variable = next (self .scenario .variables [v ] for v in test ["expected_effect" ]),
159
+ effect = test .get ("effect" , "direct" ),
160
+ )
161
+ assert len (test ["expected_effect" ]) == 1 , "Can only have one expected effect."
162
+ causal_test_case = CausalTestCase (
163
+ base_test_case = base_test_case ,
164
+ expected_causal_effect = next (
165
+ self .effects [effect ] for variable , effect in test ["expected_effect" ].items ()
166
+ ),
167
+ estimate_type = "coefficient" ,
168
+ effect_modifier_configuration = {
169
+ self .scenario .variables [v ] for v in test .get ("effect_modifiers" , [])
170
+ },
171
+ )
172
+ result = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
173
+ msg = (
174
+ f"Executing test: { test ['name' ]} \n "
175
+ + f" { causal_test_case } \n "
176
+ + " "
177
+ + ("\n " ).join (str (result [1 ]).split ("\n " ))
178
+ + "==============\n "
179
+ + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
180
+ )
181
+ return msg
182
+
183
+ def run_ate_test (self , test , f_flag ):
184
+ if "sample_size" in test :
185
+ sample_size = test ["sample_size" ]
186
+ else :
187
+ sample_size = 5
188
+ if "target_ks_score" in test :
189
+ target_ks_score = test ["target_ks_score" ]
190
+ else :
191
+ target_ks_score = 0.05
192
+ abstract_test = self ._create_abstract_test_case (test , self .mutates , self .effects )
193
+ concrete_tests , _ = abstract_test .generate_concrete_tests (sample_size = sample_size , target_ks_score = target_ks_score )
194
+ failures , _ = self ._execute_tests (concrete_tests , test , f_flag )
195
+
196
+ msg = (
197
+ f"Executing test: { test ['name' ]} \n "
198
+ + " abstract_test \n "
199
+ + f" { abstract_test } \n "
200
+ + f" { abstract_test .treatment_variable .name } ,"
201
+ + f" { abstract_test .treatment_variable .distribution } \n "
202
+ + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
203
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
204
+ )
205
+ return msg
206
+
196
207
def _execute_tests (self , concrete_tests , test , f_flag ):
197
208
failures = 0
198
209
details = []
@@ -222,7 +233,7 @@ def _populate_metas(self):
222
233
meta .populate (self .data )
223
234
224
235
def _execute_test_case (
225
- self , causal_test_case : CausalTestCase , test : Iterable [Mapping ], f_flag : bool
236
+ self , causal_test_case : CausalTestCase , test : Iterable [Mapping ], f_flag : bool
226
237
) -> (bool , CausalTestResult ):
227
238
"""Executes a singular test case, prints the results and returns the test case result
228
239
:param causal_test_case: The concrete test case to be executed
@@ -262,7 +273,7 @@ def _execute_test_case(
262
273
return failed , causal_test_result
263
274
264
275
def _setup_test (
265
- self , causal_test_case : CausalTestCase , test : Mapping , conditions : list [str ] = None
276
+ self , causal_test_case : CausalTestCase , test : Mapping , conditions : list [str ] = None
266
277
) -> tuple [CausalTestEngine , Estimator ]:
267
278
"""Create the necessary inputs for a single test case
268
279
:param causal_test_case: The concrete test case to be executed
@@ -347,7 +358,7 @@ def get_args(test_args=None) -> argparse.Namespace:
347
358
parser .add_argument (
348
359
"-w" ,
349
360
help = "Specify to overwrite any existing output files. This can lead to the loss of existing outputs if not "
350
- "careful" ,
361
+ "careful" ,
351
362
action = "store_true" ,
352
363
)
353
364
parser .add_argument (
0 commit comments