Skip to content

Commit 5c57f49

Browse files
author
AndrewC19
committed
ShouldCause MR test oracle works.
1 parent a1eca02 commit 5c57f49

File tree

2 files changed

+50
-16
lines changed

2 files changed

+50
-16
lines changed

causal_testing/specification/metamorphic_relation.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,15 @@ def execute_tests(self, data_collector: ExperimentalDataCollector):
8989
data_collector.control_input_configuration = control_input_config
9090
data_collector.treatment_input_configuration = treatment_input_config
9191
metamorphic_test_results_df = data_collector.collect_data()
92-
print(metamorphic_test_results_df)
93-
# Compare control and treatment results
92+
93+
# Apply assertion to control and treatment outputs
9494
control_output = metamorphic_test_results_df.loc["control_0"][metamorphic_test.output]
9595
treatment_output = metamorphic_test_results_df.loc["treatment_0"][metamorphic_test.output]
9696
if not self.assertion(control_output, treatment_output):
9797
test_results["fail"].append(metamorphic_test)
9898
else:
9999
test_results["pass"].append(metamorphic_test)
100-
return test_results
100+
return test_results
101101

102102
@abstractmethod
103103
def assertion(self, source_output, follow_up_output):
@@ -121,8 +121,9 @@ def assertion(self, source_output, follow_up_output):
121121
return source_output != follow_up_output
122122

123123
def test_oracle(self, test_results):
124-
...
125-
124+
"""A single passing test is sufficient to show presence of a causal effect."""
125+
assert len(test_results["fail"]) < len(self.tests),\
126+
f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
126127

127128
def __str__(self):
128129
formatted_str = f"{self.treatment_var} --> {self.output_var}"

tests/specification_tests/test_metamorphic_relations.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,28 @@ def program_under_test(X1, X2, X3, Z=None, M=None, Y=None):
2121
return {'Z': Z, 'M': M, 'Y': Y}
2222

2323

24+
def buggy_program_under_test(X1, X2, X3, Z=None, M=None, Y=None):
25+
if Z is None:
26+
Z = 2 # No effect of X1 or X2 on Z
27+
if M is None:
28+
M = 3*Z + X3
29+
if Y is None:
30+
Y = M/2
31+
return {'Z': Z, 'M': M, 'Y': Y}
32+
33+
2434
class ProgramUnderTestEDC(ExperimentalDataCollector):
2535

2636
def run_system_with_input_configuration(self, input_configuration: dict) -> pd.DataFrame:
27-
print(input_configuration)
2837
results_dict = program_under_test(**input_configuration)
29-
print(results_dict)
38+
results_df = pd.DataFrame(results_dict, index=[0])
39+
return results_df
40+
41+
42+
class BuggyProgramUnderTestEDC(ExperimentalDataCollector):
43+
44+
def run_system_with_input_configuration(self, input_configuration: dict) -> pd.DataFrame:
45+
results_dict = buggy_program_under_test(**input_configuration)
3046
results_df = pd.DataFrame(results_dict, index=[0])
3147
return results_df
3248

@@ -46,18 +62,35 @@ def setUp(self) -> None:
4662
Z = Output('Z', float)
4763
M = Output('M', float)
4864
Y = Output('Y', float)
49-
scenario = Scenario(variables={X1, X2, X3, Z, M, Y})
50-
default_control_input_config = {'X1': 1, 'X2': 2, 'X3': 3}
51-
default_treatment_input_config = {'X1': 2, 'X2': 3, 'X3': 3}
52-
self.data_collector = ProgramUnderTestEDC(scenario,
53-
default_control_input_config,
54-
default_treatment_input_config)
55-
56-
def test_metamorphic_relation(self):
65+
self.scenario = Scenario(variables={X1, X2, X3, Z, M, Y})
66+
self.default_control_input_config = {'X1': 1, 'X2': 2, 'X3': 3}
67+
self.default_treatment_input_config = {'X1': 2, 'X2': 3, 'X3': 3}
68+
self.data_collector = ProgramUnderTestEDC(self.scenario,
69+
self.default_control_input_config,
70+
self.default_treatment_input_config)
71+
72+
def test_should_cause_metamorphic_relations_should_pass(self):
5773
causal_dag = CausalDAG(self.dag_dot_path)
5874
for edge in causal_dag.graph.edges:
5975
(u, v) = edge
6076
should_cause_MR = ShouldCause(u, v, None, causal_dag)
6177
should_cause_MR.generate_follow_up(10, -10.0, 10.0, 1)
6278
test_results = should_cause_MR.execute_tests(self.data_collector)
63-
print(test_results)
79+
should_cause_MR.test_oracle(test_results)
80+
81+
def test_should_cause_metamorphic_relation_missing_relationship(self):
82+
"""Test whether the ShouldCause MR catches missing relationships in the DAG."""
83+
causal_dag = CausalDAG(self.dag_dot_path)
84+
self.data_collector = BuggyProgramUnderTestEDC(self.scenario,
85+
self.default_control_input_config,
86+
self.default_treatment_input_config)
87+
X1_should_cause_Z_MR = ShouldCause('X1', 'Z', None, causal_dag)
88+
X2_should_cause_Z_MR = ShouldCause('X2', 'Z', None, causal_dag)
89+
X1_should_cause_Z_MR.generate_follow_up(10, -100, 100, 1)
90+
X2_should_cause_Z_MR.generate_follow_up(10, -100, 100, 1)
91+
X1_should_cause_Z_test_results = X1_should_cause_Z_MR.execute_tests(self.data_collector)
92+
X2_should_cause_Z_test_results = X2_should_cause_Z_MR.execute_tests(self.data_collector)
93+
self.assertRaises(AssertionError, X1_should_cause_Z_MR.test_oracle, X1_should_cause_Z_test_results)
94+
self.assertRaises(AssertionError, X2_should_cause_Z_MR.test_oracle, X2_should_cause_Z_test_results)
95+
96+

0 commit comments

Comments
 (0)