Skip to content

Commit 086d1c9

Browse files
committed
Moved test generation from DAG files into main functionality
1 parent e7b8577 commit 086d1c9

File tree

3 files changed

+56
-51
lines changed

3 files changed

+56
-51
lines changed

causal_testing/__main__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import os
77

8+
from causal_testing.testing.metamorphic_relation import generate_causal_tests
89
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework
910

1011

@@ -18,6 +19,12 @@ def main() -> None:
1819
# Parse arguments
1920
args = parse_args()
2021

22+
if args.generate:
23+
logging.info("Generating causal tests")
24+
generate_causal_tests(args.dag_path, args.output, args.ignore_cycles, args.threads)
25+
logging.info("Causal test generation completed successfully")
26+
return
27+
2128
# Setup logging
2229
setup_logging(args.verbose)
2330

causal_testing/main.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -475,26 +475,35 @@ def setup_logging(verbose: bool = False) -> None:
475475

476476
def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
477477
"""Parse command line arguments."""
478-
parser = argparse.ArgumentParser(description="Causal Testing Framework")
479-
parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
480-
parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
481-
parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
482-
parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
483-
parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
484-
parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
485-
parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
486-
parser.add_argument(
487-
"-s",
488-
"--silent",
489-
action="store_true",
490-
help="Do not crash on error. If set to true, errors are recorded as test results.",
491-
default=False,
492-
)
493-
parser.add_argument(
494-
"--batch-size",
495-
type=int,
496-
default=0,
497-
help="Run tests in batches of the specified size (default: 0, which means no batching)",
498-
)
478+
main_parser = argparse.ArgumentParser(add_help=False, description="Causal Testing Framework")
479+
main_parser.add_argument("-G", "--generate", help="Generate test cases from a DAG", action="store_true")
480+
main_parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
481+
main_parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
482+
main_parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
483+
main_args, _ = main_parser.parse_known_args()
484+
485+
parser = argparse.ArgumentParser(parents=[main_parser])
486+
if main_args.generate:
487+
parser.add_argument(
488+
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
489+
)
490+
else:
491+
parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
492+
parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
493+
parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
494+
parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
495+
parser.add_argument(
496+
"-s",
497+
"--silent",
498+
action="store_true",
499+
help="Do not crash on error. If set to true, errors are recorded as test results.",
500+
default=False,
501+
)
502+
parser.add_argument(
503+
"--batch-size",
504+
type=int,
505+
default=0,
506+
help="Run tests in batches of the specified size (default: 0, which means no batching)",
507+
)
499508

500509
return parser.parse_args(args)

causal_testing/testing/metamorphic_relation.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def generate_metamorphic_relations(
162162
if nodes_to_test is None:
163163
nodes_to_test = dag.nodes
164164

165-
if not threads:
165+
if threads < 2:
166166
metamorphic_relations = [
167167
generate_metamorphic_relation(node_pair, dag, nodes_to_ignore)
168168
for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2)
@@ -180,36 +180,25 @@ def generate_metamorphic_relations(
180180
return [item for items in metamorphic_relations for item in items]
181181

182182

183-
if __name__ == "__main__": # pragma: no cover
184-
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)
185-
parser = argparse.ArgumentParser(
186-
description="A script for generating metamorphic relations to test the causal relationships in a given DAG."
187-
)
188-
parser.add_argument(
189-
"--dag_path",
190-
"-d",
191-
help="Specify path to file containing the DAG, normally a .dot file.",
192-
required=True,
193-
)
194-
parser.add_argument(
195-
"--output_path",
196-
"-o",
197-
help="Specify path where tests should be saved, normally a .json file.",
198-
required=True,
199-
)
200-
parser.add_argument(
201-
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
202-
)
203-
parser.add_argument("-i", "--ignore-cycles", action="store_true")
204-
args = parser.parse_args()
205-
206-
causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles)
183+
def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0):
184+
"""
185+
Generate and output causal tests for a given DAG.
186+
187+
:param dag_path: Path to the DOT file that specifies the causal DAG.
188+
:param output_path: Path to save the JSON output.
189+
:param ignore_cycles: Whether to bypass the check that the DAG is actually acyclic. If set to true, tests that
190+
include variables that are part of a cycle as either treatment, outcome, or adjustment will
191+
be omitted from the test set.
192+
:param threads: The number of threads to use to generate tests in parallel. If unspecified, tests are generated in
193+
serial. This is tylically fine unless the number of tests to be generated is >10000.
194+
"""
195+
causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles)
207196

208197
dag_nodes_to_test = [
209198
node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag.graph, "test", default=True)[node]
210199
]
211200

212-
if not causal_dag.is_acyclic() and args.ignore_cycles:
201+
if not causal_dag.is_acyclic() and ignore_cycles:
213202
logger.warning(
214203
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
215204
"Your causal test suite WILL NOT BE COMPLETE!"
@@ -218,17 +207,17 @@ def generate_metamorphic_relations(
218207
causal_dag,
219208
nodes_to_test=dag_nodes_to_test,
220209
nodes_to_ignore=set(causal_dag.cycle_nodes()),
221-
threads=args.threads,
210+
threads=threads,
222211
)
223212
else:
224-
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=args.threads)
213+
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=threads)
225214

226215
tests = [
227216
relation.to_json_stub(skip=False)
228217
for relation in relations
229218
if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0
230219
]
231220

232-
logger.info(f"Generated {len(tests)} tests. Saving to {args.output_path}.")
233-
with open(args.output_path, "w", encoding="utf-8") as f:
221+
logger.info(f"Generated {len(tests)} tests. Saving to {output_path}.")
222+
with open(output_path, "w", encoding="utf-8") as f:
234223
json.dump({"tests": tests}, f, indent=2)

0 commit comments

Comments
 (0)