@@ -123,54 +123,15 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
123
123
:param estimators: Dictionary mapping estimator classes to string representations.
124
124
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
125
125
"""
126
- failures = 0
127
- msg = ""
128
126
for test in self .test_plan ["tests" ]:
129
127
if "skip" in test and test ["skip" ]:
130
128
continue
131
129
test ["estimator" ] = estimators [test ["estimator" ]]
132
130
if "mutations" in test :
133
131
if test ["estimate_type" ] == "coefficient" :
134
- base_test_case = BaseTestCase (
135
- treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
136
- outcome_variable = next (self .scenario .variables [v ] for v in test ["expected_effect" ]),
137
- effect = test .get ("effect" , "direct" ),
138
- )
139
- assert len (test ["expected_effect" ]) == 1 , "Can only have one expected effect."
140
- causal_test_case = CausalTestCase (
141
- base_test_case = base_test_case ,
142
- expected_causal_effect = next (
143
- effects [effect ] for variable , effect in test ["expected_effect" ].items ()
144
- ),
145
- estimate_type = "coefficient" ,
146
- effect_modifier_configuration = {
147
- self .scenario .variables [v ] for v in test .get ("effect_modifiers" , [])
148
- },
149
- )
150
- result = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
151
- msg = (
152
- f"Executing test: { test ['name' ]} \n "
153
- + f" { causal_test_case } \n "
154
- + " "
155
- + ("\n " ).join (str (result [1 ]).split ("\n " ))
156
- + "==============\n "
157
- + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
158
- )
159
- print (msg )
132
+ msg = self ._run_coefficient_test (test = test , f_flag = f_flag , effects = effects )
160
133
else :
161
- abstract_test = self ._create_abstract_test_case (test , mutates , effects )
162
- concrete_tests , _ = abstract_test .generate_concrete_tests (5 , 0.05 )
163
- failures , _ = self ._execute_tests (concrete_tests , test , f_flag )
164
-
165
- msg = (
166
- f"Executing test: { test ['name' ]} \n "
167
- + " abstract_test \n "
168
- + f" { abstract_test } \n "
169
- + f" { abstract_test .treatment_variable .name } ,"
170
- + f" { abstract_test .treatment_variable .distribution } \n "
171
- + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
172
- + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
173
- )
134
+ msg = self ._run_ate_test (test = test , f_flag = f_flag , effects = effects , mutates = mutates )
174
135
self ._append_to_file (msg , logging .INFO )
175
136
else :
176
137
outcome_variable = next (
@@ -197,8 +158,74 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
197
158
+ f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
198
159
+ f"Result: { 'FAILED' if failed else 'Passed' } "
199
160
)
161
+ print (msg )
200
162
self ._append_to_file (msg , logging .INFO )
201
163
164
+ def _run_coefficient_test (self , test : dict , f_flag : bool , effects : dict ):
165
+ """Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
166
+
167
+ :param test: Single JSON test definition stored in a mapping (dict)
168
+ :param f_flag: Failure flag that if True the script will stop executing when a test fails.
169
+ :param effects: Dictionary mapping effect class instances to string representations.
170
+ :return: String containing the message to be outputted
171
+ """
172
+ base_test_case = BaseTestCase (
173
+ treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
174
+ outcome_variable = next (self .scenario .variables [v ] for v in test ["expected_effect" ]),
175
+ effect = test .get ("effect" , "direct" ),
176
+ )
177
+ assert len (test ["expected_effect" ]) == 1 , "Can only have one expected effect."
178
+ causal_test_case = CausalTestCase (
179
+ base_test_case = base_test_case ,
180
+ expected_causal_effect = next (effects [effect ] for variable , effect in test ["expected_effect" ].items ()),
181
+ estimate_type = "coefficient" ,
182
+ effect_modifier_configuration = {self .scenario .variables [v ] for v in test .get ("effect_modifiers" , [])},
183
+ )
184
+ result = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
185
+ msg = (
186
+ f"Executing test: { test ['name' ]} \n "
187
+ + f" { causal_test_case } \n "
188
+ + " "
189
+ + ("\n " ).join (str (result [1 ]).split ("\n " ))
190
+ + "==============\n "
191
+ + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
192
+ )
193
+ return msg
194
+
195
+ def _run_ate_test (self , test : dict , f_flag : bool , effects : dict , mutates : dict ):
196
+ """Builds structures and runs test case for tests with an estimate_type of 'ate'.
197
+
198
+ :param test: Single JSON test definition stored in a mapping (dict)
199
+ :param f_flag: Failure flag that if True the script will stop executing when a test fails.
200
+ :param effects: Dictionary mapping effect class instances to string representations.
201
+ :param mutates: Dictionary mapping mutation functions to string representations.
202
+ :return: String containing the message to be outputted
203
+ """
204
+ if "sample_size" in test :
205
+ sample_size = test ["sample_size" ]
206
+ else :
207
+ sample_size = 5
208
+ if "target_ks_score" in test :
209
+ target_ks_score = test ["target_ks_score" ]
210
+ else :
211
+ target_ks_score = 0.05
212
+ abstract_test = self ._create_abstract_test_case (test , mutates , effects )
213
+ concrete_tests , _ = abstract_test .generate_concrete_tests (
214
+ sample_size = sample_size , target_ks_score = target_ks_score
215
+ )
216
+ failures , _ = self ._execute_tests (concrete_tests , test , f_flag )
217
+
218
+ msg = (
219
+ f"Executing test: { test ['name' ]} \n "
220
+ + " abstract_test \n "
221
+ + f" { abstract_test } \n "
222
+ + f" { abstract_test .treatment_variable .name } ,"
223
+ + f" { abstract_test .treatment_variable .distribution } \n "
224
+ + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
225
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
226
+ )
227
+ return msg
228
+
202
229
def _execute_tests (self , concrete_tests , test , f_flag ):
203
230
failures = 0
204
231
details = []
0 commit comments