19
19
from causal_testing .specification .causal_specification import CausalSpecification
20
20
from causal_testing .specification .scenario import Scenario
21
21
from causal_testing .specification .variable import Input , Meta , Output
22
+ from causal_testing .testing .base_test_case import BaseTestCase
22
23
from causal_testing .testing .causal_test_case import CausalTestCase
23
24
from causal_testing .testing .causal_test_engine import CausalTestEngine
24
25
from causal_testing .testing .estimators import Estimator
@@ -73,6 +74,60 @@ def setup(self, scenario: Scenario):
73
74
self ._json_parse ()
74
75
self ._populate_metas ()
75
76
77
+ def run_json_tests (self , effects : dict , estimators : dict , f_flag : bool = False , mutates : dict = None ):
78
+ """Runs and evaluates each test case specified in the JSON input
79
+
80
+ :param effects: Dictionary mapping effect class instances to string representations.
81
+ :param mutates: Dictionary mapping mutation functions to string representations.
82
+ :param estimators: Dictionary mapping estimator classes to string representations.
83
+ :param f_flag: Failure flag that if True the script will stop executing when a test fails.
84
+ """
85
+ failures = 0
86
+ for test in self .test_plan ["tests" ]:
87
+ if "skip" in test and test ["skip" ]:
88
+ continue
89
+ if "mutates" in test :
90
+ abstract_test = self ._create_abstract_test_case (test , mutates , effects )
91
+
92
+ concrete_tests , dummy = abstract_test .generate_concrete_tests (5 , 0.05 )
93
+ failures = self ._execute_tests (concrete_tests , estimators , test , f_flag )
94
+ msg = (
95
+ f"Executing test: { test ['name' ]} \n "
96
+ + "abstract_test\n "
97
+ + f"{ abstract_test } \n "
98
+ + f"{ abstract_test .treatment_variable .name } ,{ abstract_test .treatment_variable .distribution } \n "
99
+ + f"Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
100
+ + f"{ failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
101
+ )
102
+ self ._append_to_file (msg , logging .INFO )
103
+ else :
104
+ outcome_variable = next (iter (test ['expectedEffect' ])) # Take first key from dictionary of expected effect
105
+ expected_effect = effects [test ['expectedEffect' ][outcome_variable ]]
106
+ base_test_case = BaseTestCase (treatment_variable = self .variables ["inputs" ][test ["treatment_variable" ]],
107
+ outcome_variable = self .variables ["outputs" ][outcome_variable ])
108
+
109
+ causal_test_case = CausalTestCase (base_test_case = base_test_case ,
110
+ expected_causal_effect = expected_effect ,
111
+ control_value = test ["control_value" ],
112
+ treatment_value = test ["treatment_value" ],
113
+ estimate_type = test ["estimate_type" ])
114
+
115
+
116
+ if self ._execute_test_case (causal_test_case = causal_test_case ,
117
+ estimator = estimators [test ["estimator" ]],
118
+ f_flag = f_flag ):
119
+ result = "failed"
120
+ else :
121
+ result = "passed"
122
+
123
+ msg = (
124
+ f"Executing test: { test ['name' ]} \n "
125
+ + f"treatment variable: { test ['treatment_variable' ]} \n "
126
+ + f"outcome_variable = { outcome_variable } \n "
127
+ + f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
128
+ + f"result - { result } \n "
129
+ )
130
+ self ._append_to_file (msg , logging .INFO )
76
131
def _create_abstract_test_case (self , test , mutates , effects ):
77
132
assert len (test ["mutations" ]) == 1
78
133
abstract_test = AbstractCausalTestCase (
@@ -91,32 +146,6 @@ def _create_abstract_test_case(self, test, mutates, effects):
91
146
)
92
147
return abstract_test
93
148
94
- def generate_tests (self , effects : dict , mutates : dict , estimators : dict , f_flag : bool ):
95
- """Runs and evaluates each test case specified in the JSON input
96
-
97
- :param effects: Dictionary mapping effect class instances to string representations.
98
- :param mutates: Dictionary mapping mutation functions to string representations.
99
- :param estimators: Dictionary mapping estimator classes to string representations.
100
- :param f_flag: Failure flag that if True the script will stop executing when a test fails.
101
- """
102
- failures = 0
103
- for test in self .test_plan ["tests" ]:
104
- if "skip" in test and test ["skip" ]:
105
- continue
106
- abstract_test = self ._create_abstract_test_case (test , mutates , effects )
107
-
108
- concrete_tests , dummy = abstract_test .generate_concrete_tests (5 , 0.05 )
109
- failures = self ._execute_tests (concrete_tests , estimators , test , f_flag )
110
- msg = (
111
- f"Executing test: { test ['name' ]} \n "
112
- + "abstract_test \n "
113
- + f"{ abstract_test } \n "
114
- + f"{ abstract_test .treatment_variable .name } ,{ abstract_test .treatment_variable .distribution } \n "
115
- + f"Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
116
- + f"{ failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
117
- )
118
- self ._append_to_file (msg , logging .INFO )
119
-
120
149
def _execute_tests (self , concrete_tests , estimators , test , f_flag ):
121
150
failures = 0
122
151
for concrete_test in concrete_tests :
@@ -157,7 +186,6 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
157
186
:rtype: bool
158
187
"""
159
188
failed = False
160
-
161
189
causal_test_engine , estimation_model = self ._setup_test (causal_test_case , estimator )
162
190
causal_test_result = causal_test_engine .execute_test (
163
191
estimation_model , causal_test_case , estimate_type = causal_test_case .estimate_type
@@ -228,7 +256,7 @@ def _append_to_file(self, line: str, log_level: int = None):
228
256
"""
229
257
with open (self .output_path , "a" , encoding = "utf-8" ) as f :
230
258
f .write (
231
- line + " \n " ,
259
+ line ,
232
260
)
233
261
if log_level :
234
262
logger .log (level = log_level , msg = line )
0 commit comments