Skip to content

Moved test generation from DAG files into main functionality #330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions causal_testing/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import os

from causal_testing.testing.metamorphic_relation import generate_causal_tests
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework


Expand All @@ -18,6 +19,12 @@ def main() -> None:
# Parse arguments
args = parse_args()

if args.generate:
logging.info("Generating causal tests")
generate_causal_tests(args.dag_path, args.output, args.ignore_cycles, args.threads)
logging.info("Causal test generation completed successfully")
return

# Setup logging
setup_logging(args.verbose)

Expand Down
51 changes: 30 additions & 21 deletions causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,26 +475,35 @@ def setup_logging(verbose: bool = False) -> None:

def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Causal Testing Framework")
parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
parser.add_argument(
"-s",
"--silent",
action="store_true",
help="Do not crash on error. If set to true, errors are recorded as test results.",
default=False,
)
parser.add_argument(
"--batch-size",
type=int,
default=0,
help="Run tests in batches of the specified size (default: 0, which means no batching)",
)
main_parser = argparse.ArgumentParser(add_help=False, description="Causal Testing Framework")
main_parser.add_argument("-G", "--generate", help="Generate test cases from a DAG", action="store_true")
main_parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
main_parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
main_parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
main_args, _ = main_parser.parse_known_args()

parser = argparse.ArgumentParser(parents=[main_parser])
if main_args.generate:
parser.add_argument(
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
)
else:
parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
parser.add_argument(
"-s",
"--silent",
action="store_true",
help="Do not crash on error. If set to true, errors are recorded as test results.",
default=False,
)
parser.add_argument(
"--batch-size",
type=int,
default=0,
help="Run tests in batches of the specified size (default: 0, which means no batching)",
)

return parser.parse_args(args)
50 changes: 19 additions & 31 deletions causal_testing/testing/metamorphic_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dataclasses import dataclass
from typing import Iterable
from itertools import combinations
import argparse
import logging
import json
from multiprocessing import Pool
Expand Down Expand Up @@ -162,7 +161,7 @@ def generate_metamorphic_relations(
if nodes_to_test is None:
nodes_to_test = dag.nodes

if not threads:
if threads < 2:
metamorphic_relations = [
generate_metamorphic_relation(node_pair, dag, nodes_to_ignore)
for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2)
Expand All @@ -180,36 +179,25 @@ def generate_metamorphic_relations(
return [item for items in metamorphic_relations for item in items]


if __name__ == "__main__": # pragma: no cover
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)
parser = argparse.ArgumentParser(
description="A script for generating metamorphic relations to test the causal relationships in a given DAG."
)
parser.add_argument(
"--dag_path",
"-d",
help="Specify path to file containing the DAG, normally a .dot file.",
required=True,
)
parser.add_argument(
"--output_path",
"-o",
help="Specify path where tests should be saved, normally a .json file.",
required=True,
)
parser.add_argument(
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
)
parser.add_argument("-i", "--ignore-cycles", action="store_true")
args = parser.parse_args()

causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles)
def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0):
"""
Generate and output causal tests for a given DAG.

:param dag_path: Path to the DOT file that specifies the causal DAG.
:param output_path: Path to save the JSON output.
:param ignore_cycles: Whether to bypass the check that the DAG is actually acyclic. If set to true, tests that
include variables that are part of a cycle as either treatment, outcome, or adjustment will
be omitted from the test set.
:param threads: The number of threads to use to generate tests in parallel. If unspecified, tests are generated in
serial. This is tylically fine unless the number of tests to be generated is >10000.
"""
causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles)

dag_nodes_to_test = [
node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag.graph, "test", default=True)[node]
]

if not causal_dag.is_acyclic() and args.ignore_cycles:
if not causal_dag.is_acyclic() and ignore_cycles:
logger.warning(
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
"Your causal test suite WILL NOT BE COMPLETE!"
Expand All @@ -218,17 +206,17 @@ def generate_metamorphic_relations(
causal_dag,
nodes_to_test=dag_nodes_to_test,
nodes_to_ignore=set(causal_dag.cycle_nodes()),
threads=args.threads,
threads=threads,
)
else:
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=args.threads)
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=threads)

tests = [
relation.to_json_stub(skip=False)
for relation in relations
if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0
]

logger.info(f"Generated {len(tests)} tests. Saving to {args.output_path}.")
with open(args.output_path, "w", encoding="utf-8") as f:
logger.info(f"Generated {len(tests)} tests. Saving to {output_path}.")
with open(output_path, "w", encoding="utf-8") as f:
json.dump({"tests": tests}, f, indent=2)
16 changes: 16 additions & 0 deletions tests/main_tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,22 @@ def test_parse_args_batches(self):
main()
self.assertTrue((self.output_path.parent / "main_batch.json").exists())

def test_parse_args_generation(self):
with tempfile.TemporaryDirectory() as tmp:
with unittest.mock.patch(
"sys.argv",
[
"causal_testing",
"--generate",
"--dag_path",
str(self.dag_path),
"--output",
os.path.join(tmp, "tests.json"),
],
):
main()
self.assertTrue(os.path.exists(os.path.join(tmp, "tests.json")))

def tearDown(self):
if self.output_path.parent.exists():
shutil.rmtree(self.output_path.parent)
47 changes: 45 additions & 2 deletions tests/testing_tests/test_metamorphic_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import shutil, tempfile
import pandas as pd
from itertools import combinations
import tempfile
import json

from causal_testing.specification.causal_dag import CausalDAG
from causal_testing.specification.causal_specification import Scenario
Expand All @@ -11,6 +13,7 @@
ShouldNotCause,
generate_metamorphic_relations,
generate_metamorphic_relation,
generate_causal_tests,
)
from causal_testing.specification.variable import Input, Output
from causal_testing.testing.base_test_case import BaseTestCase
Expand Down Expand Up @@ -177,8 +180,8 @@ def test_all_metamorphic_relations_implied_by_dag_parallel(self):
self.assertEqual(missing_snc_relations, [])

def test_all_metamorphic_relations_implied_by_dag_ignore_cycles(self):
dag = CausalDAG(self.dcg_dot_path, ignore_cycles=True)
metamorphic_relations = generate_metamorphic_relations(dag, threads=2, nodes_to_ignore=set(dag.cycle_nodes()))
dcg = CausalDAG(self.dcg_dot_path, ignore_cycles=True)
metamorphic_relations = generate_metamorphic_relations(dcg, threads=2, nodes_to_ignore=set(dcg.cycle_nodes()))
should_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldCause)]
should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldNotCause)]

Expand All @@ -203,6 +206,46 @@ def test_generate_metamorphic_relation_(self):
ShouldCause(BaseTestCase("X1", "Z"), []),
)

def test_generate_causal_tests_ignore_cycles(self):
dcg = CausalDAG(self.dcg_dot_path, ignore_cycles=True)
relations = generate_metamorphic_relations(dcg, nodes_to_ignore=set(dcg.cycle_nodes()))
with tempfile.TemporaryDirectory() as tmp:
tests_file = os.path.join(tmp, "causal_tests.json")
generate_causal_tests(self.dcg_dot_path, tests_file, ignore_cycles=True)
with open(tests_file, encoding="utf8") as f:
tests = json.load(f)
expected = list(
map(
lambda x: x.to_json_stub(skip=False),
filter(
lambda relation: len(list(dcg.graph.predecessors(relation.base_test_case.outcome_variable)))
> 0,
relations,
),
)
)
self.assertEqual(tests["tests"], expected)

def test_generate_causal_tests(self):
dag = CausalDAG(self.dag_dot_path)
relations = generate_metamorphic_relations(dag)
with tempfile.TemporaryDirectory() as tmp:
tests_file = os.path.join(tmp, "causal_tests.json")
generate_causal_tests(self.dag_dot_path, tests_file)
with open(tests_file, encoding="utf8") as f:
tests = json.load(f)
expected = list(
map(
lambda x: x.to_json_stub(skip=False),
filter(
lambda relation: len(list(dag.graph.predecessors(relation.base_test_case.outcome_variable)))
> 0,
relations,
),
)
)
self.assertEqual(tests["tests"], expected)

def test_shoud_cause_string(self):
sc_mr = ShouldCause(BaseTestCase("X", "Y"), ["A", "B", "C"])
self.assertEqual(str(sc_mr), "X --> Y | ['A', 'B', 'C']")
Expand Down