@@ -67,17 +67,16 @@ def set_path(self, json_path: str, dag_path: str, data_path: str):
67
67
self .dag_path = Path (dag_path )
68
68
self .data_path = Path (data_path )
69
69
70
- def set_variables (self , inputs : dict , outputs : dict , metas : dict , distributions : dict , populates : dict ):
70
+ def set_variables (self , inputs : dict , outputs : dict , metas : dict ):
71
+
71
72
"""Populate the Causal Variables
72
73
:param inputs:
73
74
:param outputs:
74
75
:param metas:
75
- :param distributions:
76
- :param populates:
77
76
"""
78
- self .inputs = [Input (i ["name" ], i ["type" ], distributions [ i ["distribution" ] ]) for i in inputs ]
77
+ self .inputs = [Input (i ["name" ], i ["type" ], i ["distribution" ]) for i in inputs ]
79
78
self .outputs = [Output (i ["name" ], i ["type" ]) for i in outputs ]
80
- self .metas = [Meta (i ["name" ], i ["type" ], populates [ i ["populate" ] ]) for i in metas ] if metas else []
79
+ self .metas = [Meta (i ["name" ], i ["type" ], i ["populate" ]) for i in metas ] if metas else []
81
80
82
81
def setup (self ):
83
82
"""Function to populate all the necessary parts of the json_class needed to execute tests"""
@@ -89,54 +88,58 @@ def setup(self):
89
88
self ._json_parse ()
90
89
self ._populate_metas ()
91
90
92
- def execute_tests (self , effects : dict , mutates : dict , estimators : dict , f_flag : bool ):
91
+ def _create_abstract_test_case (self , test , mutates , effects ):
92
+ abstract_test = AbstractCausalTestCase (
93
+ scenario = self .modelling_scenario ,
94
+ intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
95
+ treatment_variables = {self .modelling_scenario .variables [v ] for v in test ["mutations" ]},
96
+ expected_causal_effect = {
97
+ self .modelling_scenario .variables [variable ]: effects [effect ]
98
+ for variable , effect in test ["expectedEffect" ].items ()
99
+ },
100
+ effect_modifiers = {self .modelling_scenario .variables [v ] for v in test ["effect_modifiers" ]}
101
+ if "effect_modifiers" in test
102
+ else {},
103
+ estimate_type = test ["estimate_type" ],
104
+ )
105
+ return abstract_test
106
+
107
+ def generate_tests (self , effects : dict , mutates : dict , estimators : dict , f_flag : bool ):
93
108
"""Runs and evaluates each test case specified in the JSON input
94
109
95
110
:param effects: Dictionary mapping effect class instances to string representations.
96
111
:param mutates: Dictionary mapping mutation functions to string representations.
97
112
:param estimators: Dictionary mapping estimator classes to string representations.
98
113
:param f_flag: Failure flag that if True the script will stop executing when a test fails.
99
114
"""
100
- executed_tests = 0
101
115
failures = 0
102
116
for test in self .test_plan ["tests" ]:
103
117
if "skip" in test and test ["skip" ]:
104
118
continue
105
-
106
- abstract_test = AbstractCausalTestCase (
107
- scenario = self .modelling_scenario ,
108
- intervention_constraints = [mutates [v ](k ) for k , v in test ["mutations" ].items ()],
109
- treatment_variables = {self .modelling_scenario .variables [v ] for v in test ["mutations" ]},
110
- expected_causal_effect = {
111
- self .modelling_scenario .variables [variable ]: effects [effect ]
112
- for variable , effect in test ["expectedEffect" ].items ()
113
- },
114
- effect_modifiers = {self .modelling_scenario .variables [v ] for v in test ["effect_modifiers" ]}
115
- if "effect_modifiers" in test
116
- else {},
117
- estimate_type = test ["estimate_type" ],
118
- )
119
+ abstract_test = self ._create_abstract_test_case (test , mutates , effects )
119
120
120
121
concrete_tests , dummy = abstract_test .generate_concrete_tests (5 , 0.05 )
121
122
logger .info ("Executing test: %s" , test ["name" ])
122
123
logger .info (abstract_test )
123
124
logger .info ([(v .name , v .distribution ) for v in abstract_test .treatment_variables ])
124
125
logger .info ("Number of concrete tests for test case: %s" , str (len (concrete_tests )))
125
- for concrete_test in concrete_tests :
126
- executed_tests += 1
127
- failed = self ._execute_test_case (concrete_test , estimators [test ["estimator" ]], f_flag )
128
- if failed :
129
- failures += 1
126
+ failures = self ._execute_tests (concrete_tests , estimators , test , f_flag )
127
+
128
+ logger .info (f"{ failures } /{ len (concrete_tests )} failed" )
130
129
131
- logger .info ("{%d}/{%d} failed" , failures , executed_tests )
130
+ def _execute_tests (self , concrete_tests , estimators , test , f_flag ):
131
+ failures = 0
132
+ for concrete_test in concrete_tests :
133
+ failed = self ._execute_test_case (concrete_test , estimators [test ["estimator" ]], f_flag )
134
+ if failed :
135
+ failures += 1
136
+ return failures
132
137
133
138
def _json_parse (self ):
134
- """Parse a JSON input file into inputs, outputs, metas and a test plan
135
- :param distributions: dictionary of user defined scipy distributions
136
- :param populates: dictionary of user defined populate functions
137
- """
138
- with open (self .json_path , encoding = "UTF-8" ) as file :
139
- self .test_plan = json .load (file )
139
+
140
+ """Parse a JSON input file into inputs, outputs, metas and a test plan"""
141
+ with open (self .json_path ) as f :
142
+ self .test_plan = json .load (f )
140
143
141
144
self .data = pd .read_csv (self .data_path )
142
145
@@ -187,7 +190,9 @@ def _execute_test_case(self, causal_test_case: CausalTestCase, estimator: Estima
187
190
if not test_passes :
188
191
failed = True
189
192
logger .warning (
190
- " FAILED- expected %s, got %s" , causal_test_case .expected_causal_effect , causal_test_result .ate
193
+ " FAILED- expected %s, got %s" ,
194
+ causal_test_case .expected_causal_effect ,
195
+ causal_test_result .ate ,
191
196
)
192
197
return failed
193
198
@@ -235,25 +240,37 @@ def setup_logger(log_path: str):
235
240
setup_log .addHandler (file_handler )
236
241
237
242
@staticmethod
238
- def get_args () -> argparse .Namespace :
243
+ def get_args (test_args = None ) -> argparse .Namespace :
239
244
"""Command-line arguments
240
245
241
246
:return: parsed command line arguments
242
247
"""
243
248
parser = argparse .ArgumentParser (
244
249
description = "A script for parsing json config files for the Causal Testing Framework"
245
250
)
246
- parser .add_argument ("-f" , help = "if included, the script will stop if a test fails" , action = "store_true" )
251
+ parser .add_argument (
252
+ "-f" ,
253
+ help = "if included, the script will stop if a test fails" ,
254
+ action = "store_true" ,
255
+ )
247
256
parser .add_argument (
248
257
"--log_path" ,
249
258
help = "Specify a directory to change the location of the log file" ,
250
259
default = "./json_frontend.log" ,
251
260
)
252
- parser .add_argument ("--data_path" , help = "Specify path to file containing runtime data" , required = True )
253
261
parser .add_argument (
254
- "--dag_path" , help = "Specify path to file containing the DAG, normally a .dot file" , required = True
262
+ "--data_path" ,
263
+ help = "Specify path to file containing runtime data" ,
264
+ required = True ,
265
+ )
266
+ parser .add_argument (
267
+ "--dag_path" ,
268
+ help = "Specify path to file containing the DAG, normally a .dot file" ,
269
+ required = True ,
255
270
)
256
271
parser .add_argument (
257
- "--json_path" , help = "Specify path to file containing JSON tests, normally a .json file" , required = True
272
+ "--json_path" ,
273
+ help = "Specify path to file containing JSON tests, normally a .json file" ,
274
+ required = True ,
258
275
)
259
- return parser .parse_args ()
276
+ return parser .parse_args (test_args )
0 commit comments