diff --git a/README.md b/README.md index 8b37f3fe..5d31817c 100644 --- a/README.md +++ b/README.md @@ -66,12 +66,12 @@ For more information on how to use the Causal Testing Framework, please refer to 2. If you do not already have causal test cases, you can convert your causal DAG to causal tests by running the following command. ``` -python causal_testing/testing/metamorphic_relation.py --dag_path $PATH_TO_DAG --output_path $PATH_TO_TESTS +python -m causal_testing generate --dag_path $PATH_TO_DAG --output_path $PATH_TO_TESTS ``` 3. You can now execute your tests by running the following command. ``` -python -m causal_testing --dag_path $PATH_TO_DAG --data_paths $PATH_TO_DATA --test_config $PATH_TO_TESTS --output $OUTPUT +python -m causal_testing test --dag_path $PATH_TO_DAG --data_paths $PATH_TO_DATA --test_config $PATH_TO_TESTS --output $OUTPUT ``` The results will be saved for inspection in a JSON file located at `$OUTPUT`. In the future, we hope to add a visualisation tool to assist with this. diff --git a/causal_testing/__main__.py b/causal_testing/__main__.py index 3b84cd9c..bfe3fd1f 100644 --- a/causal_testing/__main__.py +++ b/causal_testing/__main__.py @@ -6,7 +6,7 @@ import os from causal_testing.testing.metamorphic_relation import generate_causal_tests -from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework +from .main import setup_logging, parse_args, CausalTestingPaths, CausalTestingFramework, Command def main() -> None: @@ -19,9 +19,18 @@ def main() -> None: # Parse arguments args = parse_args() - if args.generate: + if args.command == Command.GENERATE: logging.info("Generating causal tests") - generate_causal_tests(args.dag_path, args.output, args.ignore_cycles, args.threads) + generate_causal_tests( + args.dag_path, + args.output, + args.ignore_cycles, + args.threads, + effect_type=args.effect_type, + estimate_type=args.estimate_type, + estimator=args.estimator, + skip=True, + ) logging.info("Causal test generation completed successfully") return diff --git a/causal_testing/main.py b/causal_testing/main.py index bb972ca6..34f85d65 100644 --- a/causal_testing/main.py +++ b/causal_testing/main.py @@ -3,12 +3,12 @@ import argparse import json import logging +from enum import Enum from dataclasses import dataclass from pathlib import Path from typing import Dict, Any, Optional, List, Union, Sequence -from tqdm import tqdm - +from tqdm import tqdm import pandas as pd import numpy as np @@ -26,6 +26,15 @@ logger = logging.getLogger(__name__) +class Command(Enum): + """ + Enum for supported CTF commands. + """ + + TEST = "test" + GENERATE = "generate" + + @dataclass class CausalTestingPaths: """ @@ -475,35 +484,64 @@ def setup_logging(verbose: bool = False) -> None: def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace: """Parse command line arguments.""" - main_parser = argparse.ArgumentParser(add_help=False, description="Causal Testing Framework") - main_parser.add_argument("-G", "--generate", help="Generate test cases from a DAG", action="store_true") - main_parser.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True) - main_parser.add_argument("-o", "--output", help="Path for output file (.json)", required=True) - main_parser.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False) - main_args, _ = main_parser.parse_known_args() - - parser = argparse.ArgumentParser(parents=[main_parser]) - if main_args.generate: - parser.add_argument( - "--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0 - ) - else: - parser.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True) - parser.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True) - parser.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False) - parser.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str) - parser.add_argument( - "-s", - "--silent", - action="store_true", - help="Do not crash on error. If set to true, errors are recorded as test results.", - default=False, - ) - parser.add_argument( - "--batch-size", - type=int, - default=0, - help="Run tests in batches of the specified size (default: 0, which means no batching)", - ) + main_parser = argparse.ArgumentParser(add_help=True, description="Causal Testing Framework") + + subparsers = main_parser.add_subparsers( + help="The action you want to run - call `causal_testing {action} -h` for further details", dest="command" + ) + + # Generation + parser_generate = subparsers.add_parser(Command.GENERATE.value, help="Generate causal tests from a DAG") + parser_generate.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True) + parser_generate.add_argument("-o", "--output", help="Path for output file (.json)", required=True) + parser_generate.add_argument( + "-e", + "--estimator", + help="The name of the estimator class to use when evaluating tests (defaults to LinearRegressionEstimator)", + default="LinearRegressionEstimator", + ) + parser_generate.add_argument( + "-T", + "--effect_type", + help="The effect type to estimate {direct, total}", + default="direct", + ) + parser_generate.add_argument( + "-E", + "--estimate_type", + help="The estimate type to use when evaluating tests (defaults to coefficient)", + default="coefficient", + ) + parser_generate.add_argument( + "-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False + ) + parser_generate.add_argument( + "--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0 + ) + + # Testing + parser_test = subparsers.add_parser(Command.TEST.value, help="Run causal tests") + parser_test.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True) + parser_test.add_argument("-o", "--output", help="Path for output file (.json)", required=True) + parser_test.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False) + parser_test.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True) + parser_test.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True) + parser_test.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False) + parser_test.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str) + parser_test.add_argument( + "-s", + "--silent", + action="store_true", + help="Do not crash on error. If set to true, errors are recorded as test results.", + default=False, + ) + parser_test.add_argument( + "--batch-size", + type=int, + default=0, + help="Run tests in batches of the specified size (default: 0, which means no batching)", + ) - return parser.parse_args(args) + args = main_parser.parse_args(args) + args.command = Command(args.command) + return args diff --git a/causal_testing/testing/metamorphic_relation.py b/causal_testing/testing/metamorphic_relation.py index 95bef400..55b4381f 100644 --- a/causal_testing/testing/metamorphic_relation.py +++ b/causal_testing/testing/metamorphic_relation.py @@ -37,13 +37,25 @@ def __eq__(self, other): class ShouldCause(MetamorphicRelation): """Class representing a should cause metamorphic relation.""" - def to_json_stub(self, skip=True) -> dict: - """Convert to a JSON frontend stub string for user customisation""" + def to_json_stub( + self, + skip: bool = True, + estimate_type: str = "coefficient", + effect_type: str = "direct", + estimator: str = "LinearRegressionEstimator", + ) -> dict: + """ + Convert to a JSON frontend stub string for user customisation. + :param skip: Whether to skip the test + :param effect_type: The type of causal effect to consider (total or direct) + :param estimate_type: The estimate type to use when evaluating tests + :param estimator: The name of the estimator class to use when evaluating the test + """ return { "name": str(self), - "estimator": "LinearRegressionEstimator", - "estimate_type": "coefficient", - "effect": "direct", + "estimator": estimator, + "estimate_type": estimate_type, + "effect": effect_type, "treatment_variable": self.base_test_case.treatment_variable, "expected_effect": {self.base_test_case.outcome_variable: "SomeEffect"}, "formula": ( @@ -63,13 +75,25 @@ def __str__(self): class ShouldNotCause(MetamorphicRelation): """Class representing a should cause metamorphic relation.""" - def to_json_stub(self, skip=True) -> dict: - """Convert to a JSON frontend stub string for user customisation""" + def to_json_stub( + self, + skip: bool = True, + estimate_type: str = "coefficient", + effect_type: str = "direct", + estimator: str = "LinearRegressionEstimator", + ) -> dict: + """ + Convert to a JSON frontend stub string for user customisation. + :param skip: Whether to skip the test + :param effect_type: The type of causal effect to consider (total or direct) + :param estimate_type: The estimate type to use when evaluating tests + :param estimator: The name of the estimator class to use when evaluating the test + """ return { "name": str(self), - "estimator": "LinearRegressionEstimator", - "estimate_type": "coefficient", - "effect": "direct", + "estimator": estimator, + "estimate_type": estimate_type, + "effect": effect_type, "treatment_variable": self.base_test_case.treatment_variable, "expected_effect": {self.base_test_case.outcome_variable: "NoEffect"}, "formula": ( @@ -179,7 +203,9 @@ def generate_metamorphic_relations( return [item for items in metamorphic_relations for item in items] -def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0): +def generate_causal_tests( + dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0, **json_stub_kargs +): """ Generate and output causal tests for a given DAG. @@ -190,6 +216,7 @@ def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool = be omitted from the test set. :param threads: The number of threads to use to generate tests in parallel. If unspecified, tests are generated in serial. This is tylically fine unless the number of tests to be generated is >10000. + :param json_stub_kargs: Kwargs to pass into `to_json_stub` (see docstring for details.) """ causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles) @@ -212,7 +239,7 @@ def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool = relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=threads) tests = [ - relation.to_json_stub(skip=False) + relation.to_json_stub(**json_stub_kargs) for relation in relations if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0 ] diff --git a/tests/main_tests/test_main.py b/tests/main_tests/test_main.py index 5a6433ae..36ce4709 100644 --- a/tests/main_tests/test_main.py +++ b/tests/main_tests/test_main.py @@ -305,6 +305,7 @@ def test_parse_args(self): "sys.argv", [ "causal_testing", + "test", "--dag_path", str(self.dag_path), "--data_paths", @@ -323,6 +324,7 @@ def test_parse_args_batches(self): "sys.argv", [ "causal_testing", + "test", "--dag_path", str(self.dag_path), "--data_paths", @@ -344,7 +346,7 @@ def test_parse_args_generation(self): "sys.argv", [ "causal_testing", - "--generate", + "generate", "--dag_path", str(self.dag_path), "--output", @@ -354,6 +356,28 @@ def test_parse_args_generation(self): main() self.assertTrue(os.path.exists(os.path.join(tmp, "tests.json"))) + def test_parse_args_generation_non_default(self): + with tempfile.TemporaryDirectory() as tmp: + with unittest.mock.patch( + "sys.argv", + [ + "causal_testing", + "generate", + "--dag_path", + str(self.dag_path), + "--output", + os.path.join(tmp, "tests_non_default.json"), + "--estimator", + "LogisticRegressionEstimator", + "--estimate_type", + "unit_odds_ratio", + "--effect_type", + "total", + ], + ): + main() + self.assertTrue(os.path.exists(os.path.join(tmp, "tests_non_default.json"))) + def tearDown(self): if self.output_path.parent.exists(): shutil.rmtree(self.output_path.parent) diff --git a/tests/testing_tests/test_metamorphic_relations.py b/tests/testing_tests/test_metamorphic_relations.py index 68f54589..6d838126 100644 --- a/tests/testing_tests/test_metamorphic_relations.py +++ b/tests/testing_tests/test_metamorphic_relations.py @@ -50,9 +50,9 @@ def test_should_not_cause_json_stub(self): causal_dag = CausalDAG(self.dag_dot_path) causal_dag.graph.remove_nodes_from(["X2", "X3"]) adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0]) - should_not_cause_MR = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set) + should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set) self.assertEqual( - should_not_cause_MR.to_json_stub(), + should_not_cause_mr.to_json_stub(), { "effect": "direct", "estimate_type": "coefficient", @@ -66,15 +66,39 @@ def test_should_not_cause_json_stub(self): }, ) + def test_should_not_cause_logistic_json_stub(self): + """Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program + and there is only a single input.""" + causal_dag = CausalDAG(self.dag_dot_path) + causal_dag.graph.remove_nodes_from(["X2", "X3"]) + adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0]) + should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set) + self.assertEqual( + should_not_cause_mr.to_json_stub( + effect_type="total", estimate_type="unit_odds_ratio", estimator="LogisticRegressionEstimator" + ), + { + "effect": "total", + "estimate_type": "unit_odds_ratio", + "estimator": "LogisticRegressionEstimator", + "expected_effect": {"Z": "NoEffect"}, + "treatment_variable": "X1", + "name": "X1 _||_ Z", + "formula": "Z ~ X1", + "alpha": 0.05, + "skip": True, + }, + ) + def test_should_cause_json_stub(self): """Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program and there is only a single input.""" causal_dag = CausalDAG(self.dag_dot_path) causal_dag.graph.remove_nodes_from(["X2", "X3"]) adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0]) - should_cause_MR = ShouldCause(BaseTestCase("X1", "Z"), adj_set) + should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set) self.assertEqual( - should_cause_MR.to_json_stub(), + should_cause_mr.to_json_stub(), { "effect": "direct", "estimate_type": "coefficient", @@ -87,6 +111,29 @@ def test_should_cause_json_stub(self): }, ) + def test_should_cause_logistic_json_stub(self): + """Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program + and there is only a single input.""" + causal_dag = CausalDAG(self.dag_dot_path) + causal_dag.graph.remove_nodes_from(["X2", "X3"]) + adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0]) + should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set) + self.assertEqual( + should_cause_mr.to_json_stub( + effect_type="total", estimate_type="unit_odds_ratio", estimator="LogisticRegressionEstimator", skip=True + ), + { + "effect": "total", + "estimate_type": "unit_odds_ratio", + "estimator": "LogisticRegressionEstimator", + "expected_effect": {"Z": "SomeEffect"}, + "formula": "Z ~ X1", + "treatment_variable": "X1", + "name": "X1 --> Z", + "skip": True, + }, + ) + def test_all_metamorphic_relations_implied_by_dag(self): dag = CausalDAG(self.dag_dot_path) dag.add_edge("Z", "Y") # Add a direct path from Z to Y so M becomes a mediator @@ -216,7 +263,7 @@ def test_generate_causal_tests_ignore_cycles(self): tests = json.load(f) expected = list( map( - lambda x: x.to_json_stub(skip=False), + lambda x: x.to_json_stub(skip=True), filter( lambda relation: len(list(dcg.graph.predecessors(relation.base_test_case.outcome_variable))) > 0, @@ -236,7 +283,7 @@ def test_generate_causal_tests(self): tests = json.load(f) expected = list( map( - lambda x: x.to_json_stub(skip=False), + lambda x: x.to_json_stub(skip=True), filter( lambda relation: len(list(dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0,