Skip to content

Commit f78a6d2

Browse files
committed
Switched to using subcommands rather than a boolean flag.
1 parent 80f16c8 commit f78a6d2

File tree

4 files changed

+57
-37
lines changed

4 files changed

+57
-37
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ For more information on how to use the Causal Testing Framework, please refer to
6666
2. If you do not already have causal test cases, you can convert your causal DAG to causal tests by running the following command.
6767

6868
```
69-
python -m causal_testing -G --dag_path $PATH_TO_DAG --output_path $PATH_TO_TESTS
69+
python -m causal_testing generate --dag_path $PATH_TO_DAG --output_path $PATH_TO_TESTS
7070
```
7171

7272
3. You can now execute your tests by running the following command.
7373
```
74-
python -m causal_testing --dag_path $PATH_TO_DAG --data_paths $PATH_TO_DATA --test_config $PATH_TO_TESTS --output $OUTPUT
74+
python -m causal_testing test --dag_path $PATH_TO_DAG --data_paths $PATH_TO_DATA --test_config $PATH_TO_TESTS --output $OUTPUT
7575
```
7676
The results will be saved for inspection in a JSON file located at `$OUTPUT`.
7777
In the future, we hope to add a visualisation tool to assist with this.

causal_testing/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77

88
from causal_testing.testing.metamorphic_relation import generate_causal_tests
9-
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework
9+
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework, Command
1010

1111

1212
def main() -> None:
@@ -19,7 +19,7 @@ def main() -> None:
1919
# Parse arguments
2020
args = parse_args()
2121

22-
if args.generate:
22+
if args.command == Command.GENERATE:
2323
logging.info("Generating causal tests")
2424
generate_causal_tests(args.dag_path, args.output, args.ignore_cycles, args.threads)
2525
logging.info("Causal test generation completed successfully")

causal_testing/main.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pathlib import Path
88
from typing import Dict, Any, Optional, List, Union, Sequence
99
from tqdm import tqdm
10-
10+
from enum import Enum
1111

1212
import pandas as pd
1313
import numpy as np
@@ -26,6 +26,15 @@
2626
logger = logging.getLogger(__name__)
2727

2828

29+
class Command(Enum):
30+
"""
31+
Enum for supported CTF commands.
32+
"""
33+
34+
TEST = "test"
35+
GENERATE = "generate"
36+
37+
2938
@dataclass
3039
class CausalTestingPaths:
3140
"""
@@ -475,35 +484,44 @@ def setup_logging(verbose: bool = False) -> None:
475484

476485
def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
477486
"""Parse command line arguments."""
478-
main_parser = argparse.ArgumentParser(description="Causal Testing Framework")
479-
main_parser.add_argument("-G", "--generate", help="Generate test cases from a DAG", action="store_true")
480-
main_parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
481-
main_parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
482-
main_parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
483-
main_args, _ = main_parser.parse_known_args()
484-
485-
parser = argparse.ArgumentParser(parents=[main_parser])
486-
if main_args.generate:
487-
parser.add_argument(
488-
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
489-
)
490-
else:
491-
parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
492-
parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
493-
parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
494-
parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
495-
parser.add_argument(
496-
"-s",
497-
"--silent",
498-
action="store_true",
499-
help="Do not crash on error. If set to true, errors are recorded as test results.",
500-
default=False,
501-
)
502-
parser.add_argument(
503-
"--batch-size",
504-
type=int,
505-
default=0,
506-
help="Run tests in batches of the specified size (default: 0, which means no batching)",
507-
)
487+
main_parser = argparse.ArgumentParser(add_help=True, description="Causal Testing Framework")
488+
489+
subparsers = main_parser.add_subparsers(
490+
help="The action you want to run - call `causal_testing {action} -h` for further details", dest="command"
491+
)
492+
493+
parser_generate = subparsers.add_parser(Command.GENERATE.value, help="Generate causal tests from a DAG")
494+
parser_generate.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
495+
parser_generate.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
496+
parser_generate.add_argument(
497+
"-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False
498+
)
499+
parser_generate.add_argument(
500+
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
501+
)
502+
503+
parser_test = subparsers.add_parser(Command.TEST.value, help="Run causal tests")
504+
parser_test.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
505+
parser_test.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
506+
parser_test.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
507+
parser_test.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
508+
parser_test.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
509+
parser_test.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
510+
parser_test.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
511+
parser_test.add_argument(
512+
"-s",
513+
"--silent",
514+
action="store_true",
515+
help="Do not crash on error. If set to true, errors are recorded as test results.",
516+
default=False,
517+
)
518+
parser_test.add_argument(
519+
"--batch-size",
520+
type=int,
521+
default=0,
522+
help="Run tests in batches of the specified size (default: 0, which means no batching)",
523+
)
508524

509-
return parser.parse_args(args)
525+
args = main_parser.parse_args(args)
526+
args.command = Command(args.command)
527+
return args

tests/main_tests/test_main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,7 @@ def test_parse_args(self):
305305
"sys.argv",
306306
[
307307
"causal_testing",
308+
"test",
308309
"--dag_path",
309310
str(self.dag_path),
310311
"--data_paths",
@@ -323,6 +324,7 @@ def test_parse_args_batches(self):
323324
"sys.argv",
324325
[
325326
"causal_testing",
327+
"test",
326328
"--dag_path",
327329
str(self.dag_path),
328330
"--data_paths",
@@ -344,7 +346,7 @@ def test_parse_args_generation(self):
344346
"sys.argv",
345347
[
346348
"causal_testing",
347-
"--generate",
349+
"generate",
348350
"--dag_path",
349351
str(self.dag_path),
350352
"--output",

0 commit comments

Comments
 (0)