Skip to content

Commit 393f6d9

Browse files
author
AndrewC19
committed
Formatted with black
1 parent 044fe26 commit 393f6d9

File tree

5 files changed

+248
-263
lines changed

5 files changed

+248
-263
lines changed

causal_testing/data_collection/data_collector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ def collect_data(self, **kwargs) -> pd.DataFrame:
106106
executions.
107107
"""
108108
control_results_df = self.run_system_with_input_configuration(self.control_input_configuration)
109-
control_results_df.rename('control_{}'.format, inplace=True)
109+
control_results_df.rename("control_{}".format, inplace=True)
110110
treatment_results_df = self.run_system_with_input_configuration(self.treatment_input_configuration)
111-
treatment_results_df.rename('treatment_{}'.format, inplace=True)
111+
treatment_results_df.rename("treatment_{}".format, inplace=True)
112112
results_df = pd.concat([control_results_df, treatment_results_df], ignore_index=False)
113113
return results_df
114114

causal_testing/specification/metamorphic_relation.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,50 +13,42 @@
1313
@dataclass(order=True)
1414
class MetamorphicRelation:
1515
"""Class representing a metamorphic relation."""
16+
1617
treatment_var: Node
1718
output_var: Node
1819
adjustment_vars: Iterable[Node]
1920
dag: CausalDAG
2021
tests: Iterable = None
2122

22-
def generate_follow_up(self,
23-
n_tests: int,
24-
min_val: float,
25-
max_val: float,
26-
seed: int = 0):
23+
def generate_follow_up(self, n_tests: int, min_val: float, max_val: float, seed: int = 0):
2724
"""Generate numerical follow-up input configurations."""
2825
np.random.seed(seed)
2926

3027
# Get set of variables to change, excluding the treatment itself
31-
variables_to_change = set([node for node in self.dag.graph.nodes if
32-
self.dag.graph.in_degree(node) == 0])
28+
variables_to_change = set([node for node in self.dag.graph.nodes if self.dag.graph.in_degree(node) == 0])
3329
if self.adjustment_vars:
3430
variables_to_change |= set(self.adjustment_vars)
3531
if self.treatment_var in variables_to_change:
3632
variables_to_change.remove(self.treatment_var)
3733

3834
# Assign random numerical values to the variables to change
3935
test_inputs = pd.DataFrame(
40-
np.random.randint(min_val, max_val,
41-
size=(n_tests, len(variables_to_change))
42-
),
43-
columns=sorted(variables_to_change)
36+
np.random.randint(min_val, max_val, size=(n_tests, len(variables_to_change))),
37+
columns=sorted(variables_to_change),
4438
)
4539

4640
# Enumerate the possible source, follow-up pairs for the treatment
47-
candidate_source_follow_up_pairs = np.array(
48-
list(combinations(range(int(min_val), int(max_val+1)), 2))
49-
)
41+
candidate_source_follow_up_pairs = np.array(list(combinations(range(int(min_val), int(max_val + 1)), 2)))
5042

5143
# Sample without replacement from the possible source, follow-up pairs
5244
sampled_source_follow_up_indices = np.random.choice(
5345
candidate_source_follow_up_pairs.shape[0], n_tests, replace=False
5446
)
5547

56-
follow_up_input = f"{self.treatment_var}\'"
48+
follow_up_input = f"{self.treatment_var}'"
5749
source_follow_up_test_inputs = pd.DataFrame(
5850
candidate_source_follow_up_pairs[sampled_source_follow_up_indices],
59-
columns=sorted([self.treatment_var] + [follow_up_input])
51+
columns=sorted([self.treatment_var] + [follow_up_input]),
6052
)
6153
source_test_inputs = source_follow_up_test_inputs[[self.treatment_var]]
6254
follow_up_test_inputs = source_follow_up_test_inputs[[follow_up_input]]
@@ -69,12 +61,13 @@ def generate_follow_up(self,
6961
other_test_inputs_record = [{}] * len(source_test_inputs)
7062
metamorphic_tests = []
7163
for i in range(len(source_test_inputs_record)):
72-
metamorphic_test = MetamorphicTest(source_test_inputs_record[i],
73-
follow_up_test_inputs_record[i],
74-
other_test_inputs_record[i],
75-
self.output_var,
76-
str(self)
77-
)
64+
metamorphic_test = MetamorphicTest(
65+
source_test_inputs_record[i],
66+
follow_up_test_inputs_record[i],
67+
other_test_inputs_record[i],
68+
self.output_var,
69+
str(self),
70+
)
7871
metamorphic_tests.append(metamorphic_test)
7972
self.tests = metamorphic_tests
8073

@@ -124,8 +117,9 @@ def assertion(self, source_output, follow_up_output):
124117

125118
def test_oracle(self, test_results):
126119
"""A single passing test is sufficient to show presence of a causal effect."""
127-
assert len(test_results["fail"]) < len(self.tests),\
128-
f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
120+
assert len(test_results["fail"]) < len(
121+
self.tests
122+
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
129123

130124
def __str__(self):
131125
formatted_str = f"{self.treatment_var} --> {self.output_var}"
@@ -143,8 +137,9 @@ def assertion(self, source_output, follow_up_output):
143137

144138
def test_oracle(self, test_results):
145139
"""A single passing test is sufficient to show presence of a causal effect."""
146-
assert len(test_results["fail"]) == 0,\
147-
f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
140+
assert (
141+
len(test_results["fail"]) == 0
142+
), f"{str(self)}: {len(test_results['fail'])}/{len(self.tests)} tests failed."
148143

149144
def __str__(self):
150145
formatted_str = f"{self.treatment_var} _||_ {self.output_var}"
@@ -156,18 +151,21 @@ def __str__(self):
156151
@dataclass(order=True)
157152
class MetamorphicTest:
158153
"""Class representing a metamorphic test case."""
154+
159155
source_inputs: dict
160156
follow_up_inputs: dict
161157
other_inputs: dict
162158
output: str
163159
relation: str
164160

165161
def __str__(self):
166-
return f"Source inputs: {self.source_inputs}\n" \
167-
f"Follow-up inputs: {self.follow_up_inputs}\n" \
168-
f"Other inputs: {self.other_inputs}\n" \
169-
f"Output: {self.output}" \
170-
f"Metamorphic Relation: {self.relation}"
162+
return (
163+
f"Source inputs: {self.source_inputs}\n"
164+
f"Follow-up inputs: {self.follow_up_inputs}\n"
165+
f"Other inputs: {self.other_inputs}\n"
166+
f"Output: {self.output}"
167+
f"Metamorphic Relation: {self.relation}"
168+
)
171169

172170

173171
def generate_metamorphic_relations(dag: CausalDAG) -> list[MetamorphicRelation]:

tests/specification_tests/test_causal_dag.py

Lines changed: 29 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@ def test_valid_causal_dag(self):
5050
"""Test whether the Causal DAG is valid."""
5151
causal_dag = CausalDAG(self.dag_dot_path)
5252
print(causal_dag)
53-
assert list(causal_dag.graph.nodes) == ["A", "B", "C", "D"] and list(
54-
causal_dag.graph.edges
55-
) == [("A", "B"), ("B", "C"), ("D", "A"), ("D", "C")]
53+
assert list(causal_dag.graph.nodes) == ["A", "B", "C", "D"] and list(causal_dag.graph.edges) == [
54+
("A", "B"),
55+
("B", "C"),
56+
("D", "A"),
57+
("D", "C"),
58+
]
5659

5760
def test_invalid_causal_dag(self):
5861
"""Test whether a cycle-containing directed graph is an invalid causal DAG."""
@@ -96,9 +99,7 @@ class TestDAGDirectEffectIdentification(unittest.TestCase):
9699
def setUp(self) -> None:
97100
temp_dir_path = create_temp_dir_if_non_existent()
98101
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
99-
dag_dot = (
100-
"""digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
101-
)
102+
dag_dot = """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
102103
f = open(self.dag_dot_path, "w")
103104
f.write(dag_dot)
104105
f.close()
@@ -122,9 +123,7 @@ class TestDAGIdentification(unittest.TestCase):
122123
def setUp(self) -> None:
123124
temp_dir_path = create_temp_dir_if_non_existent()
124125
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
125-
dag_dot = (
126-
"""digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
127-
)
126+
dag_dot = """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
128127
f = open(self.dag_dot_path, "w")
129128
f.write(dag_dot)
130129
f.close()
@@ -137,13 +136,10 @@ def test_get_indirect_graph(self):
137136
self.assertEqual(list(indirect_graph.graph.edges), original_edges)
138137
self.assertEqual(indirect_graph.graph.nodes, causal_dag.graph.nodes)
139138

140-
141139
def test_proper_backdoor_graph(self):
142140
"""Test whether converting a Causal DAG to a proper back-door graph works correctly."""
143141
causal_dag = CausalDAG(self.dag_dot_path)
144-
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(
145-
["X1", "X2"], ["Y"]
146-
)
142+
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(["X1", "X2"], ["Y"])
147143
self.assertEqual(
148144
list(proper_backdoor_graph.graph.edges),
149145
[
@@ -163,11 +159,7 @@ def test_constructive_backdoor_criterion_should_hold(self):
163159
causal_dag = CausalDAG(self.dag_dot_path)
164160
xs, ys, zs = ["X1", "X2"], ["Y"], ["Z"]
165161
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(xs, ys)
166-
self.assertTrue(
167-
causal_dag.constructive_backdoor_criterion(
168-
proper_backdoor_graph, xs, ys, zs
169-
)
170-
)
162+
self.assertTrue(causal_dag.constructive_backdoor_criterion(proper_backdoor_graph, xs, ys, zs))
171163

172164
def test_constructive_backdoor_criterion_should_not_hold_not_d_separator_in_proper_backdoor_graph(
173165
self,
@@ -176,11 +168,7 @@ def test_constructive_backdoor_criterion_should_not_hold_not_d_separator_in_prop
176168
causal_dag = CausalDAG(self.dag_dot_path)
177169
xs, ys, zs = ["X1", "X2"], ["Y"], ["V"]
178170
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(xs, ys)
179-
self.assertFalse(
180-
causal_dag.constructive_backdoor_criterion(
181-
proper_backdoor_graph, xs, ys, zs
182-
)
183-
)
171+
self.assertFalse(causal_dag.constructive_backdoor_criterion(proper_backdoor_graph, xs, ys, zs))
184172

185173
def test_constructive_backdoor_criterion_should_not_hold_descendent_of_proper_causal_path(
186174
self,
@@ -190,11 +178,7 @@ def test_constructive_backdoor_criterion_should_not_hold_descendent_of_proper_ca
190178
causal_dag = CausalDAG(self.dag_dot_path)
191179
xs, ys, zs = ["X1", "X2"], ["Y"], ["D1"]
192180
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(xs, ys)
193-
self.assertFalse(
194-
causal_dag.constructive_backdoor_criterion(
195-
proper_backdoor_graph, xs, ys, zs
196-
)
197-
)
181+
self.assertFalse(causal_dag.constructive_backdoor_criterion(proper_backdoor_graph, xs, ys, zs))
198182

199183
def test_is_min_adjustment_for_min_adjustment(self):
200184
"""Test whether is_min_adjustment can correctly test whether the minimum adjustment set is minimal."""
@@ -262,9 +246,7 @@ def test_enumerate_minimal_adjustment_sets_multiple(self):
262246
)
263247
xs, ys = ["X1", "X2"], ["Y"]
264248
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
265-
set_of_adjustment_sets = set(
266-
frozenset(min_separator) for min_separator in adjustment_sets
267-
)
249+
set_of_adjustment_sets = set(frozenset(min_separator) for min_separator in adjustment_sets)
268250
self.assertEqual(
269251
{frozenset({"Z1"}), frozenset({"Z2"}), frozenset({"Z3"})},
270252
set_of_adjustment_sets,
@@ -291,9 +273,7 @@ def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
291273
)
292274
xs, ys = ["X1", "X2"], ["Y"]
293275
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
294-
set_of_adjustment_sets = set(
295-
frozenset(min_separator) for min_separator in adjustment_sets
296-
)
276+
set_of_adjustment_sets = set(frozenset(min_separator) for min_separator in adjustment_sets)
297277
self.assertEqual(
298278
{frozenset({"Z1", "Z4"}), frozenset({"Z2", "Z4"}), frozenset({"Z3", "Z4"})},
299279
set_of_adjustment_sets,
@@ -304,20 +284,20 @@ def test_dag_with_non_character_nodes(self):
304284
causal_dag = CausalDAG()
305285
causal_dag.graph.add_edges_from(
306286
[
307-
('va', 'ba'),
308-
('ba', 'ia'),
309-
('ba', 'da'),
310-
('ba', 'ra'),
311-
('la', 'va'),
312-
('la', 'aa'),
313-
('aa', 'ia'),
314-
('aa', 'da'),
315-
('aa', 'ra'),
287+
("va", "ba"),
288+
("ba", "ia"),
289+
("ba", "da"),
290+
("ba", "ra"),
291+
("la", "va"),
292+
("la", "aa"),
293+
("aa", "ia"),
294+
("aa", "da"),
295+
("aa", "ra"),
316296
]
317297
)
318-
xs, ys = ['ba'], ['da']
298+
xs, ys = ["ba"], ["da"]
319299
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
320-
self.assertEqual(adjustment_sets, [{'aa'}, {'la'}, {'va'}])
300+
self.assertEqual(adjustment_sets, [{"aa"}, {"la"}, {"va"}])
321301

322302
def tearDown(self) -> None:
323303
remove_temp_dir_if_existent()
@@ -385,9 +365,7 @@ class TestUndirectedGraphAlgorithms(unittest.TestCase):
385365

386366
def setUp(self) -> None:
387367
self.graph = nx.Graph()
388-
self.graph.add_edges_from(
389-
[("a", 2), ("a", 3), (2, 4), (3, 5), (3, 4), (4, "b"), (5, "b")]
390-
)
368+
self.graph.add_edges_from([("a", 2), ("a", 3), (2, 4), (3, 5), (3, 4), (4, "b"), (5, "b")])
391369
self.treatment_node = "a"
392370
self.outcome_node = "b"
393371
self.treatment_node_set = {"a"}
@@ -396,9 +374,7 @@ def setUp(self) -> None:
396374

397375
def test_close_separator(self):
398376
"""Test whether close_separator correctly identifies the close separator of {2,3} in the undirected graph."""
399-
result = close_separator(
400-
self.graph, self.treatment_node, self.outcome_node, self.treatment_node_set
401-
)
377+
result = close_separator(self.graph, self.treatment_node, self.outcome_node, self.treatment_node_set)
402378
self.assertEqual({2, 3}, result)
403379

404380
def test_list_all_min_sep(self):
@@ -414,12 +390,8 @@ def test_list_all_min_sep(self):
414390
)
415391

416392
# Convert list of sets to set of frozen sets for comparison
417-
min_separators = set(
418-
frozenset(min_separator) for min_separator in min_separators
419-
)
420-
self.assertEqual(
421-
{frozenset({2, 3}), frozenset({3, 4}), frozenset({4, 5})}, min_separators
422-
)
393+
min_separators = set(frozenset(min_separator) for min_separator in min_separators)
394+
self.assertEqual({frozenset({2, 3}), frozenset({3, 4}), frozenset({4, 5})}, min_separators)
423395

424396
def tearDown(self) -> None:
425397
remove_temp_dir_if_existent()

0 commit comments

Comments
 (0)