20
20
from causal_testing .specification .causal_specification import CausalSpecification
21
21
from causal_testing .specification .scenario import Scenario
22
22
from causal_testing .specification .variable import Input , Meta , Output
23
- from causal_testing .testing .base_test_case import BaseTestCase
24
23
from causal_testing .testing .causal_test_case import CausalTestCase
24
+ from causal_testing .testing .causal_test_result import CausalTestResult
25
25
from causal_testing .testing .causal_test_engine import CausalTestEngine
26
26
from causal_testing .testing .estimators import Estimator
27
+ from causal_testing .testing .base_test_case import BaseTestCase
27
28
28
29
logger = logging .getLogger (__name__ )
29
30
@@ -41,7 +42,7 @@ class JsonUtility:
41
42
:attr {Meta} metas: Causal variables representing metavariables.
42
43
:attr {pd.DataFrame}: Pandas DataFrame containing runtime data.
43
44
:attr {dict} test_plan: Dictionary containing the key value pairs from the loaded json test plan.
44
- :attr {Scenario} modelling_scenario :
45
+ :attr {Scenario} scenario :
45
46
:attr {CausalSpecification} causal_specification:
46
47
"""
47
48
@@ -75,6 +76,33 @@ def setup(self, scenario: Scenario):
75
76
self ._json_parse ()
76
77
self ._populate_metas ()
77
78
79
+ def _create_abstract_test_case (self , test , mutates , effects ):
80
+ assert len (test ["mutations" ]) == 1
81
+ treatment_var = next (self .scenario .variables [v ] for v in test ["mutations" ])
82
+
83
+ if not treatment_var .distribution :
84
+ fitter = Fitter (self .data [treatment_var .name ], distributions = get_common_distributions ())
85
+ fitter .fit ()
86
+ (dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
87
+ treatment_var .distribution = getattr (scipy .stats , dist )(** params )
88
+ self ._append_to_file (treatment_var .name + f" { dist } ({ params } )" , logging .INFO )
89
+
90
+ abstract_test = AbstractCausalTestCase (
91
+ scenario = self .scenario ,
92
+ intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
93
+ treatment_variable = treatment_var ,
94
+ expected_causal_effect = {
95
+ self .scenario .variables [variable ]: effects [effect ]
96
+ for variable , effect in test ["expected_effect" ].items ()
97
+ },
98
+ effect_modifiers = {self .scenario .variables [v ] for v in test ["effect_modifiers" ]}
99
+ if "effect_modifiers" in test
100
+ else {},
101
+ estimate_type = test ["estimate_type" ],
102
+ effect = test .get ("effect" , "total" ),
103
+ )
104
+ return abstract_test
105
+
78
106
def run_json_tests (self , effects : dict , estimators : dict , f_flag : bool = False , mutates : dict = None ):
79
107
"""Runs and evaluates each test case specified in the JSON input
80
108
@@ -84,23 +112,52 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
84
112
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
85
113
"""
86
114
failures = 0
115
+ msg = ""
87
116
for test in self .test_plan ["tests" ]:
88
117
if "skip" in test and test ["skip" ]:
89
118
continue
90
119
test ["estimator" ] = estimators [test ["estimator" ]]
91
120
if "mutations" in test :
92
- abstract_test = self ._create_abstract_test_case (test , mutates , effects )
93
-
94
- concrete_tests , dummy = abstract_test .generate_concrete_tests (5 , 0.05 )
95
- failures = self ._execute_tests (concrete_tests , test , f_flag )
96
- msg = (
97
- f"Executing test: { test ['name' ]} \n "
98
- + "abstract_test\n "
99
- + f"{ abstract_test } \n "
100
- + f"{ abstract_test .treatment_variable .name } ,{ abstract_test .treatment_variable .distribution } \n "
101
- + f"Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
102
- + f"{ failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
103
- )
121
+ if test ["estimate_type" ] == "coefficient" :
122
+ base_test_case = BaseTestCase (
123
+ treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
124
+ outcome_variable = next (self .scenario .variables [v ] for v in test ["expected_effect" ]),
125
+ effect = test .get ("effect" , "direct" ),
126
+ )
127
+ assert len (test ["expected_effect" ]) == 1 , "Can only have one expected effect."
128
+ causal_test_case = CausalTestCase (
129
+ base_test_case = base_test_case ,
130
+ expected_causal_effect = next (
131
+ effects [effect ] for variable , effect in test ["expected_effect" ].items ()
132
+ ),
133
+ estimate_type = "coefficient" ,
134
+ effect_modifier_configuration = {
135
+ self .scenario .variables [v ] for v in test .get ("effect_modifiers" , [])
136
+ },
137
+ )
138
+ result = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
139
+ msg = (
140
+ f"Executing test: { test ['name' ]} \n "
141
+ + f" { causal_test_case } \n "
142
+ + " "
143
+ + ("\n " ).join (str (result [1 ]).split ("\n " ))
144
+ + "==============\n "
145
+ + f" Result: { 'FAILED' if result [0 ] else 'Passed' } "
146
+ )
147
+ else :
148
+ abstract_test = self ._create_abstract_test_case (test , mutates , effects )
149
+ concrete_tests , _ = abstract_test .generate_concrete_tests (5 , 0.05 )
150
+ failures , _ = self ._execute_tests (concrete_tests , test , f_flag )
151
+
152
+ msg = (
153
+ f"Executing test: { test ['name' ]} \n "
154
+ + " abstract_test \n "
155
+ + f" { abstract_test } \n "
156
+ + f" { abstract_test .treatment_variable .name } ,"
157
+ + f" { abstract_test .treatment_variable .distribution } \n "
158
+ + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
159
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
160
+ )
104
161
self ._append_to_file (msg , logging .INFO )
105
162
else :
106
163
outcome_variable = next (
@@ -118,47 +175,28 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
118
175
treatment_value = test ["treatment_value" ],
119
176
estimate_type = test ["estimate_type" ],
120
177
)
121
- if self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag ):
122
- result = "failed"
123
- else :
124
- result = "passed"
178
+ failed , _ = self ._execute_test_case (causal_test_case = causal_test_case , test = test , f_flag = f_flag )
125
179
126
180
msg = (
127
181
f"Executing concrete test: { test ['name' ]} \n "
128
182
+ f"treatment variable: { test ['treatment_variable' ]} \n "
129
183
+ f"outcome_variable = { outcome_variable } \n "
130
184
+ f"control value = { test ['control_value' ]} , treatment value = { test ['treatment_value' ]} \n "
131
- + f"result - { result } "
185
+ + f"Result: { 'FAILED' if failed else 'Passed' } "
132
186
)
133
187
self ._append_to_file (msg , logging .INFO )
134
188
135
- def _create_abstract_test_case (self , test , mutates , effects ):
136
- assert len (test ["mutations" ]) == 1
137
- abstract_test = AbstractCausalTestCase (
138
- scenario = self .scenario ,
139
- intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
140
- treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
141
- expected_causal_effect = {
142
- self .scenario .variables [variable ]: effects [effect ]
143
- for variable , effect in test ["expected_effect" ].items ()
144
- },
145
- effect_modifiers = {self .scenario .variables [v ] for v in test ["effect_modifiers" ]}
146
- if "effect_modifiers" in test
147
- else {},
148
- estimate_type = test ["estimate_type" ],
149
- effect = test .get ("effect" , "total" ),
150
- )
151
- return abstract_test
152
-
153
189
def _execute_tests (self , concrete_tests , test , f_flag ):
154
190
failures = 0
191
+ details = []
155
192
if "formula" in test :
156
193
self ._append_to_file (f"Estimator formula used for test: { test ['formula' ]} " )
157
194
for concrete_test in concrete_tests :
158
- failed = self ._execute_test_case (concrete_test , test , f_flag )
195
+ failed , result = self ._execute_test_case (concrete_test , test , f_flag )
196
+ details .append (result )
159
197
if failed :
160
198
failures += 1
161
- return failures
199
+ return failures , details
162
200
163
201
def _json_parse (self ):
164
202
"""Parse a JSON input file into inputs, outputs, metas and a test plan"""
@@ -175,15 +213,10 @@ def _populate_metas(self):
175
213
"""
176
214
for meta in self .scenario .variables_of_type (Meta ):
177
215
meta .populate (self .data )
178
- for var in self .scenario .variables_of_type (Meta ).union (self .scenario .variables_of_type (Output )):
179
- if not var .distribution :
180
- fitter = Fitter (self .data [var .name ], distributions = get_common_distributions ())
181
- fitter .fit ()
182
- (dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
183
- var .distribution = getattr (scipy .stats , dist )(** params )
184
- self ._append_to_file (var .name + f" { dist } ({ params } )" , logging .INFO )
185
-
186
- def _execute_test_case (self , causal_test_case : CausalTestCase , test : Iterable [Mapping ], f_flag : bool ) -> bool :
216
+
217
+ def _execute_test_case (
218
+ self , causal_test_case : CausalTestCase , test : Iterable [Mapping ], f_flag : bool
219
+ ) -> (bool , CausalTestResult ):
187
220
"""Executes a singular test case, prints the results and returns the test case result
188
221
:param causal_test_case: The concrete test case to be executed
189
222
:param test: Single JSON test definition stored in a mapping (dict)
@@ -193,7 +226,10 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
193
226
:rtype: bool
194
227
"""
195
228
failed = False
196
- causal_test_engine , estimation_model = self ._setup_test (causal_test_case , test )
229
+
230
+ causal_test_engine , estimation_model = self ._setup_test (
231
+ causal_test_case , test , test ["conditions" ] if "conditions" in test else None
232
+ )
197
233
causal_test_result = causal_test_engine .execute_test (
198
234
estimation_model , causal_test_case , estimate_type = causal_test_case .estimate_type
199
235
)
@@ -216,18 +252,25 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
216
252
)
217
253
failed = True
218
254
logger .warning (" FAILED- expected %s, got %s" , causal_test_case .expected_causal_effect , result_string )
219
- return failed
255
+ return failed , causal_test_result
220
256
221
- def _setup_test (self , causal_test_case : CausalTestCase , test : Mapping ) -> tuple [CausalTestEngine , Estimator ]:
257
+ def _setup_test (
258
+ self , causal_test_case : CausalTestCase , test : Mapping , conditions : list [str ] = None
259
+ ) -> tuple [CausalTestEngine , Estimator ]:
222
260
"""Create the necessary inputs for a single test case
223
261
:param causal_test_case: The concrete test case to be executed
224
262
:param test: Single JSON test definition stored in a mapping (dict)
263
+ :param conditions: A list of conditions which should be applied to the
264
+ data. Conditions should be in the query format detailed at
265
+ https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
225
266
:returns:
226
267
- causal_test_engine - Test Engine instance for the test being run
227
268
- estimation_model - Estimator instance for the test being run
228
269
"""
229
270
230
- data_collector = ObservationalDataCollector (self .scenario , self .data )
271
+ data_collector = ObservationalDataCollector (
272
+ self .scenario , self .data .query (" & " .join (conditions )) if conditions else self .data
273
+ )
231
274
causal_test_engine = CausalTestEngine (self .causal_specification , data_collector , index_col = 0 )
232
275
233
276
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
0 commit comments