Skip to content

Commit 0db9294

Browse files
authored
Merge pull request #346 from CITCOM-project/jmafoster1/344-docker-image
Changed main run command to facilitate test generation
2 parents b162d8f + 7cce369 commit 0db9294

File tree

6 files changed

+202
-57
lines changed

6 files changed

+202
-57
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ For more information on how to use the Causal Testing Framework, please refer to
6666
2. If you do not already have causal test cases, you can convert your causal DAG to causal tests by running the following command.
6767

6868
```
69-
python causal_testing/testing/metamorphic_relation.py --dag_path $PATH_TO_DAG --output_path $PATH_TO_TESTS
69+
python -m causal_testing generate --dag_path $PATH_TO_DAG --output_path $PATH_TO_TESTS
7070
```
7171

7272
3. You can now execute your tests by running the following command.
7373
```
74-
python -m causal_testing --dag_path $PATH_TO_DAG --data_paths $PATH_TO_DATA --test_config $PATH_TO_TESTS --output $OUTPUT
74+
python -m causal_testing test --dag_path $PATH_TO_DAG --data_paths $PATH_TO_DATA --test_config $PATH_TO_TESTS --output $OUTPUT
7575
```
7676
The results will be saved for inspection in a JSON file located at `$OUTPUT`.
7777
In the future, we hope to add a visualisation tool to assist with this.

causal_testing/__main__.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77

88
from causal_testing.testing.metamorphic_relation import generate_causal_tests
9-
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework
9+
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework, Command
1010

1111

1212
def main() -> None:
@@ -19,9 +19,18 @@ def main() -> None:
1919
# Parse arguments
2020
args = parse_args()
2121

22-
if args.generate:
22+
if args.command == Command.GENERATE:
2323
logging.info("Generating causal tests")
24-
generate_causal_tests(args.dag_path, args.output, args.ignore_cycles, args.threads)
24+
generate_causal_tests(
25+
args.dag_path,
26+
args.output,
27+
args.ignore_cycles,
28+
args.threads,
29+
effect_type=args.effect_type,
30+
estimate_type=args.estimate_type,
31+
estimator=args.estimator,
32+
skip=True,
33+
)
2534
logging.info("Causal test generation completed successfully")
2635
return
2736

causal_testing/main.py

Lines changed: 71 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
import argparse
44
import json
55
import logging
6+
from enum import Enum
67
from dataclasses import dataclass
78
from pathlib import Path
89
from typing import Dict, Any, Optional, List, Union, Sequence
9-
from tqdm import tqdm
10-
1110

11+
from tqdm import tqdm
1212
import pandas as pd
1313
import numpy as np
1414

@@ -26,6 +26,15 @@
2626
logger = logging.getLogger(__name__)
2727

2828

29+
class Command(Enum):
30+
"""
31+
Enum for supported CTF commands.
32+
"""
33+
34+
TEST = "test"
35+
GENERATE = "generate"
36+
37+
2938
@dataclass
3039
class CausalTestingPaths:
3140
"""
@@ -475,35 +484,64 @@ def setup_logging(verbose: bool = False) -> None:
475484

476485
def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
477486
"""Parse command line arguments."""
478-
main_parser = argparse.ArgumentParser(add_help=False, description="Causal Testing Framework")
479-
main_parser.add_argument("-G", "--generate", help="Generate test cases from a DAG", action="store_true")
480-
main_parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
481-
main_parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
482-
main_parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
483-
main_args, _ = main_parser.parse_known_args()
484-
485-
parser = argparse.ArgumentParser(parents=[main_parser])
486-
if main_args.generate:
487-
parser.add_argument(
488-
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
489-
)
490-
else:
491-
parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
492-
parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
493-
parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
494-
parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
495-
parser.add_argument(
496-
"-s",
497-
"--silent",
498-
action="store_true",
499-
help="Do not crash on error. If set to true, errors are recorded as test results.",
500-
default=False,
501-
)
502-
parser.add_argument(
503-
"--batch-size",
504-
type=int,
505-
default=0,
506-
help="Run tests in batches of the specified size (default: 0, which means no batching)",
507-
)
487+
main_parser = argparse.ArgumentParser(add_help=True, description="Causal Testing Framework")
488+
489+
subparsers = main_parser.add_subparsers(
490+
help="The action you want to run - call `causal_testing {action} -h` for further details", dest="command"
491+
)
492+
493+
# Generation
494+
parser_generate = subparsers.add_parser(Command.GENERATE.value, help="Generate causal tests from a DAG")
495+
parser_generate.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
496+
parser_generate.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
497+
parser_generate.add_argument(
498+
"-e",
499+
"--estimator",
500+
help="The name of the estimator class to use when evaluating tests (defaults to LinearRegressionEstimator)",
501+
default="LinearRegressionEstimator",
502+
)
503+
parser_generate.add_argument(
504+
"-T",
505+
"--effect_type",
506+
help="The effect type to estimate {direct, total}",
507+
default="direct",
508+
)
509+
parser_generate.add_argument(
510+
"-E",
511+
"--estimate_type",
512+
help="The estimate type to use when evaluating tests (defaults to coefficient)",
513+
default="coefficient",
514+
)
515+
parser_generate.add_argument(
516+
"-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False
517+
)
518+
parser_generate.add_argument(
519+
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
520+
)
521+
522+
# Testing
523+
parser_test = subparsers.add_parser(Command.TEST.value, help="Run causal tests")
524+
parser_test.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
525+
parser_test.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
526+
parser_test.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
527+
parser_test.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
528+
parser_test.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
529+
parser_test.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
530+
parser_test.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
531+
parser_test.add_argument(
532+
"-s",
533+
"--silent",
534+
action="store_true",
535+
help="Do not crash on error. If set to true, errors are recorded as test results.",
536+
default=False,
537+
)
538+
parser_test.add_argument(
539+
"--batch-size",
540+
type=int,
541+
default=0,
542+
help="Run tests in batches of the specified size (default: 0, which means no batching)",
543+
)
508544

509-
return parser.parse_args(args)
545+
args = main_parser.parse_args(args)
546+
args.command = Command(args.command)
547+
return args

causal_testing/testing/metamorphic_relation.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,25 @@ def __eq__(self, other):
3737
class ShouldCause(MetamorphicRelation):
3838
"""Class representing a should cause metamorphic relation."""
3939

40-
def to_json_stub(self, skip=True) -> dict:
41-
"""Convert to a JSON frontend stub string for user customisation"""
40+
def to_json_stub(
41+
self,
42+
skip: bool = True,
43+
estimate_type: str = "coefficient",
44+
effect_type: str = "direct",
45+
estimator: str = "LinearRegressionEstimator",
46+
) -> dict:
47+
"""
48+
Convert to a JSON frontend stub string for user customisation.
49+
:param skip: Whether to skip the test
50+
:param effect_type: The type of causal effect to consider (total or direct)
51+
:param estimate_type: The estimate type to use when evaluating tests
52+
:param estimator: The name of the estimator class to use when evaluating the test
53+
"""
4254
return {
4355
"name": str(self),
44-
"estimator": "LinearRegressionEstimator",
45-
"estimate_type": "coefficient",
46-
"effect": "direct",
56+
"estimator": estimator,
57+
"estimate_type": estimate_type,
58+
"effect": effect_type,
4759
"treatment_variable": self.base_test_case.treatment_variable,
4860
"expected_effect": {self.base_test_case.outcome_variable: "SomeEffect"},
4961
"formula": (
@@ -63,13 +75,25 @@ def __str__(self):
6375
class ShouldNotCause(MetamorphicRelation):
6476
"""Class representing a should cause metamorphic relation."""
6577

66-
def to_json_stub(self, skip=True) -> dict:
67-
"""Convert to a JSON frontend stub string for user customisation"""
78+
def to_json_stub(
79+
self,
80+
skip: bool = True,
81+
estimate_type: str = "coefficient",
82+
effect_type: str = "direct",
83+
estimator: str = "LinearRegressionEstimator",
84+
) -> dict:
85+
"""
86+
Convert to a JSON frontend stub string for user customisation.
87+
:param skip: Whether to skip the test
88+
:param effect_type: The type of causal effect to consider (total or direct)
89+
:param estimate_type: The estimate type to use when evaluating tests
90+
:param estimator: The name of the estimator class to use when evaluating the test
91+
"""
6892
return {
6993
"name": str(self),
70-
"estimator": "LinearRegressionEstimator",
71-
"estimate_type": "coefficient",
72-
"effect": "direct",
94+
"estimator": estimator,
95+
"estimate_type": estimate_type,
96+
"effect": effect_type,
7397
"treatment_variable": self.base_test_case.treatment_variable,
7498
"expected_effect": {self.base_test_case.outcome_variable: "NoEffect"},
7599
"formula": (
@@ -179,7 +203,9 @@ def generate_metamorphic_relations(
179203
return [item for items in metamorphic_relations for item in items]
180204

181205

182-
def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0):
206+
def generate_causal_tests(
207+
dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0, **json_stub_kargs
208+
):
183209
"""
184210
Generate and output causal tests for a given DAG.
185211
@@ -190,6 +216,7 @@ def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool =
190216
be omitted from the test set.
191217
:param threads: The number of threads to use to generate tests in parallel. If unspecified, tests are generated in
192218
serial. This is tylically fine unless the number of tests to be generated is >10000.
219+
:param json_stub_kargs: Kwargs to pass into `to_json_stub` (see docstring for details.)
193220
"""
194221
causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles)
195222

@@ -212,7 +239,7 @@ def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool =
212239
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=threads)
213240

214241
tests = [
215-
relation.to_json_stub(skip=False)
242+
relation.to_json_stub(**json_stub_kargs)
216243
for relation in relations
217244
if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0
218245
]

tests/main_tests/test_main.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def test_parse_args(self):
305305
"sys.argv",
306306
[
307307
"causal_testing",
308+
"test",
308309
"--dag_path",
309310
str(self.dag_path),
310311
"--data_paths",
@@ -323,6 +324,7 @@ def test_parse_args_batches(self):
323324
"sys.argv",
324325
[
325326
"causal_testing",
327+
"test",
326328
"--dag_path",
327329
str(self.dag_path),
328330
"--data_paths",
@@ -344,7 +346,7 @@ def test_parse_args_generation(self):
344346
"sys.argv",
345347
[
346348
"causal_testing",
347-
"--generate",
349+
"generate",
348350
"--dag_path",
349351
str(self.dag_path),
350352
"--output",
@@ -354,6 +356,28 @@ def test_parse_args_generation(self):
354356
main()
355357
self.assertTrue(os.path.exists(os.path.join(tmp, "tests.json")))
356358

359+
def test_parse_args_generation_non_default(self):
360+
with tempfile.TemporaryDirectory() as tmp:
361+
with unittest.mock.patch(
362+
"sys.argv",
363+
[
364+
"causal_testing",
365+
"generate",
366+
"--dag_path",
367+
str(self.dag_path),
368+
"--output",
369+
os.path.join(tmp, "tests_non_default.json"),
370+
"--estimator",
371+
"LogisticRegressionEstimator",
372+
"--estimate_type",
373+
"unit_odds_ratio",
374+
"--effect_type",
375+
"total",
376+
],
377+
):
378+
main()
379+
self.assertTrue(os.path.exists(os.path.join(tmp, "tests_non_default.json")))
380+
357381
def tearDown(self):
358382
if self.output_path.parent.exists():
359383
shutil.rmtree(self.output_path.parent)

0 commit comments

Comments
 (0)