From 086d1c9445b8a704166b5eaf725c0a3d96dba5c6 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 13 Jun 2025 09:58:55 +0100 Subject: [PATCH 1/3] Moved test generation from DAG files into main functionality --- causal_testing/__main__.py | 7 +++ causal_testing/main.py | 51 +++++++++++-------- .../testing/metamorphic_relation.py | 49 +++++++----------- 3 files changed, 56 insertions(+), 51 deletions(-) diff --git a/causal_testing/__main__.py b/causal_testing/__main__.py index a02d8b49..3b84cd9c 100644 --- a/causal_testing/__main__.py +++ b/causal_testing/__main__.py @@ -5,6 +5,7 @@ import json import os +from causal_testing.testing.metamorphic_relation import generate_causal_tests from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework @@ -18,6 +19,12 @@ def main() -> None: # Parse arguments args = parse_args() + if args.generate: + logging.info("Generating causal tests") + generate_causal_tests(args.dag_path, args.output, args.ignore_cycles, args.threads) + logging.info("Causal test generation completed successfully") + return + # Setup logging setup_logging(args.verbose) diff --git a/causal_testing/main.py b/causal_testing/main.py index aed161a8..bb972ca6 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -475,26 +475,35 @@ def setup_logging(verbose: bool = False) -> None: def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace: """Parse command line arguments.""" - parser = argparse.ArgumentParser(description="Causal Testing Framework") - parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True) - parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True) - parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True) - parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True) - parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False) - parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False) - parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str) - parser.add_argument( - "-s", - "--silent", - action="store_true", - help="Do not crash on error. If set to true, errors are recorded as test results.", - default=False, - ) - parser.add_argument( - "--batch-size", - type=int, - default=0, - help="Run tests in batches of the specified size (default: 0, which means no batching)", - ) + main_parser = argparse.ArgumentParser(add_help=False, description="Causal Testing Framework") + main_parser.add_argument("-G", "--generate", help="Generate test cases from a DAG", action="store_true") + main_parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True) + main_parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True) + main_parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False) + main_args, _ = main_parser.parse_known_args() + + parser = argparse.ArgumentParser(parents=[main_parser]) + if main_args.generate: + parser.add_argument( + "--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0 + ) + else: + parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True) + parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True) + parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False) + parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str) + parser.add_argument( + "-s", + "--silent", + action="store_true", + help="Do not crash on error. If set to true, errors are recorded as test results.", + default=False, + ) + parser.add_argument( + "--batch-size", + type=int, + default=0, + help="Run tests in batches of the specified size (default: 0, which means no batching)", + ) return parser.parse_args(args) diff --git a/causal_testing/testing/metamorphic_relation.py b/causal_testing/testing/metamorphic_relation.py index c901ff01..4459ddd5 100644 --- a/causal_testing/testing/metamorphic_relation.py +++ b/causal_testing/testing/metamorphic_relation.py @@ -162,7 +162,7 @@ def generate_metamorphic_relations( if nodes_to_test is None: nodes_to_test = dag.nodes - if not threads: + if threads < 2: metamorphic_relations = [ generate_metamorphic_relation(node_pair, dag, nodes_to_ignore) for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2) @@ -180,36 +180,25 @@ def generate_metamorphic_relations( return [item for items in metamorphic_relations for item in items] -if __name__ == "__main__": # pragma: no cover - logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO) - parser = argparse.ArgumentParser( - description="A script for generating metamorphic relations to test the causal relationships in a given DAG." - ) - parser.add_argument( - "--dag_path", - "-d", - help="Specify path to file containing the DAG, normally a .dot file.", - required=True, - ) - parser.add_argument( - "--output_path", - "-o", - help="Specify path where tests should be saved, normally a .json file.", - required=True, - ) - parser.add_argument( - "--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0 - ) - parser.add_argument("-i", "--ignore-cycles", action="store_true") - args = parser.parse_args() - - causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles) +def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0): + """ + Generate and output causal tests for a given DAG. + + :param dag_path: Path to the DOT file that specifies the causal DAG. + :param output_path: Path to save the JSON output. + :param ignore_cycles: Whether to bypass the check that the DAG is actually acyclic. If set to true, tests that + include variables that are part of a cycle as either treatment, outcome, or adjustment will + be omitted from the test set. + :param threads: The number of threads to use to generate tests in parallel. If unspecified, tests are generated in + serial. This is tylically fine unless the number of tests to be generated is >10000. + """ + causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles) dag_nodes_to_test = [ node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag.graph, "test", default=True)[node] ] - if not causal_dag.is_acyclic() and args.ignore_cycles: + if not causal_dag.is_acyclic() and ignore_cycles: logger.warning( "Ignoring cycles by removing causal tests that reference any node within a cycle. " "Your causal test suite WILL NOT BE COMPLETE!" @@ -218,10 +207,10 @@ def generate_metamorphic_relations( causal_dag, nodes_to_test=dag_nodes_to_test, nodes_to_ignore=set(causal_dag.cycle_nodes()), - threads=args.threads, + threads=threads, ) else: - relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=args.threads) + relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=threads) tests = [ relation.to_json_stub(skip=False) @@ -229,6 +218,6 @@ def generate_metamorphic_relations( if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0 ] - logger.info(f"Generated {len(tests)} tests. Saving to {args.output_path}.") - with open(args.output_path, "w", encoding="utf-8") as f: + logger.info(f"Generated {len(tests)} tests. Saving to {output_path}.") + with open(output_path, "w", encoding="utf-8") as f: json.dump({"tests": tests}, f, indent=2) From a4f1ae79fda08b193fde6adf4d04ed279d771798 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 13 Jun 2025 10:01:44 +0100 Subject: [PATCH 2/3] Removed unused argparse import --- causal_testing/testing/metamorphic_relation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/causal_testing/testing/metamorphic_relation.py b/causal_testing/testing/metamorphic_relation.py index 4459ddd5..95bef400 100644 --- a/causal_testing/testing/metamorphic_relation.py +++ b/causal_testing/testing/metamorphic_relation.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from typing import Iterable from itertools import combinations -import argparse import logging import json from multiprocessing import Pool From a51b66ddc28348277769d2cbbe1d33c0c6e29050 Mon Sep 17 00:00:00 2001 From: Michael Foster Date: Fri, 13 Jun 2025 10:29:58 +0100 Subject: [PATCH 3/3] codecov --- tests/main_tests/test_main.py | 16 +++++++ .../test_metamorphic_relations.py | 47 ++++++++++++++++++- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/tests/main_tests/test_main.py b/tests/main_tests/test_main.py index 07d864c4..5a6433ae 100644 --- a/tests/main_tests/test_main.py +++ b/tests/main_tests/test_main.py @@ -338,6 +338,22 @@ def test_parse_args_batches(self): main() self.assertTrue((self.output_path.parent / "main_batch.json").exists()) + def test_parse_args_generation(self): + with tempfile.TemporaryDirectory() as tmp: + with unittest.mock.patch( + "sys.argv", + [ + "causal_testing", + "--generate", + "--dag_path", + str(self.dag_path), + "--output", + os.path.join(tmp, "tests.json"), + ], + ): + main() + self.assertTrue(os.path.exists(os.path.join(tmp, "tests.json"))) + def tearDown(self): if self.output_path.parent.exists(): shutil.rmtree(self.output_path.parent) diff --git a/tests/testing_tests/test_metamorphic_relations.py b/tests/testing_tests/test_metamorphic_relations.py index 723285b4..68f54589 100644 --- a/tests/testing_tests/test_metamorphic_relations.py +++ b/tests/testing_tests/test_metamorphic_relations.py @@ -3,6 +3,8 @@ import shutil, tempfile import pandas as pd from itertools import combinations +import tempfile +import json from causal_testing.specification.causal_dag import CausalDAG from causal_testing.specification.causal_specification import Scenario @@ -11,6 +13,7 @@ ShouldNotCause, generate_metamorphic_relations, generate_metamorphic_relation, + generate_causal_tests, ) from causal_testing.specification.variable import Input, Output from causal_testing.testing.base_test_case import BaseTestCase @@ -177,8 +180,8 @@ def test_all_metamorphic_relations_implied_by_dag_parallel(self): self.assertEqual(missing_snc_relations, []) def test_all_metamorphic_relations_implied_by_dag_ignore_cycles(self): - dag = CausalDAG(self.dcg_dot_path, ignore_cycles=True) - metamorphic_relations = generate_metamorphic_relations(dag, threads=2, nodes_to_ignore=set(dag.cycle_nodes())) + dcg = CausalDAG(self.dcg_dot_path, ignore_cycles=True) + metamorphic_relations = generate_metamorphic_relations(dcg, threads=2, nodes_to_ignore=set(dcg.cycle_nodes())) should_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldCause)] should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldNotCause)] @@ -203,6 +206,46 @@ def test_generate_metamorphic_relation_(self): ShouldCause(BaseTestCase("X1", "Z"), []), ) + def test_generate_causal_tests_ignore_cycles(self): + dcg = CausalDAG(self.dcg_dot_path, ignore_cycles=True) + relations = generate_metamorphic_relations(dcg, nodes_to_ignore=set(dcg.cycle_nodes())) + with tempfile.TemporaryDirectory() as tmp: + tests_file = os.path.join(tmp, "causal_tests.json") + generate_causal_tests(self.dcg_dot_path, tests_file, ignore_cycles=True) + with open(tests_file, encoding="utf8") as f: + tests = json.load(f) + expected = list( + map( + lambda x: x.to_json_stub(skip=False), + filter( + lambda relation: len(list(dcg.graph.predecessors(relation.base_test_case.outcome_variable))) + > 0, + relations, + ), + ) + ) + self.assertEqual(tests["tests"], expected) + + def test_generate_causal_tests(self): + dag = CausalDAG(self.dag_dot_path) + relations = generate_metamorphic_relations(dag) + with tempfile.TemporaryDirectory() as tmp: + tests_file = os.path.join(tmp, "causal_tests.json") + generate_causal_tests(self.dag_dot_path, tests_file) + with open(tests_file, encoding="utf8") as f: + tests = json.load(f) + expected = list( + map( + lambda x: x.to_json_stub(skip=False), + filter( + lambda relation: len(list(dag.graph.predecessors(relation.base_test_case.outcome_variable))) + > 0, + relations, + ), + ) + ) + self.assertEqual(tests["tests"], expected) + def test_shoud_cause_string(self): sc_mr = ShouldCause(BaseTestCase("X", "Y"), ["A", "B", "C"]) self.assertEqual(str(sc_mr), "X --> Y | ['A', 'B', 'C']")