@@ -55,8 +55,6 @@ def __init__(self, output_path: str, output_overwrite: bool = False):
55
55
self .causal_specification = None
56
56
self .output_path = Path (output_path )
57
57
self .check_file_exists (self .output_path , output_overwrite )
58
- self .effects = None
59
- self .mutates = None
60
58
61
59
def set_paths (self , json_path : str , dag_path : str , data_paths : str ):
62
60
"""
@@ -113,17 +111,15 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
113
111
:param estimators: Dictionary mapping estimator classes to string representations.
114
112
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
115
113
"""
116
- self .effects = effects
117
- self .mutates = mutates
118
114
for test in self .test_plan ["tests" ]:
119
115
if "skip" in test and test ["skip" ]:
120
116
continue
121
117
test ["estimator" ] = estimators [test ["estimator" ]]
122
118
if "mutations" in test :
123
119
if test ["estimate_type" ] == "coefficient" :
124
- msg = self .run_coefficient_test (test = test , f_flag = f_flag )
120
+ msg = self ._run_coefficient_test (test = test , f_flag = f_flag , effects = effects )
125
121
else :
126
- msg = self .run_ate_test (test = test , f_flag = f_flag )
122
+ msg = self ._run_ate_test (test = test , f_flag = f_flag , effects = effects , mutates = mutates )
127
123
self ._append_to_file (msg , logging .INFO )
128
124
else :
129
125
outcome_variable = next (
@@ -152,11 +148,12 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
152
148
)
153
149
self ._append_to_file (msg , logging .INFO )
154
150
155
- def run_coefficient_test (self , test : dict , f_flag : bool ):
151
+ def _run_coefficient_test (self , test : dict , f_flag : bool , effects : dict ):
156
152
"""Builds structures and runs test case for tests with an estimate_type of 'coefficient'.
157
153
158
154
:param test: Single JSON test definition stored in a mapping (dict)
159
155
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
156
+ :param effects: Dictionary mapping effect class instances to string representations.
160
157
:return: String containing the message to be outputted
161
158
"""
162
159
base_test_case = BaseTestCase (
@@ -167,7 +164,7 @@ def run_coefficient_test(self, test: dict, f_flag: bool):
167
164
assert len (test ["expected_effect" ]) == 1 , "Can only have one expected effect."
168
165
causal_test_case = CausalTestCase (
169
166
base_test_case = base_test_case ,
170
- expected_causal_effect = next (self . effects [effect ] for variable , effect in test ["expected_effect" ].items ()),
167
+ expected_causal_effect = next (effects [effect ] for variable , effect in test ["expected_effect" ].items ()),
171
168
estimate_type = "coefficient" ,
172
169
effect_modifier_configuration = {self .scenario .variables [v ] for v in test .get ("effect_modifiers" , [])},
173
170
)
@@ -182,11 +179,13 @@ def run_coefficient_test(self, test: dict, f_flag: bool):
182
179
)
183
180
return msg
184
181
185
- def run_ate_test (self , test : dict , f_flag : bool ):
182
+ def _run_ate_test (self , test : dict , f_flag : bool , effects : dict , mutates : dict ):
186
183
"""Builds structures and runs test case for tests with an estimate_type of 'ate'.
187
184
188
185
:param test: Single JSON test definition stored in a mapping (dict)
189
186
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
187
+ :param effects: Dictionary mapping effect class instances to string representations.
188
+ :param mutates: Dictionary mapping mutation functions to string representations.
190
189
:return: String containing the message to be outputted
191
190
"""
192
191
if "sample_size" in test :
@@ -197,7 +196,7 @@ def run_ate_test(self, test: dict, f_flag: bool):
197
196
target_ks_score = test ["target_ks_score" ]
198
197
else :
199
198
target_ks_score = 0.05
200
- abstract_test = self ._create_abstract_test_case (test , self . mutates , self . effects )
199
+ abstract_test = self ._create_abstract_test_case (test , mutates , effects )
201
200
concrete_tests , _ = abstract_test .generate_concrete_tests (
202
201
sample_size = sample_size , target_ks_score = target_ks_score
203
202
)
0 commit comments