Skip to content

Commit 3774c44

Browse files
Merge branch 'main' into update-readme
2 parents 35f1ccf + 2b5376d commit 3774c44

File tree

11 files changed

+779
-79
lines changed

11 files changed

+779
-79
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
112112
executions.
113113
"""
114114
control_results_df = self.run_system_with_input_configuration(self.control_input_configuration)
115+
control_results_df.rename(lambda x: f"control_{x}", inplace=True)
115116
treatment_results_df = self.run_system_with_input_configuration(self.treatment_input_configuration)
116-
results_df = pd.concat([control_results_df, treatment_results_df], ignore_index=True)
117+
treatment_results_df.rename(lambda x: f"treatment_{x}", inplace=True)
118+
results_df = pd.concat([control_results_df, treatment_results_df], ignore_index=False)
117119
return results_df
118120

119121
@abstractmethod
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""
2+
This module contains the ShouldCause and ShouldNotCause metamorphic relations as
3+
defined in our ICST paper [https://eprints.whiterose.ac.uk/195317/].
4+
"""
5+
6+
from dataclasses import dataclass
7+
from abc import abstractmethod
8+
from typing import Iterable
9+
from itertools import combinations
10+
import numpy as np
11+
import pandas as pd
12+
import networkx as nx
13+
14+
from causal_testing.specification.causal_specification import CausalDAG, Node
15+
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
16+
17+
18+
@dataclass(order=True)
19+
class MetamorphicRelation:
20+
"""Class representing a metamorphic relation."""
21+
22+
treatment_var: Node
23+
output_var: Node
24+
adjustment_vars: Iterable[Node]
25+
dag: CausalDAG
26+
tests: Iterable = None
27+
28+
def generate_follow_up(self, n_tests: int, min_val: float, max_val: float, seed: int = 0):
29+
"""Generate numerical follow-up input configurations."""
30+
np.random.seed(seed)
31+
32+
# Get set of variables to change, excluding the treatment itself
33+
variables_to_change = {node for node in self.dag.graph.nodes if self.dag.graph.in_degree(node) == 0}
34+
if self.adjustment_vars:
35+
variables_to_change |= set(self.adjustment_vars)
36+
if self.treatment_var in variables_to_change:
37+
variables_to_change.remove(self.treatment_var)
38+
39+
# Assign random numerical values to the variables to change
40+
test_inputs = pd.DataFrame(
41+
np.random.randint(min_val, max_val, size=(n_tests, len(variables_to_change))),
42+
columns=sorted(variables_to_change),
43+
)
44+
45+
# Enumerate the possible source, follow-up pairs for the treatment
46+
candidate_source_follow_up_pairs = np.array(list(combinations(range(int(min_val), int(max_val + 1)), 2)))
47+
48+
# Sample without replacement from the possible source, follow-up pairs
49+
sampled_source_follow_up_indices = np.random.choice(
50+
candidate_source_follow_up_pairs.shape[0], n_tests, replace=False
51+
)
52+
53+
follow_up_input = f"{self.treatment_var}'"
54+
source_follow_up_test_inputs = pd.DataFrame(
55+
candidate_source_follow_up_pairs[sampled_source_follow_up_indices],
56+
columns=sorted([self.treatment_var] + [follow_up_input]),
57+
)
58+
self.tests = [
59+
MetamorphicTest(
60+
source_inputs,
61+
follow_up_inputs,
62+
other_inputs,
63+
self.output_var,
64+
str(self),
65+
)
66+
for source_inputs, follow_up_inputs, other_inputs in zip(
67+
source_follow_up_test_inputs[[self.treatment_var]].to_dict(orient="records"),
68+
source_follow_up_test_inputs[[follow_up_input]]
69+
.rename(columns={follow_up_input: self.treatment_var})
70+
.to_dict(orient="records"),
71+
test_inputs.to_dict(orient="records")
72+
if not test_inputs.empty
73+
else [{}] * len(source_follow_up_test_inputs),
74+
)
75+
]
76+
77+
def execute_tests(self, data_collector: ExperimentalDataCollector):
78+
"""Execute the generated list of metamorphic tests, returning a dictionary of tests that pass and fail.
79+
80+
:param data_collector: An experimental data collector for the system-under-test.
81+
"""
82+
test_results = {"pass": [], "fail": []}
83+
for metamorphic_test in self.tests:
84+
# Update the control and treatment configuration to take generated values for source and follow-up tests
85+
control_input_config = metamorphic_test.source_inputs | metamorphic_test.other_inputs
86+
treatment_input_config = metamorphic_test.follow_up_inputs | metamorphic_test.other_inputs
87+
data_collector.control_input_configuration = control_input_config
88+
data_collector.treatment_input_configuration = treatment_input_config
89+
metamorphic_test_results_df = data_collector.collect_data()
90+
91+
# Apply assertion to control and treatment outputs
92+
control_output = metamorphic_test_results_df.loc["control_0"][metamorphic_test.output]
93+
treatment_output = metamorphic_test_results_df.loc["treatment_0"][metamorphic_test.output]
94+
95+
if not self.assertion(control_output, treatment_output):
96+
test_results["fail"].append(metamorphic_test)
97+
else:
98+
test_results["pass"].append(metamorphic_test)
99+
return test_results
100+
101+
@abstractmethod
102+
def assertion(self, source_output, follow_up_output):
103+
"""An assertion that should be applied to an individual metamorphic test run."""
104+
105+
@abstractmethod
106+
def test_oracle(self, test_results):
107+
"""A test oracle that assert whether the MR holds or not based on ALL test results.
108+
109+
This method must raise an assertion, not return a bool."""
110+
111+
def __eq__(self, other):
112+
same_type = self.__class__ == other.__class__
113+
same_treatment = self.treatment_var == other.treatment_var
114+
same_output = self.output_var == other.output_var
115+
same_adjustment_set = set(self.adjustment_vars) == set(other.adjustment_vars)
116+
return same_type and same_treatment and same_output and same_adjustment_set
117+
118+
119+
class ShouldCause(MetamorphicRelation):
120+
"""Class representing a should cause metamorphic relation."""
121+
122+
def assertion(self, source_output, follow_up_output):
123+
"""If there is a causal effect, the outputs should not be the same."""
124+
return source_output != follow_up_output
125+
126+
def test_oracle(self, test_results):
127+
"""A single passing test is sufficient to show presence of a causal effect."""
128+
assert len(test_results["fail"]) < len(
129+
self.tests
130+
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
131+
132+
def __str__(self):
133+
formatted_str = f"{self.treatment_var} --> {self.output_var}"
134+
if self.adjustment_vars:
135+
formatted_str += f" | {self.adjustment_vars}"
136+
return formatted_str
137+
138+
139+
class ShouldNotCause(MetamorphicRelation):
140+
"""Class representing a should cause metamorphic relation."""
141+
142+
def assertion(self, source_output, follow_up_output):
143+
"""If there is a causal effect, the outputs should not be the same."""
144+
return source_output == follow_up_output
145+
146+
def test_oracle(self, test_results):
147+
"""A single passing test is sufficient to show presence of a causal effect."""
148+
assert (
149+
len(test_results["fail"]) == 0
150+
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
151+
152+
def __str__(self):
153+
formatted_str = f"{self.treatment_var} _||_ {self.output_var}"
154+
if self.adjustment_vars:
155+
formatted_str += f" | {self.adjustment_vars}"
156+
return formatted_str
157+
158+
159+
@dataclass(order=True)
160+
class MetamorphicTest:
161+
"""Class representing a metamorphic test case."""
162+
163+
source_inputs: dict
164+
follow_up_inputs: dict
165+
other_inputs: dict
166+
output: str
167+
relation: str
168+
169+
def __str__(self):
170+
return (
171+
f"Source inputs: {self.source_inputs}\n"
172+
f"Follow-up inputs: {self.follow_up_inputs}\n"
173+
f"Other inputs: {self.other_inputs}\n"
174+
f"Output: {self.output}"
175+
f"Metamorphic Relation: {self.relation}"
176+
)
177+
178+
179+
def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
180+
"""Construct a list of metamorphic relations implied by the Causal DAG.
181+
182+
This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
183+
relation for every (minimal) conditional independence relation implied by the structure of the DAG.
184+
185+
:param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
186+
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
187+
"""
188+
metamorphic_relations = []
189+
for node_pair in combinations(dag.graph.nodes, 2):
190+
(u, v) = node_pair
191+
192+
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
193+
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
194+
195+
# Case 1: U --> ... --> V
196+
if u in nx.ancestors(dag.graph, v):
197+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
198+
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
199+
200+
# Case 2: V --> ... --> U
201+
elif v in nx.ancestors(dag.graph, u):
202+
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
203+
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))
204+
205+
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
206+
# Only make one MR since V _||_ U == U _||_ V
207+
else:
208+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
209+
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
210+
211+
# Create a ShouldCause relation for each edge (u, v) or (v, u)
212+
elif (u, v) in dag.graph.edges:
213+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
214+
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
215+
else:
216+
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
217+
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
218+
219+
return metamorphic_relations

causal_testing/specification/variable.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -77,78 +77,118 @@ def __init__(self, name: str, datatype: T, distribution: rv_generic = None, hidd
7777
def __repr__(self):
7878
return f"{self.typestring()}: {self.name}::{self.datatype.__name__}"
7979

80-
def __ge__(self, other: Any) -> BoolRef:
81-
"""Create the Z3 expression `other >= self`.
80+
# Thin wrapper for Z1 functions
81+
82+
def __add__(self, other: Any) -> BoolRef:
83+
"""Create the Z3 expression `self + other`.
8284
8385
:param any other: The object to compare against.
84-
:return: The Z3 expression `other >= self`.
86+
:return: The Z3 expression `self + other`.
8587
:rtype: BoolRef
8688
"""
87-
return self.z3.__ge__(_coerce(other))
89+
return self.z3.__add__(_coerce(other))
8890

89-
def __le__(self, other: Any) -> BoolRef:
90-
"""Create the Z3 expression `other <= self`.
91+
def __ge__(self, other: Any) -> BoolRef:
92+
"""Create the Z3 expression `self >= other`.
9193
9294
:param any other: The object to compare against.
93-
:return: The Z3 expression `other >= self`.
95+
:return: The Z3 expression `self >= other`.
9496
:rtype: BoolRef
9597
"""
96-
return self.z3.__le__(_coerce(other))
98+
return self.z3.__ge__(_coerce(other))
9799

98100
def __gt__(self, other: Any) -> BoolRef:
99-
"""Create the Z3 expression `other > self`.
101+
"""Create the Z3 expression `self > other`.
100102
101103
:param any other: The object to compare against.
102-
:return: The Z3 expression `other >= self`.
104+
:return: The Z3 expression `self > other`.
103105
:rtype: BoolRef
104106
"""
105107
return self.z3.__gt__(_coerce(other))
106108

109+
def __le__(self, other: Any) -> BoolRef:
110+
"""Create the Z3 expression `self <= other`.
111+
112+
:param any other: The object to compare against.
113+
:return: The Z3 expression `self <= other`.
114+
:rtype: BoolRef
115+
"""
116+
return self.z3.__le__(_coerce(other))
117+
107118
def __lt__(self, other: Any) -> BoolRef:
108-
"""Create the Z3 expression `other < self`.
119+
"""Create the Z3 expression `self < other`.
109120
110121
:param any other: The object to compare against.
111-
:return: The Z3 expression `other >= self`.
122+
:return: The Z3 expression `self < other`.
112123
:rtype: BoolRef
113124
"""
114125
return self.z3.__lt__(_coerce(other))
115126

127+
def __mod__(self, other: Any) -> BoolRef:
128+
"""Create the Z3 expression `self % other`.
129+
130+
:param any other: The object to compare against.
131+
:return: The Z3 expression `self % other`.
132+
:rtype: BoolRef
133+
"""
134+
return self.z3.__mod__(_coerce(other))
135+
116136
def __mul__(self, other: Any) -> BoolRef:
117-
"""Create the Z3 expression `other * self`.
137+
"""Create the Z3 expression `self * other`.
118138
119139
:param any other: The object to compare against.
120-
:return: The Z3 expression `other >= self`.
140+
:return: The Z3 expression `self * other`.
121141
:rtype: BoolRef
122142
"""
123143
return self.z3.__mul__(_coerce(other))
124144

125-
def __sub__(self, other: Any) -> BoolRef:
126-
"""Create the Z3 expression `other * self`.
145+
def __ne__(self, other: Any) -> BoolRef:
146+
"""Create the Z3 expression `self != other`.
127147
128148
:param any other: The object to compare against.
129-
:return: The Z3 expression `other >= self`.
149+
:return: The Z3 expression `self != other`.
130150
:rtype: BoolRef
131151
"""
132-
return self.z3.__sub__(_coerce(other))
152+
return self.z3.__ne__(_coerce(other))
133153

134-
def __add__(self, other: Any) -> BoolRef:
135-
"""Create the Z3 expression `other * self`.
154+
def __neg__(self) -> BoolRef:
155+
"""Create the Z3 expression `-self`.
136156
137157
:param any other: The object to compare against.
138-
:return: The Z3 expression `other >= self`.
158+
:return: The Z3 expression `-self`.
139159
:rtype: BoolRef
140160
"""
141-
return self.z3.__add__(_coerce(other))
161+
return self.z3.__neg__()
162+
163+
def __pow__(self, other: Any) -> BoolRef:
164+
"""Create the Z3 expression `self ^ other`.
165+
166+
:param any other: The object to compare against.
167+
:return: The Z3 expression `self ^ other`.
168+
:rtype: BoolRef
169+
"""
170+
return self.z3.__pow__(_coerce(other))
171+
172+
def __sub__(self, other: Any) -> BoolRef:
173+
"""Create the Z3 expression `self - other`.
174+
175+
:param any other: The object to compare against.
176+
:return: The Z3 expression `self - other`.
177+
:rtype: BoolRef
178+
"""
179+
return self.z3.__sub__(_coerce(other))
142180

143181
def __truediv__(self, other: Any) -> BoolRef:
144-
"""Create the Z3 expression `other * self`.
182+
"""Create the Z3 expression `self / other`.
145183
146184
:param any other: The object to compare against.
147-
:return: The Z3 expression `other >= self`.
185+
:return: The Z3 expression `self / other`.
148186
:rtype: BoolRef
149187
"""
150188
return self.z3.__truediv__(_coerce(other))
151189

190+
# End thin wrapper
191+
152192
def cast(self, val: Any) -> T:
153193
"""Cast the supplied value to the datatype T of the variable.
154194

0 commit comments

Comments
 (0)