|
3 | 3 | import argparse
|
4 | 4 | import json
|
5 | 5 | import logging
|
| 6 | +from enum import Enum |
6 | 7 | from dataclasses import dataclass
|
7 | 8 | from pathlib import Path
|
8 | 9 | from typing import Dict, Any, Optional, List, Union, Sequence
|
9 |
| -from tqdm import tqdm |
10 |
| - |
11 | 10 |
|
| 11 | +from tqdm import tqdm |
12 | 12 | import pandas as pd
|
13 | 13 | import numpy as np
|
14 | 14 |
|
|
26 | 26 | logger = logging.getLogger(__name__)
|
27 | 27 |
|
28 | 28 |
|
| 29 | +class Command(Enum): |
| 30 | + """ |
| 31 | + Enum for supported CTF commands. |
| 32 | + """ |
| 33 | + |
| 34 | + TEST = "test" |
| 35 | + GENERATE = "generate" |
| 36 | + |
| 37 | + |
29 | 38 | @dataclass
|
30 | 39 | class CausalTestingPaths:
|
31 | 40 | """
|
@@ -475,35 +484,64 @@ def setup_logging(verbose: bool = False) -> None:
|
475 | 484 |
|
476 | 485 | def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
|
477 | 486 | """Parse command line arguments."""
|
478 |
| - main_parser = argparse.ArgumentParser(add_help=False, description="Causal Testing Framework") |
479 |
| - main_parser.add_argument("-G", "--generate", help="Generate test cases from a DAG", action="store_true") |
480 |
| - main_parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True) |
481 |
| - main_parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True) |
482 |
| - main_parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False) |
483 |
| - main_args, _ = main_parser.parse_known_args() |
484 |
| - |
485 |
| - parser = argparse.ArgumentParser(parents=[main_parser]) |
486 |
| - if main_args.generate: |
487 |
| - parser.add_argument( |
488 |
| - "--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0 |
489 |
| - ) |
490 |
| - else: |
491 |
| - parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True) |
492 |
| - parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True) |
493 |
| - parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False) |
494 |
| - parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str) |
495 |
| - parser.add_argument( |
496 |
| - "-s", |
497 |
| - "--silent", |
498 |
| - action="store_true", |
499 |
| - help="Do not crash on error. If set to true, errors are recorded as test results.", |
500 |
| - default=False, |
501 |
| - ) |
502 |
| - parser.add_argument( |
503 |
| - "--batch-size", |
504 |
| - type=int, |
505 |
| - default=0, |
506 |
| - help="Run tests in batches of the specified size (default: 0, which means no batching)", |
507 |
| - ) |
| 487 | + main_parser = argparse.ArgumentParser(add_help=True, description="Causal Testing Framework") |
| 488 | + |
| 489 | + subparsers = main_parser.add_subparsers( |
| 490 | + help="The action you want to run - call `causal_testing {action} -h` for further details", dest="command" |
| 491 | + ) |
| 492 | + |
| 493 | + # Generation |
| 494 | + parser_generate = subparsers.add_parser(Command.GENERATE.value, help="Generate causal tests from a DAG") |
| 495 | + parser_generate.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True) |
| 496 | + parser_generate.add_argument("-o", "--output", help="Path for output file (.json)", required=True) |
| 497 | + parser_generate.add_argument( |
| 498 | + "-e", |
| 499 | + "--estimator", |
| 500 | + help="The name of the estimator class to use when evaluating tests (defaults to LinearRegressionEstimator)", |
| 501 | + default="LinearRegressionEstimator", |
| 502 | + ) |
| 503 | + parser_generate.add_argument( |
| 504 | + "-T", |
| 505 | + "--effect_type", |
| 506 | + help="The effect type to estimate {direct, total}", |
| 507 | + default="direct", |
| 508 | + ) |
| 509 | + parser_generate.add_argument( |
| 510 | + "-E", |
| 511 | + "--estimate_type", |
| 512 | + help="The estimate type to use when evaluating tests (defaults to coefficient)", |
| 513 | + default="coefficient", |
| 514 | + ) |
| 515 | + parser_generate.add_argument( |
| 516 | + "-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False |
| 517 | + ) |
| 518 | + parser_generate.add_argument( |
| 519 | + "--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0 |
| 520 | + ) |
| 521 | + |
| 522 | + # Testing |
| 523 | + parser_test = subparsers.add_parser(Command.TEST.value, help="Run causal tests") |
| 524 | + parser_test.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True) |
| 525 | + parser_test.add_argument("-o", "--output", help="Path for output file (.json)", required=True) |
| 526 | + parser_test.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False) |
| 527 | + parser_test.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True) |
| 528 | + parser_test.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True) |
| 529 | + parser_test.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False) |
| 530 | + parser_test.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str) |
| 531 | + parser_test.add_argument( |
| 532 | + "-s", |
| 533 | + "--silent", |
| 534 | + action="store_true", |
| 535 | + help="Do not crash on error. If set to true, errors are recorded as test results.", |
| 536 | + default=False, |
| 537 | + ) |
| 538 | + parser_test.add_argument( |
| 539 | + "--batch-size", |
| 540 | + type=int, |
| 541 | + default=0, |
| 542 | + help="Run tests in batches of the specified size (default: 0, which means no batching)", |
| 543 | + ) |
508 | 544 |
|
509 |
| - return parser.parse_args(args) |
| 545 | + args = main_parser.parse_args(args) |
| 546 | + args.command = Command(args.command) |
| 547 | + return args |
0 commit comments