@@ -87,6 +87,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
87
87
for test in self .test_plan ["tests" ]:
88
88
if "skip" in test and test ["skip" ]:
89
89
continue
90
+ test ["estimator" ] = estimators [test ["estimator" ]]
90
91
if "mutations" in test :
91
92
abstract_test = self ._create_abstract_test_case (test , mutates , effects )
92
93
@@ -117,9 +118,8 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
117
118
treatment_value = test ["treatment_value" ],
118
119
estimate_type = test ["estimate_type" ],
119
120
)
120
-
121
121
if self ._execute_test_case (
122
- causal_test_case = causal_test_case , estimator = estimators [ test [ "estimator" ]] , f_flag = f_flag
122
+ causal_test_case = causal_test_case , test = test , f_flag = f_flag
123
123
):
124
124
result = "failed"
125
125
else :
@@ -130,7 +130,7 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
130
130
+ f"treatment variable: { test ['treatment_variable' ]} \n "
131
131
+ f"outcome_variable = { outcome_variable } \n "
132
132
+ f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
133
- + f"result - { result } \n "
133
+ + f"result - { result } "
134
134
)
135
135
self ._append_to_file (msg , logging .INFO )
136
136
@@ -154,7 +154,6 @@ def _create_abstract_test_case(self, test, mutates, effects):
154
154
155
155
def _execute_tests (self , concrete_tests , estimators , test , f_flag ):
156
156
failures = 0
157
- test ["estimator" ] = estimators [test ["estimator" ]]
158
157
if "formula" in test :
159
158
self ._append_to_file (f"Estimator formula used for test: { test ['formula' ]} " )
160
159
for concrete_test in concrete_tests :
@@ -203,7 +202,6 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
203
202
204
203
test_passes = causal_test_case .expected_causal_effect .apply (causal_test_result )
205
204
206
- result_string = str ()
207
205
if causal_test_result .ci_low () and causal_test_result .ci_high ():
208
206
result_string = (
209
207
f"{ causal_test_result .ci_low ()} < { causal_test_result .test_value .value } < "
@@ -248,7 +246,6 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> tuple[
248
246
}
249
247
if "formula" in test :
250
248
estimator_kwargs ["formula" ] = test ["formula" ]
251
-
252
249
estimation_model = test ["estimator" ](** estimator_kwargs )
253
250
return causal_test_engine , estimation_model
254
251
@@ -261,7 +258,7 @@ def _append_to_file(self, line: str, log_level: int = None):
261
258
"""
262
259
with open (self .output_path , "a" , encoding = "utf-8" ) as f :
263
260
f .write (
264
- line + "\n " ,
261
+ line + "\n "
265
262
)
266
263
if log_level :
267
264
logger .log (level = log_level , msg = line )
0 commit comments