20
20
from causal_testing .specification .causal_specification import CausalSpecification
21
21
from causal_testing .specification .scenario import Scenario
22
22
from causal_testing .specification .variable import Input , Meta , Output
23
- from causal_testing .testing .base_test_case import BaseTestCase
24
23
from causal_testing .testing .causal_test_case import CausalTestCase
25
24
from causal_testing .testing .causal_test_engine import CausalTestEngine
26
25
from causal_testing .testing .estimators import Estimator
26
+ from causal_testing .testing .base_test_case import BaseTestCase
27
27
28
28
logger = logging .getLogger (__name__ )
29
29
@@ -41,7 +41,7 @@ class JsonUtility:
41
41
:attr {Meta} metas: Causal variables representing metavariables.
42
42
:attr {pd.DataFrame}: Pandas DataFrame containing runtime data.
43
43
:attr {dict} test_plan: Dictionary containing the key value pairs from the loaded json test plan.
44
- :attr {Scenario} modelling_scenario :
44
+ :attr {Scenario} scenario :
45
45
:attr {CausalSpecification} causal_specification:
46
46
"""
47
47
@@ -75,6 +75,32 @@ def setup(self, scenario: Scenario):
75
75
self ._json_parse ()
76
76
self ._populate_metas ()
77
77
78
+ def _create_abstract_test_case (self , test , mutates , effects ):
79
+ assert len (test ["mutations" ]) == 1
80
+ treatment_var = next (self .scenario .variables [v ] for v in test ["mutations" ])
81
+ if not treatment_var .distribution :
82
+ fitter = Fitter (self .data [treatment_var .name ], distributions = get_common_distributions ())
83
+ fitter .fit ()
84
+ (dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
85
+ treatment_var .distribution = getattr (scipy .stats , dist )(** params )
86
+ self ._append_to_file (treatment_var .name + f" { dist } ({ params } )" , logging .INFO )
87
+
88
+ abstract_test = AbstractCausalTestCase (
89
+ scenario = self .scenario ,
90
+ intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
91
+ treatment_variable = treatment_var ,
92
+ expected_causal_effect = {
93
+ self .scenario .variables [variable ]: effects [effect ]
94
+ for variable , effect in test ["expected_effect" ].items ()
95
+ },
96
+ effect_modifiers = {self .scenario .variables [v ] for v in test ["effect_modifiers" ]}
97
+ if "effect_modifiers" in test
98
+ else {},
99
+ estimate_type = test ["estimate_type" ],
100
+ effect = test .get ("effect" , "total" ),
101
+ )
102
+ return abstract_test
103
+
78
104
def run_json_tests (self , effects : dict , estimators : dict , f_flag : bool = False , mutates : dict = None ):
79
105
"""Runs and evaluates each test case specified in the JSON input
80
106
@@ -84,23 +110,51 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
84
110
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
85
111
"""
86
112
failures = 0
113
+ msg = ""
87
114
for test in self .test_plan ["tests" ]:
88
115
if "skip" in test and test ["skip" ]:
89
116
continue
90
117
test ["estimator" ] = estimators [test ["estimator" ]]
91
118
if "mutations" in test :
92
- abstract_test = self ._create_abstract_test_case (test , mutates , effects )
93
-
94
- concrete_tests , dummy = abstract_test .generate_concrete_tests (5 , 0.05 )
95
- failures = self ._execute_tests (concrete_tests , test , f_flag )
96
- msg = (
97
- f"Executing test: { test ['name' ]} \n "
98
- + "abstract_test\n "
99
- + f"{ abstract_test } \n "
100
- + f"{ abstract_test .treatment_variable .name } ,{ abstract_test .treatment_variable .distribution } \n "
101
- + f"Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
102
- + f"{ failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
103
- )
119
+ if test ["estimate_type" ] == "coefficient" :
120
+ base_test_case = BaseTestCase (
121
+ treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
122
+ outcome_variable = next (self .scenario .variables [v ] for v in test ["expected_effect" ]),
123
+ effect = test .get ("effect" , "direct" ),
124
+ )
125
+ assert len (test ["expected_effect" ]) == 1 , "Can only have one expected effect."
126
+ concrete_tests = [
127
+ CausalTestCase (
128
+ base_test_case = base_test_case ,
129
+ expected_causal_effect = next (
130
+ effects [effect ] for variable , effect in test ["expected_effect" ].items ()
131
+ ),
132
+ estimate_type = "coefficient" ,
133
+ effect_modifier_configuration = {
134
+ self .scenario .variables [v ] for v in test .get ("effect_modifiers" , [])
135
+ },
136
+ )
137
+ ]
138
+ failures = self ._execute_tests (concrete_tests , test , f_flag )
139
+ msg = (
140
+ f"Executing test: { test ['name' ]} \n "
141
+ + f" { concrete_tests [0 ]} \n "
142
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
143
+ )
144
+ else :
145
+ abstract_test = self ._create_abstract_test_case (test , mutates , effects )
146
+ concrete_tests , dummy = abstract_test .generate_concrete_tests (5 , 0.05 )
147
+ failures = self ._execute_tests (concrete_tests , test , f_flag )
148
+
149
+ msg = (
150
+ f"Executing test: { test ['name' ]} \n "
151
+ + " abstract_test \n "
152
+ + f" { abstract_test } \n "
153
+ + f" { abstract_test .treatment_variable .name } ,"
154
+ + f" { abstract_test .treatment_variable .distribution } \n "
155
+ + f" Number of concrete tests for test case: { str (len (concrete_tests ))} \n "
156
+ + f" { failures } /{ len (concrete_tests )} failed for { test ['name' ]} "
157
+ )
104
158
self ._append_to_file (msg , logging .INFO )
105
159
else :
106
160
outcome_variable = next (
@@ -132,24 +186,6 @@ def run_json_tests(self, effects: dict, estimators: dict, f_flag: bool = False,
132
186
)
133
187
self ._append_to_file (msg , logging .INFO )
134
188
135
- def _create_abstract_test_case (self , test , mutates , effects ):
136
- assert len (test ["mutations" ]) == 1
137
- abstract_test = AbstractCausalTestCase (
138
- scenario = self .scenario ,
139
- intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
140
- treatment_variable = next (self .scenario .variables [v ] for v in test ["mutations" ]),
141
- expected_causal_effect = {
142
- self .scenario .variables [variable ]: effects [effect ]
143
- for variable , effect in test ["expected_effect" ].items ()
144
- },
145
- effect_modifiers = {self .scenario .variables [v ] for v in test ["effect_modifiers" ]}
146
- if "effect_modifiers" in test
147
- else {},
148
- estimate_type = test ["estimate_type" ],
149
- effect = test .get ("effect" , "total" ),
150
- )
151
- return abstract_test
152
-
153
189
def _execute_tests (self , concrete_tests , test , f_flag ):
154
190
failures = 0
155
191
if "formula" in test :
@@ -175,13 +211,6 @@ def _populate_metas(self):
175
211
"""
176
212
for meta in self .scenario .variables_of_type (Meta ):
177
213
meta .populate (self .data )
178
- for var in self .scenario .variables_of_type (Meta ).union (self .scenario .variables_of_type (Output )):
179
- if not var .distribution :
180
- fitter = Fitter (self .data [var .name ], distributions = get_common_distributions ())
181
- fitter .fit ()
182
- (dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
183
- var .distribution = getattr (scipy .stats , dist )(** params )
184
- self ._append_to_file (var .name + f" { dist } ({ params } )" , logging .INFO )
185
214
186
215
def _execute_test_case (self , causal_test_case : CausalTestCase , test : Iterable [Mapping ], f_flag : bool ) -> bool :
187
216
"""Executes a singular test case, prints the results and returns the test case result
@@ -193,6 +222,15 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
193
222
:rtype: bool
194
223
"""
195
224
failed = False
225
+
226
+ for var in self .scenario .variables_of_type (Meta ).union (self .scenario .variables_of_type (Output )):
227
+ if not var .distribution :
228
+ fitter = Fitter (self .data [var .name ], distributions = get_common_distributions ())
229
+ fitter .fit ()
230
+ (dist , params ) = list (fitter .get_best (method = "sumsquare_error" ).items ())[0 ]
231
+ var .distribution = getattr (scipy .stats , dist )(** params )
232
+ self ._append_to_file (var .name + f" { dist } ({ params } )" , logging .INFO )
233
+
196
234
causal_test_engine , estimation_model = self ._setup_test (causal_test_case , test )
197
235
causal_test_result = causal_test_engine .execute_test (
198
236
estimation_model , causal_test_case , estimate_type = causal_test_case .estimate_type
@@ -218,16 +256,23 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, test: Iterable[Ma
218
256
logger .warning (" FAILED- expected %s, got %s" , causal_test_case .expected_causal_effect , result_string )
219
257
return failed
220
258
221
- def _setup_test (self , causal_test_case : CausalTestCase , test : Mapping ) -> tuple [CausalTestEngine , Estimator ]:
259
+ def _setup_test (
260
+ self , causal_test_case : CausalTestCase , test : Mapping , conditions : list [str ] = None
261
+ ) -> tuple [CausalTestEngine , Estimator ]:
222
262
"""Create the necessary inputs for a single test case
223
263
:param causal_test_case: The concrete test case to be executed
224
264
:param test: Single JSON test definition stored in a mapping (dict)
265
+ :param conditions: A list of conditions which should be applied to the
266
+ data. Conditions should be in the query format detailed at
267
+ https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
225
268
:returns:
226
269
- causal_test_engine - Test Engine instance for the test being run
227
270
- estimation_model - Estimator instance for the test being run
228
271
"""
229
272
230
- data_collector = ObservationalDataCollector (self .scenario , self .data )
273
+ data_collector = ObservationalDataCollector (
274
+ self .scenario , self .data .query (" & " .join (conditions )) if conditions else self .data
275
+ )
231
276
causal_test_engine = CausalTestEngine (self .causal_specification , data_collector , index_col = 0 )
232
277
233
278
minimal_adjustment_set = self .causal_specification .causal_dag .identification (causal_test_case .base_test_case )
0 commit comments