Skip to content

Changed main run command to facilitate test generation #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ For more information on how to use the Causal Testing Framework, please refer to
2. If you do not already have causal test cases, you can convert your causal DAG to causal tests by running the following command.

```
python causal_testing/testing/metamorphic_relation.py --dag_path $PATH_TO_DAG --output_path $PATH_TO_TESTS
python -m causal_testing generate --dag_path $PATH_TO_DAG --output_path $PATH_TO_TESTS
```

3. You can now execute your tests by running the following command.
```
python -m causal_testing --dag_path $PATH_TO_DAG --data_paths $PATH_TO_DATA --test_config $PATH_TO_TESTS --output $OUTPUT
python -m causal_testing test --dag_path $PATH_TO_DAG --data_paths $PATH_TO_DATA --test_config $PATH_TO_TESTS --output $OUTPUT
```
The results will be saved for inspection in a JSON file located at `$OUTPUT`.
In the future, we hope to add a visualisation tool to assist with this.
Expand Down
15 changes: 12 additions & 3 deletions causal_testing/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os

from causal_testing.testing.metamorphic_relation import generate_causal_tests
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework
from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework, Command


def main() -> None:
Expand All @@ -19,9 +19,18 @@ def main() -> None:
# Parse arguments
args = parse_args()

if args.generate:
if args.command == Command.GENERATE:
logging.info("Generating causal tests")
generate_causal_tests(args.dag_path, args.output, args.ignore_cycles, args.threads)
generate_causal_tests(
args.dag_path,
args.output,
args.ignore_cycles,
args.threads,
effect_type=args.effect_type,
estimate_type=args.estimate_type,
estimator=args.estimator,
skip=True,
)
logging.info("Causal test generation completed successfully")
return

Expand Down
104 changes: 71 additions & 33 deletions causal_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import argparse
import json
import logging
from enum import Enum
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Any, Optional, List, Union, Sequence
from tqdm import tqdm


from tqdm import tqdm
import pandas as pd
import numpy as np

Expand All @@ -26,6 +26,15 @@
logger = logging.getLogger(__name__)


class Command(Enum):
"""
Enum for supported CTF commands.
"""

TEST = "test"
GENERATE = "generate"


@dataclass
class CausalTestingPaths:
"""
Expand Down Expand Up @@ -475,35 +484,64 @@ def setup_logging(verbose: bool = False) -> None:

def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
"""Parse command line arguments."""
main_parser = argparse.ArgumentParser(add_help=False, description="Causal Testing Framework")
main_parser.add_argument("-G", "--generate", help="Generate test cases from a DAG", action="store_true")
main_parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
main_parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
main_parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
main_args, _ = main_parser.parse_known_args()

parser = argparse.ArgumentParser(parents=[main_parser])
if main_args.generate:
parser.add_argument(
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
)
else:
parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
parser.add_argument(
"-s",
"--silent",
action="store_true",
help="Do not crash on error. If set to true, errors are recorded as test results.",
default=False,
)
parser.add_argument(
"--batch-size",
type=int,
default=0,
help="Run tests in batches of the specified size (default: 0, which means no batching)",
)
main_parser = argparse.ArgumentParser(add_help=True, description="Causal Testing Framework")

subparsers = main_parser.add_subparsers(
help="The action you want to run - call `causal_testing {action} -h` for further details", dest="command"
)

# Generation
parser_generate = subparsers.add_parser(Command.GENERATE.value, help="Generate causal tests from a DAG")
parser_generate.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
parser_generate.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
parser_generate.add_argument(
"-e",
"--estimator",
help="The name of the estimator class to use when evaluating tests (defaults to LinearRegressionEstimator)",
default="LinearRegressionEstimator",
)
parser_generate.add_argument(
"-T",
"--effect_type",
help="The effect type to estimate {direct, total}",
default="direct",
)
parser_generate.add_argument(
"-E",
"--estimate_type",
help="The estimate type to use when evaluating tests (defaults to coefficient)",
default="coefficient",
)
parser_generate.add_argument(
"-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False
)
parser_generate.add_argument(
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
)

# Testing
parser_test = subparsers.add_parser(Command.TEST.value, help="Run causal tests")
parser_test.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
parser_test.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
parser_test.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
parser_test.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
parser_test.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
parser_test.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
parser_test.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
parser_test.add_argument(
"-s",
"--silent",
action="store_true",
help="Do not crash on error. If set to true, errors are recorded as test results.",
default=False,
)
parser_test.add_argument(
"--batch-size",
type=int,
default=0,
help="Run tests in batches of the specified size (default: 0, which means no batching)",
)

return parser.parse_args(args)
args = main_parser.parse_args(args)
args.command = Command(args.command)
return args
51 changes: 39 additions & 12 deletions causal_testing/testing/metamorphic_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,25 @@ def __eq__(self, other):
class ShouldCause(MetamorphicRelation):
"""Class representing a should cause metamorphic relation."""

def to_json_stub(self, skip=True) -> dict:
"""Convert to a JSON frontend stub string for user customisation"""
def to_json_stub(
self,
skip: bool = True,
estimate_type: str = "coefficient",
effect_type: str = "direct",
estimator: str = "LinearRegressionEstimator",
) -> dict:
"""
Convert to a JSON frontend stub string for user customisation.
:param skip: Whether to skip the test
:param effect_type: The type of causal effect to consider (total or direct)
:param estimate_type: The estimate type to use when evaluating tests
:param estimator: The name of the estimator class to use when evaluating the test
"""
return {
"name": str(self),
"estimator": "LinearRegressionEstimator",
"estimate_type": "coefficient",
"effect": "direct",
"estimator": estimator,
"estimate_type": estimate_type,
"effect": effect_type,
"treatment_variable": self.base_test_case.treatment_variable,
"expected_effect": {self.base_test_case.outcome_variable: "SomeEffect"},
"formula": (
Expand All @@ -63,13 +75,25 @@ def __str__(self):
class ShouldNotCause(MetamorphicRelation):
"""Class representing a should cause metamorphic relation."""

def to_json_stub(self, skip=True) -> dict:
"""Convert to a JSON frontend stub string for user customisation"""
def to_json_stub(
self,
skip: bool = True,
estimate_type: str = "coefficient",
effect_type: str = "direct",
estimator: str = "LinearRegressionEstimator",
) -> dict:
"""
Convert to a JSON frontend stub string for user customisation.
:param skip: Whether to skip the test
:param effect_type: The type of causal effect to consider (total or direct)
:param estimate_type: The estimate type to use when evaluating tests
:param estimator: The name of the estimator class to use when evaluating the test
"""
return {
"name": str(self),
"estimator": "LinearRegressionEstimator",
"estimate_type": "coefficient",
"effect": "direct",
"estimator": estimator,
"estimate_type": estimate_type,
"effect": effect_type,
"treatment_variable": self.base_test_case.treatment_variable,
"expected_effect": {self.base_test_case.outcome_variable: "NoEffect"},
"formula": (
Expand Down Expand Up @@ -179,7 +203,9 @@ def generate_metamorphic_relations(
return [item for items in metamorphic_relations for item in items]


def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0):
def generate_causal_tests(
dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0, **json_stub_kargs
):
"""
Generate and output causal tests for a given DAG.

Expand All @@ -190,6 +216,7 @@ def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool =
be omitted from the test set.
:param threads: The number of threads to use to generate tests in parallel. If unspecified, tests are generated in
serial. This is tylically fine unless the number of tests to be generated is >10000.
:param json_stub_kargs: Kwargs to pass into `to_json_stub` (see docstring for details.)
"""
causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles)

Expand All @@ -212,7 +239,7 @@ def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool =
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=threads)

tests = [
relation.to_json_stub(skip=False)
relation.to_json_stub(**json_stub_kargs)
for relation in relations
if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0
]
Expand Down
26 changes: 25 additions & 1 deletion tests/main_tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def test_parse_args(self):
"sys.argv",
[
"causal_testing",
"test",
"--dag_path",
str(self.dag_path),
"--data_paths",
Expand All @@ -323,6 +324,7 @@ def test_parse_args_batches(self):
"sys.argv",
[
"causal_testing",
"test",
"--dag_path",
str(self.dag_path),
"--data_paths",
Expand All @@ -344,7 +346,7 @@ def test_parse_args_generation(self):
"sys.argv",
[
"causal_testing",
"--generate",
"generate",
"--dag_path",
str(self.dag_path),
"--output",
Expand All @@ -354,6 +356,28 @@ def test_parse_args_generation(self):
main()
self.assertTrue(os.path.exists(os.path.join(tmp, "tests.json")))

def test_parse_args_generation_non_default(self):
with tempfile.TemporaryDirectory() as tmp:
with unittest.mock.patch(
"sys.argv",
[
"causal_testing",
"generate",
"--dag_path",
str(self.dag_path),
"--output",
os.path.join(tmp, "tests_non_default.json"),
"--estimator",
"LogisticRegressionEstimator",
"--estimate_type",
"unit_odds_ratio",
"--effect_type",
"total",
],
):
main()
self.assertTrue(os.path.exists(os.path.join(tmp, "tests_non_default.json")))

def tearDown(self):
if self.output_path.parent.exists():
shutil.rmtree(self.output_path.parent)
Loading
Loading