Skip to content

Commit b097a30

Browse files
authored
Merge branch 'main' into jmafoster1/114-validate-treatment-outcome
2 parents 431ac7b + 708e8f7 commit b097a30

File tree

6 files changed

+118
-56
lines changed

6 files changed

+118
-56
lines changed

causal_testing/__main__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import os
77

8+
from causal_testing.testing.metamorphic_relation import generate_causal_tests
89
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework
910

1011

@@ -18,6 +19,12 @@ def main() -> None:
1819
# Parse arguments
1920
args = parse_args()
2021

22+
if args.generate:
23+
logging.info("Generating causal tests")
24+
generate_causal_tests(args.dag_path, args.output, args.ignore_cycles, args.threads)
25+
logging.info("Causal test generation completed successfully")
26+
return
27+
2128
# Setup logging
2229
setup_logging(args.verbose)
2330

causal_testing/main.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -475,26 +475,35 @@ def setup_logging(verbose: bool = False) -> None:
475475

476476
def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
477477
"""Parse command line arguments."""
478-
parser = argparse.ArgumentParser(description="Causal Testing Framework")
479-
parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
480-
parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
481-
parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
482-
parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
483-
parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
484-
parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
485-
parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
486-
parser.add_argument(
487-
"-s",
488-
"--silent",
489-
action="store_true",
490-
help="Do not crash on error. If set to true, errors are recorded as test results.",
491-
default=False,
492-
)
493-
parser.add_argument(
494-
"--batch-size",
495-
type=int,
496-
default=0,
497-
help="Run tests in batches of the specified size (default: 0, which means no batching)",
498-
)
478+
main_parser = argparse.ArgumentParser(add_help=False, description="Causal Testing Framework")
479+
main_parser.add_argument("-G", "--generate", help="Generate test cases from a DAG", action="store_true")
480+
main_parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
481+
main_parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
482+
main_parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
483+
main_args, _ = main_parser.parse_known_args()
484+
485+
parser = argparse.ArgumentParser(parents=[main_parser])
486+
if main_args.generate:
487+
parser.add_argument(
488+
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
489+
)
490+
else:
491+
parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
492+
parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
493+
parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
494+
parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
495+
parser.add_argument(
496+
"-s",
497+
"--silent",
498+
action="store_true",
499+
help="Do not crash on error. If set to true, errors are recorded as test results.",
500+
default=False,
501+
)
502+
parser.add_argument(
503+
"--batch-size",
504+
type=int,
505+
default=0,
506+
help="Run tests in batches of the specified size (default: 0, which means no batching)",
507+
)
499508

500509
return parser.parse_args(args)

causal_testing/testing/metamorphic_relation.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from dataclasses import dataclass
77
from typing import Iterable
88
from itertools import combinations
9-
import argparse
109
import logging
1110
import json
1211
from multiprocessing import Pool
@@ -162,7 +161,7 @@ def generate_metamorphic_relations(
162161
if nodes_to_test is None:
163162
nodes_to_test = dag.nodes
164163

165-
if not threads:
164+
if threads < 2:
166165
metamorphic_relations = [
167166
generate_metamorphic_relation(node_pair, dag, nodes_to_ignore)
168167
for node_pair in combinations(filter(lambda node: node not in nodes_to_ignore, nodes_to_test), 2)
@@ -180,36 +179,25 @@ def generate_metamorphic_relations(
180179
return [item for items in metamorphic_relations for item in items]
181180

182181

183-
if __name__ == "__main__": # pragma: no cover
184-
logging.basicConfig(format="%(levelname)s: %(message)s", level=logging.INFO)
185-
parser = argparse.ArgumentParser(
186-
description="A script for generating metamorphic relations to test the causal relationships in a given DAG."
187-
)
188-
parser.add_argument(
189-
"--dag_path",
190-
"-d",
191-
help="Specify path to file containing the DAG, normally a .dot file.",
192-
required=True,
193-
)
194-
parser.add_argument(
195-
"--output_path",
196-
"-o",
197-
help="Specify path where tests should be saved, normally a .json file.",
198-
required=True,
199-
)
200-
parser.add_argument(
201-
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
202-
)
203-
parser.add_argument("-i", "--ignore-cycles", action="store_true")
204-
args = parser.parse_args()
205-
206-
causal_dag = CausalDAG(args.dag_path, ignore_cycles=args.ignore_cycles)
182+
def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0):
183+
"""
184+
Generate and output causal tests for a given DAG.
185+
186+
:param dag_path: Path to the DOT file that specifies the causal DAG.
187+
:param output_path: Path to save the JSON output.
188+
:param ignore_cycles: Whether to bypass the check that the DAG is actually acyclic. If set to true, tests that
189+
include variables that are part of a cycle as either treatment, outcome, or adjustment will
190+
be omitted from the test set.
191+
:param threads: The number of threads to use to generate tests in parallel. If unspecified, tests are generated in
192+
serial. This is tylically fine unless the number of tests to be generated is >10000.
193+
"""
194+
causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles)
207195

208196
dag_nodes_to_test = [
209197
node for node in causal_dag.nodes if nx.get_node_attributes(causal_dag.graph, "test", default=True)[node]
210198
]
211199

212-
if not causal_dag.is_acyclic() and args.ignore_cycles:
200+
if not causal_dag.is_acyclic() and ignore_cycles:
213201
logger.warning(
214202
"Ignoring cycles by removing causal tests that reference any node within a cycle. "
215203
"Your causal test suite WILL NOT BE COMPLETE!"
@@ -218,17 +206,17 @@ def generate_metamorphic_relations(
218206
causal_dag,
219207
nodes_to_test=dag_nodes_to_test,
220208
nodes_to_ignore=set(causal_dag.cycle_nodes()),
221-
threads=args.threads,
209+
threads=threads,
222210
)
223211
else:
224-
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=args.threads)
212+
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=threads)
225213

226214
tests = [
227215
relation.to_json_stub(skip=False)
228216
for relation in relations
229217
if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0
230218
]
231219

232-
logger.info(f"Generated {len(tests)} tests. Saving to {args.output_path}.")
233-
with open(args.output_path, "w", encoding="utf-8") as f:
220+
logger.info(f"Generated {len(tests)} tests. Saving to {output_path}.")
221+
with open(output_path, "w", encoding="utf-8") as f:
234222
json.dump({"tests": tests}, f, indent=2)

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,13 @@ dependencies = [
2222
"numpy~=1.26",
2323
"pandas>=2.1",
2424
"scikit_learn~=1.4",
25-
"scipy~=1.7",
25+
"scipy>=1.12.0,<1.14.0",
2626
"statsmodels~=0.14",
2727
"tabulate~=0.9",
2828
"pydot~=2.0",
2929
"pygad~=3.3",
3030
"deap~=1.4.1",
3131
"sympy~=1.13.1",
32-
"deap~=1.4.1",
3332
"pyarrow~=19.0.1",
3433
"fastparquet~=2024.11.0",
3534
]

tests/main_tests/test_main.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,22 @@ def test_parse_args_batches(self):
338338
main()
339339
self.assertTrue((self.output_path.parent / "main_batch.json").exists())
340340

341+
def test_parse_args_generation(self):
342+
with tempfile.TemporaryDirectory() as tmp:
343+
with unittest.mock.patch(
344+
"sys.argv",
345+
[
346+
"causal_testing",
347+
"--generate",
348+
"--dag_path",
349+
str(self.dag_path),
350+
"--output",
351+
os.path.join(tmp, "tests.json"),
352+
],
353+
):
354+
main()
355+
self.assertTrue(os.path.exists(os.path.join(tmp, "tests.json")))
356+
341357
def tearDown(self):
342358
if self.output_path.parent.exists():
343359
shutil.rmtree(self.output_path.parent)

tests/testing_tests/test_metamorphic_relations.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import shutil, tempfile
44
import pandas as pd
55
from itertools import combinations
6+
import tempfile
7+
import json
68

79
from causal_testing.specification.causal_dag import CausalDAG
810
from causal_testing.specification.causal_specification import Scenario
@@ -11,6 +13,7 @@
1113
ShouldNotCause,
1214
generate_metamorphic_relations,
1315
generate_metamorphic_relation,
16+
generate_causal_tests,
1417
)
1518
from causal_testing.specification.variable import Input, Output
1619
from causal_testing.testing.base_test_case import BaseTestCase
@@ -177,8 +180,8 @@ def test_all_metamorphic_relations_implied_by_dag_parallel(self):
177180
self.assertEqual(missing_snc_relations, [])
178181

179182
def test_all_metamorphic_relations_implied_by_dag_ignore_cycles(self):
180-
dag = CausalDAG(self.dcg_dot_path, ignore_cycles=True)
181-
metamorphic_relations = generate_metamorphic_relations(dag, threads=2, nodes_to_ignore=set(dag.cycle_nodes()))
183+
dcg = CausalDAG(self.dcg_dot_path, ignore_cycles=True)
184+
metamorphic_relations = generate_metamorphic_relations(dcg, threads=2, nodes_to_ignore=set(dcg.cycle_nodes()))
182185
should_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldCause)]
183186
should_not_cause_relations = [mr for mr in metamorphic_relations if isinstance(mr, ShouldNotCause)]
184187

@@ -203,6 +206,46 @@ def test_generate_metamorphic_relation_(self):
203206
ShouldCause(BaseTestCase("X1", "Z"), []),
204207
)
205208

209+
def test_generate_causal_tests_ignore_cycles(self):
210+
dcg = CausalDAG(self.dcg_dot_path, ignore_cycles=True)
211+
relations = generate_metamorphic_relations(dcg, nodes_to_ignore=set(dcg.cycle_nodes()))
212+
with tempfile.TemporaryDirectory() as tmp:
213+
tests_file = os.path.join(tmp, "causal_tests.json")
214+
generate_causal_tests(self.dcg_dot_path, tests_file, ignore_cycles=True)
215+
with open(tests_file, encoding="utf8") as f:
216+
tests = json.load(f)
217+
expected = list(
218+
map(
219+
lambda x: x.to_json_stub(skip=False),
220+
filter(
221+
lambda relation: len(list(dcg.graph.predecessors(relation.base_test_case.outcome_variable)))
222+
> 0,
223+
relations,
224+
),
225+
)
226+
)
227+
self.assertEqual(tests["tests"], expected)
228+
229+
def test_generate_causal_tests(self):
230+
dag = CausalDAG(self.dag_dot_path)
231+
relations = generate_metamorphic_relations(dag)
232+
with tempfile.TemporaryDirectory() as tmp:
233+
tests_file = os.path.join(tmp, "causal_tests.json")
234+
generate_causal_tests(self.dag_dot_path, tests_file)
235+
with open(tests_file, encoding="utf8") as f:
236+
tests = json.load(f)
237+
expected = list(
238+
map(
239+
lambda x: x.to_json_stub(skip=False),
240+
filter(
241+
lambda relation: len(list(dag.graph.predecessors(relation.base_test_case.outcome_variable)))
242+
> 0,
243+
relations,
244+
),
245+
)
246+
)
247+
self.assertEqual(tests["tests"], expected)
248+
206249
def test_shoud_cause_string(self):
207250
sc_mr = ShouldCause(BaseTestCase("X", "Y"), ["A", "B", "C"])
208251
self.assertEqual(str(sc_mr), "X --> Y | ['A', 'B', 'C']")

0 commit comments

Comments
 (0)