@@ -162,7 +162,7 @@ def generate_metamorphic_relations(
162
162
if nodes_to_test is None :
163
163
nodes_to_test = dag .nodes
164
164
165
- if not threads :
165
+ if threads < 2 :
166
166
metamorphic_relations = [
167
167
generate_metamorphic_relation (node_pair , dag , nodes_to_ignore )
168
168
for node_pair in combinations (filter (lambda node : node not in nodes_to_ignore , nodes_to_test ), 2 )
@@ -180,36 +180,25 @@ def generate_metamorphic_relations(
180
180
return [item for items in metamorphic_relations for item in items ]
181
181
182
182
183
- if __name__ == "__main__" : # pragma: no cover
184
- logging .basicConfig (format = "%(levelname)s: %(message)s" , level = logging .INFO )
185
- parser = argparse .ArgumentParser (
186
- description = "A script for generating metamorphic relations to test the causal relationships in a given DAG."
187
- )
188
- parser .add_argument (
189
- "--dag_path" ,
190
- "-d" ,
191
- help = "Specify path to file containing the DAG, normally a .dot file." ,
192
- required = True ,
193
- )
194
- parser .add_argument (
195
- "--output_path" ,
196
- "-o" ,
197
- help = "Specify path where tests should be saved, normally a .json file." ,
198
- required = True ,
199
- )
200
- parser .add_argument (
201
- "--threads" , "-t" , type = int , help = "The number of parallel threads to use." , required = False , default = 0
202
- )
203
- parser .add_argument ("-i" , "--ignore-cycles" , action = "store_true" )
204
- args = parser .parse_args ()
205
-
206
- causal_dag = CausalDAG (args .dag_path , ignore_cycles = args .ignore_cycles )
183
+ def generate_causal_tests (dag_path : str , output_path : str , ignore_cycles : bool = False , threads : int = 0 ):
184
+ """
185
+ Generate and output causal tests for a given DAG.
186
+
187
+ :param dag_path: Path to the DOT file that specifies the causal DAG.
188
+ :param output_path: Path to save the JSON output.
189
+ :param ignore_cycles: Whether to bypass the check that the DAG is actually acyclic. If set to true, tests that
190
+ include variables that are part of a cycle as either treatment, outcome, or adjustment will
191
+ be omitted from the test set.
192
+ :param threads: The number of threads to use to generate tests in parallel. If unspecified, tests are generated in
193
+ serial. This is tylically fine unless the number of tests to be generated is >10000.
194
+ """
195
+ causal_dag = CausalDAG (dag_path , ignore_cycles = ignore_cycles )
207
196
208
197
dag_nodes_to_test = [
209
198
node for node in causal_dag .nodes if nx .get_node_attributes (causal_dag .graph , "test" , default = True )[node ]
210
199
]
211
200
212
- if not causal_dag .is_acyclic () and args . ignore_cycles :
201
+ if not causal_dag .is_acyclic () and ignore_cycles :
213
202
logger .warning (
214
203
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
215
204
"Your causal test suite WILL NOT BE COMPLETE!"
@@ -218,17 +207,17 @@ def generate_metamorphic_relations(
218
207
causal_dag ,
219
208
nodes_to_test = dag_nodes_to_test ,
220
209
nodes_to_ignore = set (causal_dag .cycle_nodes ()),
221
- threads = args . threads ,
210
+ threads = threads ,
222
211
)
223
212
else :
224
- relations = generate_metamorphic_relations (causal_dag , nodes_to_test = dag_nodes_to_test , threads = args . threads )
213
+ relations = generate_metamorphic_relations (causal_dag , nodes_to_test = dag_nodes_to_test , threads = threads )
225
214
226
215
tests = [
227
216
relation .to_json_stub (skip = False )
228
217
for relation in relations
229
218
if len (list (causal_dag .graph .predecessors (relation .base_test_case .outcome_variable ))) > 0
230
219
]
231
220
232
- logger .info (f"Generated { len (tests )} tests. Saving to { args . output_path } ." )
233
- with open (args . output_path , "w" , encoding = "utf-8" ) as f :
221
+ logger .info (f"Generated { len (tests )} tests. Saving to { output_path } ." )
222
+ with open (output_path , "w" , encoding = "utf-8" ) as f :
234
223
json .dump ({"tests" : tests }, f , indent = 2 )
0 commit comments