Skip to content

Commit 736cf6d

Browse files
committed
Removed data collector from metamorphic relation
1 parent bb70338 commit 736cf6d

File tree

3 files changed

+266
-489
lines changed

3 files changed

+266
-489
lines changed

causal_testing/specification/metamorphic_relation.py renamed to causal_testing/testing/metamorphic_relation.py

Lines changed: 19 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@
1313
from multiprocessing import Pool
1414

1515
import networkx as nx
16-
import pandas as pd
17-
import numpy as np
1816

1917
from causal_testing.specification.causal_specification import CausalDAG, Node
20-
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
18+
from causal_testing.testing.base_test_case import BaseTestCase
2119

2220
logger = logging.getLogger(__name__)
2321

@@ -26,91 +24,11 @@
2624
class MetamorphicRelation:
2725
"""Class representing a metamorphic relation."""
2826

29-
treatment_var: Node
30-
output_var: Node
27+
base_test_case: BaseTestCase
3128
adjustment_vars: Iterable[Node]
3229
dag: CausalDAG
3330
tests: Iterable = None
3431

35-
def generate_follow_up(self, n_tests: int, min_val: float, max_val: float, seed: int = 0):
36-
"""Generate numerical follow-up input configurations."""
37-
np.random.seed(seed)
38-
39-
# Get set of variables to change, excluding the treatment itself
40-
variables_to_change = {node for node in self.dag.graph.nodes if self.dag.graph.in_degree(node) == 0}
41-
if self.adjustment_vars:
42-
variables_to_change |= set(self.adjustment_vars)
43-
if self.treatment_var in variables_to_change:
44-
variables_to_change.remove(self.treatment_var)
45-
46-
# Assign random numerical values to the variables to change
47-
test_inputs = pd.DataFrame(
48-
np.random.randint(min_val, max_val, size=(n_tests, len(variables_to_change))),
49-
columns=sorted(variables_to_change),
50-
)
51-
52-
# Enumerate the possible source, follow-up pairs for the treatment
53-
candidate_source_follow_up_pairs = np.array(list(combinations(range(int(min_val), int(max_val + 1)), 2)))
54-
55-
# Sample without replacement from the possible source, follow-up pairs
56-
sampled_source_follow_up_indices = np.random.choice(
57-
candidate_source_follow_up_pairs.shape[0], n_tests, replace=False
58-
)
59-
60-
follow_up_input = f"{self.treatment_var}'"
61-
source_follow_up_test_inputs = pd.DataFrame(
62-
candidate_source_follow_up_pairs[sampled_source_follow_up_indices],
63-
columns=sorted([self.treatment_var] + [follow_up_input]),
64-
)
65-
self.tests = [
66-
MetamorphicTest(
67-
source_inputs,
68-
follow_up_inputs,
69-
other_inputs,
70-
self.output_var,
71-
str(self),
72-
)
73-
for source_inputs, follow_up_inputs, other_inputs in zip(
74-
source_follow_up_test_inputs[[self.treatment_var]].to_dict(orient="records"),
75-
source_follow_up_test_inputs[[follow_up_input]]
76-
.rename(columns={follow_up_input: self.treatment_var})
77-
.to_dict(orient="records"),
78-
(
79-
test_inputs.to_dict(orient="records")
80-
if not test_inputs.empty
81-
else [{}] * len(source_follow_up_test_inputs)
82-
),
83-
)
84-
]
85-
86-
def execute_tests(self, data_collector: ExperimentalDataCollector):
87-
"""Execute the generated list of metamorphic tests, returning a dictionary of tests that pass and fail.
88-
89-
:param data_collector: An experimental data collector for the system-under-test.
90-
"""
91-
test_results = {"pass": [], "fail": []}
92-
for metamorphic_test in self.tests:
93-
# Update the control and treatment configuration to take generated values for source and follow-up tests
94-
control_input_config = metamorphic_test.source_inputs | metamorphic_test.other_inputs
95-
treatment_input_config = metamorphic_test.follow_up_inputs | metamorphic_test.other_inputs
96-
data_collector.control_input_configuration = control_input_config
97-
data_collector.treatment_input_configuration = treatment_input_config
98-
metamorphic_test_results_df = data_collector.collect_data()
99-
100-
# Apply assertion to control and treatment outputs
101-
control_output = metamorphic_test_results_df.loc["control_0"][metamorphic_test.output]
102-
treatment_output = metamorphic_test_results_df.loc["treatment_0"][metamorphic_test.output]
103-
104-
if not self.assertion(control_output, treatment_output):
105-
test_results["fail"].append(metamorphic_test)
106-
else:
107-
test_results["pass"].append(metamorphic_test)
108-
return test_results
109-
110-
@abstractmethod
111-
def assertion(self, source_output, follow_up_output):
112-
"""An assertion that should be applied to an individual metamorphic test run."""
113-
11432
@abstractmethod
11533
def to_json_stub(self, skip=True) -> dict:
11634
"""Convert to a JSON frontend stub string for user customisation"""
@@ -123,10 +41,11 @@ def test_oracle(self, test_results):
12341

12442
def __eq__(self, other):
12543
same_type = self.__class__ == other.__class__
126-
same_treatment = self.treatment_var == other.treatment_var
127-
same_output = self.output_var == other.output_var
44+
same_treatment = self.base_test_case.treatment_variable == other.base_test_case.treatment_variable
45+
same_outcome = self.base_test_case.outcome_variable == other.base_test_case.outcome_variable
46+
same_effect = self.base_test_case.effect == other.base_test_case.effect
12847
same_adjustment_set = set(self.adjustment_vars) == set(other.adjustment_vars)
129-
return same_type and same_treatment and same_output and same_adjustment_set
48+
return same_type and same_treatment and same_outcome and same_effect and same_adjustment_set
13049

13150

13251
class ShouldCause(MetamorphicRelation):
@@ -149,14 +68,14 @@ def to_json_stub(self, skip=True) -> dict:
14968
"estimator": "LinearRegressionEstimator",
15069
"estimate_type": "coefficient",
15170
"effect": "direct",
152-
"mutations": [self.treatment_var],
153-
"expected_effect": {self.output_var: "SomeEffect"},
154-
"formula": f"{self.output_var} ~ {' + '.join([self.treatment_var] + self.adjustment_vars)}",
71+
"mutations": [self.base_test_case.treatment_variable],
72+
"expected_effect": {self.base_test_case.outcome_variable: "SomeEffect"},
73+
"formula": f"{self.base_test_case.outcome_variable} ~ {' + '.join([self.base_test_case.treatment_variable] + self.adjustment_vars)}",
15574
"skip": skip,
15675
}
15776

15877
def __str__(self):
159-
formatted_str = f"{self.treatment_var} --> {self.output_var}"
78+
formatted_str = f"{self.base_test_case.treatment_variable} --> {self.base_test_case.outcome_variable}"
16079
if self.adjustment_vars:
16180
formatted_str += f" | {self.adjustment_vars}"
16281
return formatted_str
@@ -182,40 +101,20 @@ def to_json_stub(self, skip=True) -> dict:
182101
"estimator": "LinearRegressionEstimator",
183102
"estimate_type": "coefficient",
184103
"effect": "direct",
185-
"mutations": [self.treatment_var],
186-
"expected_effect": {self.output_var: "NoEffect"},
187-
"formula": f"{self.output_var} ~ {' + '.join([self.treatment_var] + self.adjustment_vars)}",
104+
"mutations": [self.base_test_case.treatment_variable],
105+
"expected_effect": {self.base_test_case.outcome_variable: "NoEffect"},
106+
"formula": f"{self.base_test_case.outcome_variable} ~ {' + '.join([self.base_test_case.treatment_variable] + self.adjustment_vars)}",
188107
"alpha": 0.05,
189108
"skip": skip,
190109
}
191110

192111
def __str__(self):
193-
formatted_str = f"{self.treatment_var} _||_ {self.output_var}"
112+
formatted_str = f"{self.base_test_case.treatment_variable} _||_ {self.base_test_case.outcome_variable}"
194113
if self.adjustment_vars:
195114
formatted_str += f" | {self.adjustment_vars}"
196115
return formatted_str
197116

198117

199-
@dataclass(order=True)
200-
class MetamorphicTest:
201-
"""Class representing a metamorphic test case."""
202-
203-
source_inputs: dict
204-
follow_up_inputs: dict
205-
other_inputs: dict
206-
output: str
207-
relation: str
208-
209-
def __str__(self):
210-
return (
211-
f"Source inputs: {self.source_inputs}\n"
212-
f"Follow-up inputs: {self.follow_up_inputs}\n"
213-
f"Other inputs: {self.other_inputs}\n"
214-
f"Output: {self.output}"
215-
f"Metamorphic Relation: {self.relation}"
216-
)
217-
218-
219118
def generate_metamorphic_relation(
220119
node_pair: tuple[str, str], dag: CausalDAG, nodes_to_ignore: set = None
221120
) -> MetamorphicRelation:
@@ -241,30 +140,30 @@ def generate_metamorphic_relation(
241140
if u in nx.ancestors(dag.graph, v):
242141
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
243142
if adj_sets:
244-
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))
143+
metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0]), dag))
245144

246145
# Case 2: V --> ... --> U
247146
elif v in nx.ancestors(dag.graph, u):
248147
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
249148
if adj_sets:
250-
metamorphic_relations.append(ShouldNotCause(v, u, list(adj_sets[0]), dag))
149+
metamorphic_relations.append(ShouldNotCause(BaseTestCase(v, u), list(adj_sets[0]), dag))
251150

252151
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
253152
# Only make one MR since V _||_ U == U _||_ V
254153
else:
255154
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
256155
if adj_sets:
257-
metamorphic_relations.append(ShouldNotCause(u, v, list(adj_sets[0]), dag))
156+
metamorphic_relations.append(ShouldNotCause(BaseTestCase(u, v), list(adj_sets[0]), dag))
258157

259158
# Create a ShouldCause relation for each edge (u, v) or (v, u)
260159
elif (u, v) in dag.graph.edges:
261160
adj_sets = dag.direct_effect_adjustment_sets([u], [v], nodes_to_ignore=nodes_to_ignore)
262161
if adj_sets:
263-
metamorphic_relations.append(ShouldCause(u, v, list(adj_sets[0]), dag))
162+
metamorphic_relations.append(ShouldCause(BaseTestCase(u, v), list(adj_sets[0]), dag))
264163
else:
265164
adj_sets = dag.direct_effect_adjustment_sets([v], [u], nodes_to_ignore=nodes_to_ignore)
266165
if adj_sets:
267-
metamorphic_relations.append(ShouldCause(v, u, list(adj_sets[0]), dag))
166+
metamorphic_relations.append(ShouldCause(BaseTestCase(v, u), list(adj_sets[0]), dag))
268167
return metamorphic_relations
269168

270169

0 commit comments

Comments
 (0)