Skip to content

Commit a1eca02

Browse files
author
AndrewC19
committed
ShouldCause metamorphic relation gen and testing works
1 parent e7e02b5 commit a1eca02

File tree

3 files changed

+116
-25
lines changed

3 files changed

+116
-25
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,10 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
106106
executions.
107107
"""
108108
control_results_df = self.run_system_with_input_configuration(self.control_input_configuration)
109+
control_results_df.rename('control_{}'.format, inplace=True)
109110
treatment_results_df = self.run_system_with_input_configuration(self.treatment_input_configuration)
110-
results_df = pd.concat([control_results_df, treatment_results_df], ignore_index=True)
111+
treatment_results_df.rename('treatment_{}'.format, inplace=True)
112+
results_df = pd.concat([control_results_df, treatment_results_df], ignore_index=False)
111113
return results_df
112114

113115
@abstractmethod

causal_testing/specification/metamorphic_relation.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77

88
from causal_testing.specification.causal_specification import CausalDAG, Node
9+
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
910

1011
@dataclass(order=True)
1112
class MetamorphicRelation:
@@ -57,43 +58,91 @@ def generate_follow_up(self,
5758
)
5859
source_test_inputs = source_follow_up_test_inputs[[self.treatment_var]]
5960
follow_up_test_inputs = source_follow_up_test_inputs[[follow_up_input]]
60-
follow_up_test_inputs.rename({follow_up_input: self.treatment_var})
61-
62-
# TODO: Add a metamorphic test dataclass that stores these attributes
63-
self.tests = list(
64-
zip(
65-
source_test_inputs.to_dict(orient="records"),
66-
follow_up_test_inputs.to_dict(orient="records"),
67-
test_inputs.to_dict(orient="records") if not test_inputs.empty
68-
else [{}] * len(source_test_inputs),
69-
[self.output_var] * len(source_test_inputs),
70-
[str(self)] * len(source_test_inputs)
71-
)
72-
)
61+
follow_up_test_inputs = follow_up_test_inputs.rename(columns={follow_up_input: self.treatment_var})
62+
source_test_inputs_record = source_test_inputs.to_dict(orient="records")
63+
follow_up_test_inputs_record = follow_up_test_inputs.to_dict(orient="records")
64+
if not test_inputs.empty:
65+
other_test_inputs_record = test_inputs.to_dict(orient="records")
66+
else:
67+
other_test_inputs_record = [{}] * len(source_test_inputs)
68+
metamorphic_tests = []
69+
for i in range(len(source_test_inputs_record)):
70+
metamorphic_test = MetamorphicTest(source_test_inputs_record[i],
71+
follow_up_test_inputs_record[i],
72+
other_test_inputs_record[i],
73+
self.output_var,
74+
str(self)
75+
)
76+
metamorphic_tests.append(metamorphic_test)
77+
self.tests = metamorphic_tests
78+
79+
def execute_tests(self, data_collector: ExperimentalDataCollector):
80+
"""Execute the generated list of metamorphic tests, returning a dictionary of tests that pass and fail.
81+
82+
:param data_collector: An experimental data collector for the system-under-test.
83+
"""
84+
test_results = {"pass": [], "fail": []}
85+
for metamorphic_test in self.tests:
86+
# Update the control and treatment configuration to take generated values for source and follow-up tests
87+
control_input_config = metamorphic_test.source_inputs | metamorphic_test.other_inputs
88+
treatment_input_config = metamorphic_test.follow_up_inputs | metamorphic_test.other_inputs
89+
data_collector.control_input_configuration = control_input_config
90+
data_collector.treatment_input_configuration = treatment_input_config
91+
metamorphic_test_results_df = data_collector.collect_data()
92+
print(metamorphic_test_results_df)
93+
# Compare control and treatment results
94+
control_output = metamorphic_test_results_df.loc["control_0"][metamorphic_test.output]
95+
treatment_output = metamorphic_test_results_df.loc["treatment_0"][metamorphic_test.output]
96+
if not self.assertion(control_output, treatment_output):
97+
test_results["fail"].append(metamorphic_test)
98+
else:
99+
test_results["pass"].append(metamorphic_test)
100+
return test_results
73101

74102
@abstractmethod
75-
def test_oracle(self):
76-
"""A test oracle i.e. a method that checks correctness of a test."""
103+
def assertion(self, source_output, follow_up_output):
104+
"""An assertion that should be applied to an individual metamorphic test run."""
77105
...
78106

79107
@abstractmethod
80-
def execute_test(self):
81-
"""Execute a test for this metamorphic relation."""
108+
def test_oracle(self, test_results):
109+
"""A test oracle that assert whether the MR holds or not based on ALL test results.
110+
111+
This method must raise an assertion, not return a bool."""
82112
...
83113

84114

85115
@dataclass(order=True)
86116
class ShouldCause(MetamorphicRelation):
87117
"""Class representing a should cause metamorphic relation."""
88118

89-
def test_oracle(self):
90-
pass
119+
def assertion(self, source_output, follow_up_output):
120+
"""If there is a causal effect, the outputs should not be the same."""
121+
return source_output != follow_up_output
122+
123+
def test_oracle(self, test_results):
124+
...
91125

92-
def execute_test(self):
93-
pass
94126

95127
def __str__(self):
96128
formatted_str = f"{self.treatment_var} --> {self.output_var}"
97129
if self.adjustment_vars:
98130
formatted_str += f" | {self.adjustment_vars}"
99131
return formatted_str
132+
133+
134+
@dataclass(order=True)
135+
class MetamorphicTest:
136+
"""Class representing a metamorphic test case."""
137+
source_inputs: dict
138+
follow_up_inputs: dict
139+
other_inputs: dict
140+
output: str
141+
relation: str
142+
143+
def __str__(self):
144+
return f"Source inputs: {self.source_inputs}\n" \
145+
f"Follow-up inputs: {self.follow_up_inputs}\n" \
146+
f"Other inputs: {self.other_inputs}\n" \
147+
f"Output: {self.output}" \
148+
f"Metamorphic Relation: {self.relation}"
Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,63 @@
11
import unittest
22
import os
33

4+
import pandas as pd
5+
46
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
57
from causal_testing.specification.causal_dag import CausalDAG
8+
from causal_testing.specification.causal_specification import Scenario
69
from causal_testing.specification.metamorphic_relation import ShouldCause
10+
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
11+
from causal_testing.specification.variable import Input, Output
12+
13+
14+
def program_under_test(X1, X2, X3, Z=None, M=None, Y=None):
15+
if Z is None:
16+
Z = 2*X1 + -3*X2 + 10
17+
if M is None:
18+
M = 3*Z + X3
19+
if Y is None:
20+
Y = M/2
21+
return {'Z': Z, 'M': M, 'Y': Y}
22+
23+
24+
class ProgramUnderTestEDC(ExperimentalDataCollector):
25+
26+
def run_system_with_input_configuration(self, input_configuration: dict) -> pd.DataFrame:
27+
print(input_configuration)
28+
results_dict = program_under_test(**input_configuration)
29+
print(results_dict)
30+
results_df = pd.DataFrame(results_dict, index=[0])
31+
return results_df
32+
733

834
class TestMetamorphicRelation(unittest.TestCase):
935

1036
def setUp(self) -> None:
1137
temp_dir_path = create_temp_dir_if_non_existent()
1238
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
13-
dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M; M -> Y; Z -> M; }"""
39+
dag_dot = """digraph DAG { rankdir=LR; X1 -> Z; Z -> M; M -> Y; X1 -> M; X2 -> Z; X3 -> M;}"""
1440
with open(self.dag_dot_path, "w") as f:
1541
f.write(dag_dot)
1642

43+
X1 = Input('X1', float)
44+
X2 = Input('X2', float)
45+
X3 = Input('X3', float)
46+
Z = Output('Z', float)
47+
M = Output('M', float)
48+
Y = Output('Y', float)
49+
scenario = Scenario(variables={X1, X2, X3, Z, M, Y})
50+
default_control_input_config = {'X1': 1, 'X2': 2, 'X3': 3}
51+
default_treatment_input_config = {'X1': 2, 'X2': 3, 'X3': 3}
52+
self.data_collector = ProgramUnderTestEDC(scenario,
53+
default_control_input_config,
54+
default_treatment_input_config)
55+
1756
def test_metamorphic_relation(self):
1857
causal_dag = CausalDAG(self.dag_dot_path)
1958
for edge in causal_dag.graph.edges:
2059
(u, v) = edge
2160
should_cause_MR = ShouldCause(u, v, None, causal_dag)
22-
should_cause_MR.generate_follow_up(1, -10.0, 10.0, 1)
23-
print(should_cause_MR.tests)
61+
should_cause_MR.generate_follow_up(10, -10.0, 10.0, 1)
62+
test_results = should_cause_MR.execute_tests(self.data_collector)
63+
print(test_results)

0 commit comments

Comments
 (0)