Skip to content

Commit d042d4f

Browse files
authored
Merge branch 'main' into z3_operations
2 parents 203be8a + a187fec commit d042d4f

File tree

9 files changed

+677
-46
lines changed

9 files changed

+677
-46
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,10 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
112112
executions.
113113
"""
114114
control_results_df = self.run_system_with_input_configuration(self.control_input_configuration)
115+
control_results_df.rename(lambda x: f"control_{x}", inplace=True)
115116
treatment_results_df = self.run_system_with_input_configuration(self.treatment_input_configuration)
116-
results_df = pd.concat([control_results_df, treatment_results_df], ignore_index=True)
117+
treatment_results_df.rename(lambda x: f"treatment_{x}", inplace=True)
118+
results_df = pd.concat([control_results_df, treatment_results_df], ignore_index=False)
117119
return results_df
118120

119121
@abstractmethod
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""
2+
This module contains the ShouldCause and ShouldNotCause metamorphic relations as
3+
defined in our ICST paper [https://eprints.whiterose.ac.uk/195317/].
4+
"""
5+
6+
from dataclasses import dataclass
7+
from abc import abstractmethod
8+
from typing import Iterable
9+
from itertools import combinations
10+
import numpy as np
11+
import pandas as pd
12+
import networkx as nx
13+
14+
from causal_testing.specification.causal_specification import CausalDAG, Node
15+
from causal_testing.data_collection.data_collector import ExperimentalDataCollector
16+
17+
18+
@dataclass(order=True)
19+
class MetamorphicRelation:
20+
"""Class representing a metamorphic relation."""
21+
22+
treatment_var: Node
23+
output_var: Node
24+
adjustment_vars: Iterable[Node]
25+
dag: CausalDAG
26+
tests: Iterable = None
27+
28+
def generate_follow_up(self, n_tests: int, min_val: float, max_val: float, seed: int = 0):
29+
"""Generate numerical follow-up input configurations."""
30+
np.random.seed(seed)
31+
32+
# Get set of variables to change, excluding the treatment itself
33+
variables_to_change = {node for node in self.dag.graph.nodes if self.dag.graph.in_degree(node) == 0}
34+
if self.adjustment_vars:
35+
variables_to_change |= set(self.adjustment_vars)
36+
if self.treatment_var in variables_to_change:
37+
variables_to_change.remove(self.treatment_var)
38+
39+
# Assign random numerical values to the variables to change
40+
test_inputs = pd.DataFrame(
41+
np.random.randint(min_val, max_val, size=(n_tests, len(variables_to_change))),
42+
columns=sorted(variables_to_change),
43+
)
44+
45+
# Enumerate the possible source, follow-up pairs for the treatment
46+
candidate_source_follow_up_pairs = np.array(list(combinations(range(int(min_val), int(max_val + 1)), 2)))
47+
48+
# Sample without replacement from the possible source, follow-up pairs
49+
sampled_source_follow_up_indices = np.random.choice(
50+
candidate_source_follow_up_pairs.shape[0], n_tests, replace=False
51+
)
52+
53+
follow_up_input = f"{self.treatment_var}'"
54+
source_follow_up_test_inputs = pd.DataFrame(
55+
candidate_source_follow_up_pairs[sampled_source_follow_up_indices],
56+
columns=sorted([self.treatment_var] + [follow_up_input]),
57+
)
58+
self.tests = [
59+
MetamorphicTest(
60+
source_inputs,
61+
follow_up_inputs,
62+
other_inputs,
63+
self.output_var,
64+
str(self),
65+
)
66+
for source_inputs, follow_up_inputs, other_inputs in zip(
67+
source_follow_up_test_inputs[[self.treatment_var]].to_dict(orient="records"),
68+
source_follow_up_test_inputs[[follow_up_input]]
69+
.rename(columns={follow_up_input: self.treatment_var})
70+
.to_dict(orient="records"),
71+
test_inputs.to_dict(orient="records")
72+
if not test_inputs.empty
73+
else [{}] * len(source_follow_up_test_inputs),
74+
)
75+
]
76+
77+
def execute_tests(self, data_collector: ExperimentalDataCollector):
78+
"""Execute the generated list of metamorphic tests, returning a dictionary of tests that pass and fail.
79+
80+
:param data_collector: An experimental data collector for the system-under-test.
81+
"""
82+
test_results = {"pass": [], "fail": []}
83+
for metamorphic_test in self.tests:
84+
# Update the control and treatment configuration to take generated values for source and follow-up tests
85+
control_input_config = metamorphic_test.source_inputs | metamorphic_test.other_inputs
86+
treatment_input_config = metamorphic_test.follow_up_inputs | metamorphic_test.other_inputs
87+
data_collector.control_input_configuration = control_input_config
88+
data_collector.treatment_input_configuration = treatment_input_config
89+
metamorphic_test_results_df = data_collector.collect_data()
90+
91+
# Apply assertion to control and treatment outputs
92+
control_output = metamorphic_test_results_df.loc["control_0"][metamorphic_test.output]
93+
treatment_output = metamorphic_test_results_df.loc["treatment_0"][metamorphic_test.output]
94+
95+
if not self.assertion(control_output, treatment_output):
96+
test_results["fail"].append(metamorphic_test)
97+
else:
98+
test_results["pass"].append(metamorphic_test)
99+
return test_results
100+
101+
@abstractmethod
102+
def assertion(self, source_output, follow_up_output):
103+
"""An assertion that should be applied to an individual metamorphic test run."""
104+
105+
@abstractmethod
106+
def test_oracle(self, test_results):
107+
"""A test oracle that assert whether the MR holds or not based on ALL test results.
108+
109+
This method must raise an assertion, not return a bool."""
110+
111+
def __eq__(self, other):
112+
same_type = self.__class__ == other.__class__
113+
same_treatment = self.treatment_var == other.treatment_var
114+
same_output = self.output_var == other.output_var
115+
same_adjustment_set = set(self.adjustment_vars) == set(other.adjustment_vars)
116+
return same_type and same_treatment and same_output and same_adjustment_set
117+
118+
119+
class ShouldCause(MetamorphicRelation):
120+
"""Class representing a should cause metamorphic relation."""
121+
122+
def assertion(self, source_output, follow_up_output):
123+
"""If there is a causal effect, the outputs should not be the same."""
124+
return source_output != follow_up_output
125+
126+
def test_oracle(self, test_results):
127+
"""A single passing test is sufficient to show presence of a causal effect."""
128+
assert len(test_results["fail"]) < len(
129+
self.tests
130+
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
131+
132+
def __str__(self):
133+
formatted_str = f"{self.treatment_var} --> {self.output_var}"
134+
if self.adjustment_vars:
135+
formatted_str += f" | {self.adjustment_vars}"
136+
return formatted_str
137+
138+
139+
class ShouldNotCause(MetamorphicRelation):
140+
"""Class representing a should cause metamorphic relation."""
141+
142+
def assertion(self, source_output, follow_up_output):
143+
"""If there is a causal effect, the outputs should not be the same."""
144+
return source_output == follow_up_output
145+
146+
def test_oracle(self, test_results):
147+
"""A single passing test is sufficient to show presence of a causal effect."""
148+
assert (
149+
len(test_results["fail"]) == 0
150+
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
151+
152+
def __str__(self):
153+
formatted_str = f"{self.treatment_var} _||_ {self.output_var}"
154+
if self.adjustment_vars:
155+
formatted_str += f" | {self.adjustment_vars}"
156+
return formatted_str
157+
158+
159+
@dataclass(order=True)
160+
class MetamorphicTest:
161+
"""Class representing a metamorphic test case."""
162+
163+
source_inputs: dict
164+
follow_up_inputs: dict
165+
other_inputs: dict
166+
output: str
167+
relation: str
168+
169+
def __str__(self):
170+
return (
171+
f"Source inputs: {self.source_inputs}\n"
172+
f"Follow-up inputs: {self.follow_up_inputs}\n"
173+
f"Other inputs: {self.other_inputs}\n"
174+
f"Output: {self.output}"
175+
f"Metamorphic Relation: {self.relation}"
176+
)
177+
178+
179+
def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:
180+
"""Construct a list of metamorphic relations implied by the Causal DAG.
181+
182+
This list of metamorphic relations contains a ShouldCause relation for every edge, and a ShouldNotCause
183+
relation for every (minimal) conditional independence relation implied by the structure of the DAG.
184+
185+
:param CausalDAG dag: Causal DAG from which the metamorphic relations will be generated.
186+
:return: A list containing ShouldCause and ShouldNotCause metamorphic relations.
187+
"""
188+
metamorphic_relations = []
189+
for node_pair in combinations(dag.graph.nodes, 2):
190+
(u, v) = node_pair
191+
192+
# Create a ShouldNotCause relation for each pair of nodes that are not directly connected
193+
if ((u, v) not in dag.graph.edges) and ((v, u) not in dag.graph.edges):
194+
195+
# Case 1: U --> ... --> V
196+
if u in nx.ancestors(dag.graph, v):
197+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
198+
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
199+
200+
# Case 2: V --> ... --> U
201+
elif v in nx.ancestors(dag.graph, u):
202+
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
203+
metamorphic_relations.append(ShouldNotCause(v, u, adj_set, dag))
204+
205+
# Case 3: V _||_ U (No directed walk from V to U but there may be a back-door path e.g. U <-- Z --> V).
206+
# Only make one MR since V _||_ U == U _||_ V
207+
else:
208+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
209+
metamorphic_relations.append(ShouldNotCause(u, v, adj_set, dag))
210+
211+
# Create a ShouldCause relation for each edge (u, v) or (v, u)
212+
elif (u, v) in dag.graph.edges:
213+
adj_set = list(dag.direct_effect_adjustment_sets([u], [v])[0])
214+
metamorphic_relations.append(ShouldCause(u, v, adj_set, dag))
215+
else:
216+
adj_set = list(dag.direct_effect_adjustment_sets([v], [u])[0])
217+
metamorphic_relations.append(ShouldCause(v, u, adj_set, dag))
218+
219+
return metamorphic_relations

causal_testing/testing/causal_test_outcome.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,29 @@ def __str__(self) -> str:
2222
return type(self).__name__
2323

2424

25-
class ExactValue(CausalTestOutcome):
25+
class SomeEffect(CausalTestOutcome):
26+
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
27+
28+
def apply(self, res: CausalTestResult) -> bool:
29+
if res.test_value.type == "ate":
30+
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
31+
if res.test_value.type == "risk_ratio":
32+
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
33+
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
34+
35+
36+
class NoEffect(CausalTestOutcome):
37+
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
38+
39+
def apply(self, res: CausalTestResult) -> bool:
40+
if res.test_value.type == "ate":
41+
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
42+
if res.test_value.type == "risk_ratio":
43+
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
44+
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
45+
46+
47+
class ExactValue(SomeEffect):
2648
"""An extension of TestOutcome representing that the expected causal effect should be a specific value."""
2749

2850
def __init__(self, value: float, tolerance: float = None):
@@ -33,6 +55,8 @@ def __init__(self, value: float, tolerance: float = None):
3355
self.tolerance = tolerance
3456

3557
def apply(self, res: CausalTestResult) -> bool:
58+
if res.ci_valid():
59+
return super().apply(res) and np.isclose(res.test_value.value, self.value, atol=self.tolerance)
3660
return np.isclose(res.test_value.value, self.value, atol=self.tolerance)
3761

3862
def __str__(self):
@@ -43,6 +67,8 @@ class Positive(CausalTestOutcome):
4367
"""An extension of TestOutcome representing that the expected causal effect should be positive."""
4468

4569
def apply(self, res: CausalTestResult) -> bool:
70+
if res.ci_valid() and not super().apply(res):
71+
return False
4672
if res.test_value.type == "ate":
4773
return res.test_value.value > 0
4874
if res.test_value.type == "risk_ratio":
@@ -54,36 +80,10 @@ class Negative(CausalTestOutcome):
5480
"""An extension of TestOutcome representing that the expected causal effect should be negative."""
5581

5682
def apply(self, res: CausalTestResult) -> bool:
83+
if res.ci_valid() and not super().apply(res):
84+
return False
5785
if res.test_value.type == "ate":
5886
return res.test_value.value < 0
5987
if res.test_value.type == "risk_ratio":
6088
return res.test_value.value < 1
6189
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
62-
63-
64-
class SomeEffect(CausalTestOutcome):
65-
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""
66-
67-
def apply(self, res: CausalTestResult) -> bool:
68-
if res.test_value.type == "ate":
69-
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
70-
if res.test_value.type == "risk_ratio":
71-
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
72-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
73-
74-
def __str__(self):
75-
return "Changed"
76-
77-
78-
class NoEffect(CausalTestOutcome):
79-
"""An extension of TestOutcome representing that the expected causal effect should be zero."""
80-
81-
def apply(self, res: CausalTestResult) -> bool:
82-
if res.test_value.type == "ate":
83-
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < 1e-10)
84-
if res.test_value.type == "risk_ratio":
85-
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=1e-10)
86-
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
87-
88-
def __str__(self):
89-
return "Unchanged"

causal_testing/testing/causal_test_result.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ def ci_high(self):
8585
return max(self.confidence_intervals)
8686
return None
8787

88+
def ci_valid(self) -> bool:
89+
"""Return whether or not the result has valid confidence invervals"""
90+
return self.ci_low() and self.ci_high()
91+
8892
def summary(self):
8993
"""Summarise the causal test result as an intuitive sentence."""
9094
print(

causal_testing/testing/estimators.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:
355355
:return: The average treatment effect and the 95% Wald confidence intervals.
356356
"""
357357
model = self._run_linear_regression()
358+
self.model = model
358359

359360
# Create an empty individual for the control and treated
360361
individuals = pd.DataFrame(1, index=["control", "treated"], columns=model.params.index)

causal_testing/testing/validation.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""This module contains the CausalValidator class for performing Quantitive Bias Analysis techniques"""
2+
import math
3+
import numpy as np
4+
from scipy.stats import t
5+
from statsmodels.regression.linear_model import RegressionResultsWrapper
6+
7+
8+
class CausalValidator:
9+
"""A suite of validation tools to perform Quantitive Bias Analysis to back up causal claims"""
10+
11+
def estimate_robustness(self, model: RegressionResultsWrapper, q=1, alpha=1):
12+
"""Calculate the robustness of a linear regression model. This allow
13+
the user to identify how large an unidentified confounding variable
14+
would need to be to nullify the causal relationship under test."""
15+
16+
dof = model.df_resid
17+
t_values = model.tvalues
18+
19+
fq = q * abs(t_values / math.sqrt(dof))
20+
f_crit = abs(t.ppf(alpha / 2, dof - 1)) / math.sqrt(dof - 1)
21+
fqa = fq - f_crit
22+
23+
rv = 0.5 * (np.sqrt(fqa**4 + (4 * fqa**2)) - fqa**2)
24+
25+
return rv
26+
27+
def estimate_e_value(self, risk_ratio: float) -> float:
28+
"""Calculate the E value from a risk ratio. This allow
29+
the user to identify how large a risk an unidentified confounding
30+
variable would need to be to nullify the causal relationship
31+
under test."""
32+
33+
if risk_ratio >= 1:
34+
return risk_ratio + math.sqrt(risk_ratio * (risk_ratio - 1))
35+
36+
risk_ratio_prime = 1 / risk_ratio
37+
return risk_ratio_prime + math.sqrt(risk_ratio_prime * (risk_ratio_prime - 1))
38+
39+
def estimate_e_value_using_ci(self, risk_ratio: float, confidence_intervals: tuple[float, float]) -> float:
40+
"""Calculate the E value from a risk ratio and it's confidence intervals.
41+
This allow the user to identify how large a risk an unidentified
42+
confounding variable would need to be to nullify the causal relationship
43+
under test."""
44+
45+
if risk_ratio >= 1:
46+
lower_limit = confidence_intervals[0]
47+
e = 1
48+
if lower_limit > 1:
49+
e = lower_limit + math.sqrt(lower_limit * (lower_limit - 1))
50+
51+
return e
52+
53+
upper_limit = confidence_intervals[1]
54+
e = 1
55+
if upper_limit < 1:
56+
upper_limit_prime = 1 / upper_limit
57+
e = upper_limit_prime + math.sqrt(upper_limit_prime * (upper_limit_prime - 1))
58+
59+
return e

0 commit comments

Comments
 (0)