Skip to content

Commit cefa853

Browse files
committed
Made generation options configurable
1 parent a722a60 commit cefa853

File tree

5 files changed

+149
-17
lines changed

5 files changed

+149
-17
lines changed

causal_testing/__main__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,15 @@ def main() -> None:
2121

2222
if args.command == Command.GENERATE:
2323
logging.info("Generating causal tests")
24-
generate_causal_tests(args.dag_path, args.output, args.ignore_cycles, args.threads)
24+
generate_causal_tests(
25+
args.dag_path,
26+
args.output,
27+
args.ignore_cycles,
28+
args.threads,
29+
effect_type=args.effect_type,
30+
estimate_type=args.estimate_type,
31+
estimator=args.estimator,
32+
)
2533
logging.info("Causal test generation completed successfully")
2634
return
2735

causal_testing/main.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,16 +490,36 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
490490
help="The action you want to run - call `causal_testing {action} -h` for further details", dest="command"
491491
)
492492

493+
# Generation
493494
parser_generate = subparsers.add_parser(Command.GENERATE.value, help="Generate causal tests from a DAG")
494495
parser_generate.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
495496
parser_generate.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
497+
parser_generate.add_argument(
498+
"-e",
499+
"--estimator",
500+
help="The name of the estimator class to use when evaluating tests (defaults to LinearRegressionEstimator)",
501+
default="LinearRegressionEstimator",
502+
)
503+
parser_generate.add_argument(
504+
"-T",
505+
"--effect_type",
506+
help="The effect type to estimate {direct, total}",
507+
default="direct",
508+
)
509+
parser_generate.add_argument(
510+
"-E",
511+
"--estimate_type",
512+
help="The estimate type to use when evaluating tests (defaults to coefficient)",
513+
default="coefficient",
514+
)
496515
parser_generate.add_argument(
497516
"-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False
498517
)
499518
parser_generate.add_argument(
500519
"--threads", "-t", type=int, help="The number of parallel threads to use.", required=False, default=0
501520
)
502521

522+
# Testing
503523
parser_test = subparsers.add_parser(Command.TEST.value, help="Run causal tests")
504524
parser_test.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
505525
parser_test.add_argument("-o", "--output", help="Path for output file (.json)", required=True)

causal_testing/testing/metamorphic_relation.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,25 @@ def __eq__(self, other):
3737
class ShouldCause(MetamorphicRelation):
3838
"""Class representing a should cause metamorphic relation."""
3939

40-
def to_json_stub(self, skip=True) -> dict:
41-
"""Convert to a JSON frontend stub string for user customisation"""
40+
def to_json_stub(
41+
self,
42+
skip: bool = True,
43+
estimate_type: str = "coefficient",
44+
effect_type: str = "direct",
45+
estimator: str = "LinearRegressionEstimator",
46+
) -> dict:
47+
"""
48+
Convert to a JSON frontend stub string for user customisation.
49+
:param skip: Whether to skip the test
50+
:param effect_type: The type of causal effect to consider (total or direct)
51+
:param estimate_type: The estimate type to use when evaluating tests
52+
:param estimator: The name of the estimator class to use when evaluating the test
53+
"""
4254
return {
4355
"name": str(self),
44-
"estimator": "LinearRegressionEstimator",
45-
"estimate_type": "coefficient",
46-
"effect": "direct",
56+
"estimator": estimator,
57+
"estimate_type": estimate_type,
58+
"effect": effect_type,
4759
"treatment_variable": self.base_test_case.treatment_variable,
4860
"expected_effect": {self.base_test_case.outcome_variable: "SomeEffect"},
4961
"formula": (
@@ -63,13 +75,25 @@ def __str__(self):
6375
class ShouldNotCause(MetamorphicRelation):
6476
"""Class representing a should cause metamorphic relation."""
6577

66-
def to_json_stub(self, skip=True) -> dict:
67-
"""Convert to a JSON frontend stub string for user customisation"""
78+
def to_json_stub(
79+
self,
80+
skip: bool = True,
81+
estimate_type: str = "coefficient",
82+
effect_type: str = "direct",
83+
estimator: str = "LinearRegressionEstimator",
84+
) -> dict:
85+
"""
86+
Convert to a JSON frontend stub string for user customisation.
87+
:param skip: Whether to skip the test
88+
:param effect_type: The type of causal effect to consider (total or direct)
89+
:param estimate_type: The estimate type to use when evaluating tests
90+
:param estimator: The name of the estimator class to use when evaluating the test
91+
"""
6892
return {
6993
"name": str(self),
70-
"estimator": "LinearRegressionEstimator",
71-
"estimate_type": "coefficient",
72-
"effect": "direct",
94+
"estimator": estimator,
95+
"estimate_type": estimate_type,
96+
"effect": effect_type,
7397
"treatment_variable": self.base_test_case.treatment_variable,
7498
"expected_effect": {self.base_test_case.outcome_variable: "NoEffect"},
7599
"formula": (
@@ -179,7 +203,15 @@ def generate_metamorphic_relations(
179203
return [item for items in metamorphic_relations for item in items]
180204

181205

182-
def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0):
206+
def generate_causal_tests(
207+
dag_path: str,
208+
output_path: str,
209+
ignore_cycles: bool = False,
210+
threads: int = 0,
211+
estimate_type: str = "coefficient",
212+
effect_type: str = "direct",
213+
estimator: str = "LinearRegressionEstimator",
214+
):
183215
"""
184216
Generate and output causal tests for a given DAG.
185217
@@ -190,6 +222,9 @@ def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool =
190222
be omitted from the test set.
191223
:param threads: The number of threads to use to generate tests in parallel. If unspecified, tests are generated in
192224
serial. This is tylically fine unless the number of tests to be generated is >10000.
225+
:param effect_type: The type of causal effect to consider (total or direct)
226+
:param estimate_type: The estimate type to use when evaluating tests
227+
:param estimator: The name of the estimator class to use when evaluating the test
193228
"""
194229
causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles)
195230

@@ -212,7 +247,7 @@ def generate_causal_tests(dag_path: str, output_path: str, ignore_cycles: bool =
212247
relations = generate_metamorphic_relations(causal_dag, nodes_to_test=dag_nodes_to_test, threads=threads)
213248

214249
tests = [
215-
relation.to_json_stub(skip=False)
250+
relation.to_json_stub(skip=False, estimate_type=estimate_type, effect_type=effect_type, estimator=estimator)
216251
for relation in relations
217252
if len(list(causal_dag.graph.predecessors(relation.base_test_case.outcome_variable))) > 0
218253
]

tests/main_tests/test_main.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,28 @@ def test_parse_args_generation(self):
356356
main()
357357
self.assertTrue(os.path.exists(os.path.join(tmp, "tests.json")))
358358

359+
def test_parse_args_generation_non_default(self):
360+
with tempfile.TemporaryDirectory() as tmp:
361+
with unittest.mock.patch(
362+
"sys.argv",
363+
[
364+
"causal_testing",
365+
"generate",
366+
"--dag_path",
367+
str(self.dag_path),
368+
"--output",
369+
os.path.join(tmp, "tests_non_default.json"),
370+
"--estimator",
371+
"LogisticRegressionEstimator",
372+
"--estimate_type",
373+
"unit_odds_ratio",
374+
"--effect_type",
375+
"total",
376+
],
377+
):
378+
main()
379+
self.assertTrue(os.path.exists(os.path.join(tmp, "tests_non_default.json")))
380+
359381
def tearDown(self):
360382
if self.output_path.parent.exists():
361383
shutil.rmtree(self.output_path.parent)

tests/testing_tests/test_metamorphic_relations.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def test_should_not_cause_json_stub(self):
5050
causal_dag = CausalDAG(self.dag_dot_path)
5151
causal_dag.graph.remove_nodes_from(["X2", "X3"])
5252
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
53-
should_not_cause_MR = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set)
53+
should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set)
5454
self.assertEqual(
55-
should_not_cause_MR.to_json_stub(),
55+
should_not_cause_mr.to_json_stub(),
5656
{
5757
"effect": "direct",
5858
"estimate_type": "coefficient",
@@ -66,15 +66,39 @@ def test_should_not_cause_json_stub(self):
6666
},
6767
)
6868

69+
def test_should_not_cause_logistic_json_stub(self):
70+
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
71+
and there is only a single input."""
72+
causal_dag = CausalDAG(self.dag_dot_path)
73+
causal_dag.graph.remove_nodes_from(["X2", "X3"])
74+
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
75+
should_not_cause_mr = ShouldNotCause(BaseTestCase("X1", "Z"), adj_set)
76+
self.assertEqual(
77+
should_not_cause_mr.to_json_stub(
78+
effect_type="total", estimate_type="unit_odds_ratio", estimator="LogisticRegressionEstimator"
79+
),
80+
{
81+
"effect": "total",
82+
"estimate_type": "unit_odds_ratio",
83+
"estimator": "LogisticRegressionEstimator",
84+
"expected_effect": {"Z": "NoEffect"},
85+
"treatment_variable": "X1",
86+
"name": "X1 _||_ Z",
87+
"formula": "Z ~ X1",
88+
"alpha": 0.05,
89+
"skip": True,
90+
},
91+
)
92+
6993
def test_should_cause_json_stub(self):
7094
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
7195
and there is only a single input."""
7296
causal_dag = CausalDAG(self.dag_dot_path)
7397
causal_dag.graph.remove_nodes_from(["X2", "X3"])
7498
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
75-
should_cause_MR = ShouldCause(BaseTestCase("X1", "Z"), adj_set)
99+
should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set)
76100
self.assertEqual(
77-
should_cause_MR.to_json_stub(),
101+
should_cause_mr.to_json_stub(),
78102
{
79103
"effect": "direct",
80104
"estimate_type": "coefficient",
@@ -87,6 +111,29 @@ def test_should_cause_json_stub(self):
87111
},
88112
)
89113

114+
def test_should_cause_logistic_json_stub(self):
115+
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program
116+
and there is only a single input."""
117+
causal_dag = CausalDAG(self.dag_dot_path)
118+
causal_dag.graph.remove_nodes_from(["X2", "X3"])
119+
adj_set = list(causal_dag.direct_effect_adjustment_sets(["X1"], ["Z"])[0])
120+
should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set)
121+
self.assertEqual(
122+
should_cause_mr.to_json_stub(
123+
effect_type="total", estimate_type="unit_odds_ratio", estimator="LogisticRegressionEstimator"
124+
),
125+
{
126+
"effect": "total",
127+
"estimate_type": "unit_odds_ratio",
128+
"estimator": "LogisticRegressionEstimator",
129+
"expected_effect": {"Z": "SomeEffect"},
130+
"formula": "Z ~ X1",
131+
"treatment_variable": "X1",
132+
"name": "X1 --> Z",
133+
"skip": True,
134+
},
135+
)
136+
90137
def test_all_metamorphic_relations_implied_by_dag(self):
91138
dag = CausalDAG(self.dag_dot_path)
92139
dag.add_edge("Z", "Y") # Add a direct path from Z to Y so M becomes a mediator

0 commit comments

Comments
 (0)