Skip to content

Commit 359835c

Browse files
committed
Blackened tests
1 parent ec4de83 commit 359835c

File tree

10 files changed

+418
-336
lines changed

10 files changed

+418
-336
lines changed

tests/data_collection_tests/test_observational_data_collector.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99

1010

1111
class TestObservationalDataCollector(unittest.TestCase):
12-
1312
def setUp(self) -> None:
1413
temp_dir_path = create_temp_dir_if_non_existent()
1514
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
1615
self.observational_df_path = os.path.join(temp_dir_path, "observational_data.csv")
1716
# Y = 3*X1 + X2*X3 + 10
1817
self.observational_df = pd.DataFrame({"X1": [1, 2, 3, 4], "X2": [5, 6, 7, 8], "X3": [10, 20, 30, 40]})
1918
self.observational_df["Y"] = self.observational_df.apply(
20-
lambda row: (3 * row.X1) + (row.X2 * row.X3) + 10, axis=1)
19+
lambda row: (3 * row.X1) + (row.X2 * row.X3) + 10, axis=1
20+
)
2121
self.observational_df.to_csv(self.observational_df_path)
2222
self.X1 = Input("X1", int, uniform(1, 4))
2323
self.X2 = Input("X2", int, rv_discrete(values=([7], [1])))
@@ -45,12 +45,13 @@ def test_data_constraints(self):
4545

4646
def test_meta_population(self):
4747
def populate_m(data):
48-
data['M'] = data['X1'] * 2
48+
data["M"] = data["X1"] * 2
49+
4950
meta = Meta("M", int, populate_m)
5051
scenario = Scenario({self.X1, meta})
5152
observational_data_collector = ObservationalDataCollector(scenario, self.observational_df_path)
5253
data = observational_data_collector.collect_data()
53-
assert all((m == 2*x1 for x1, m in zip(data['X1'], data['M'])))
54+
assert all((m == 2 * x1 for x1, m in zip(data["X1"], data["M"])))
5455

5556
def tearDown(self) -> None:
5657
remove_temp_dir_if_existent()

tests/generation_tests/test_abstract_test_case.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,21 @@
88
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
99
from causal_testing.testing.causal_test_outcome import Positive
1010

11+
1112
class TestAbstractTestCase(unittest.TestCase):
1213
"""
13-
Class to test abstract test cases.
14+
Class to test abstract test cases.
1415
"""
16+
1517
def setUp(self) -> None:
1618
temp_dir_path = create_temp_dir_if_non_existent()
1719
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
1820
self.observational_df_path = os.path.join(temp_dir_path, "observational_data.csv")
1921
# Y = 3*X1 + X2*X3 + 10
2022
self.observational_df = pd.DataFrame({"X1": [1, 2, 3, 4], "X2": [5, 6, 7, 8], "X3": [10, 20, 30, 40]})
21-
self.observational_df["Y"] = self.observational_df.apply(lambda row: (3 * row.X1) + (row.X2 * row.X3) + 10,
22-
axis=1)
23+
self.observational_df["Y"] = self.observational_df.apply(
24+
lambda row: (3 * row.X1) + (row.X2 * row.X3) + 10, axis=1
25+
)
2326
self.observational_df.to_csv(self.observational_df_path)
2427
self.X1 = Input("X1", float, uniform(1, 4))
2528
self.X2 = Input("X2", int, rv_discrete(values=([7], [1])))
@@ -51,9 +54,9 @@ def test_str(self):
5154
expected_causal_effect={self.Y: Positive()},
5255
effect_modifiers=None,
5356
)
54-
assert str(abstract) == \
55-
"When we apply intervention {X1' > X1}, the effect on Output: Y::int should be Positive", \
56-
f"Unexpected string {str(abstract)}"
57+
assert (
58+
str(abstract) == "When we apply intervention {X1' > X1}, the effect on Output: Y::int should be Positive"
59+
), f"Unexpected string {str(abstract)}"
5760

5861
def test_datapath(self):
5962
scenario = Scenario({self.X1, self.X2, self.X3, self.X4})
@@ -109,7 +112,6 @@ def test_generate_concrete_test_cases_rct(self):
109112
assert len(concrete_tests) == 2, "Expected 2 concrete tests"
110113
assert len(runs) == 4, "Expected 4 runs"
111114

112-
113115
def test_infeasible_constraints(self):
114116
scenario = Scenario({self.X1, self.X2, self.X3, self.X4}, [self.X1.z3 > 2])
115117
scenario.setup_treatment_variables()
@@ -125,10 +127,9 @@ def test_infeasible_constraints(self):
125127

126128
with self.assertWarns(Warning):
127129
concrete_tests, runs = abstract.generate_concrete_tests(4, rct=True, target_ks_score=0.1, hard_max=HARD_MAX)
128-
self.assertTrue(all((x > 2 for x in runs['X1'])))
130+
self.assertTrue(all((x > 2 for x in runs["X1"])))
129131
self.assertEqual(len(concrete_tests), HARD_MAX * NUM_STRATA)
130132

131-
132133
def test_feasible_constraints(self):
133134
scenario = Scenario({self.X1, self.X2, self.X3, self.X4})
134135
scenario.setup_treatment_variables()
@@ -142,7 +143,6 @@ def test_feasible_constraints(self):
142143
concrete_tests, _ = abstract.generate_concrete_tests(4, rct=True, target_ks_score=0.1, hard_max=1000)
143144
assert len(concrete_tests) < 1000
144145

145-
146146
def tearDown(self) -> None:
147147
remove_temp_dir_if_existent()
148148

tests/json_front_tests/test_json_class.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,10 @@ def test_generate_tests_from_json(self):
9393
self.json_class.test_plan = example_test
9494
effects = {"NoEffect": NoEffect()}
9595
mutates = {
96-
"Increase": lambda x: self.json_class.modelling_scenario.treatment_variables[x].z3 >
97-
self.json_class.modelling_scenario.variables[x].z3
98-
}
99-
estimators = {
100-
"LinearRegressionEstimator": LinearRegressionEstimator
96+
"Increase": lambda x: self.json_class.modelling_scenario.treatment_variables[x].z3
97+
> self.json_class.modelling_scenario.variables[x].z3
10198
}
99+
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
102100

103101
with self.assertLogs() as captured:
104102
self.json_class.generate_tests(effects, mutates, estimators, False)
@@ -108,9 +106,8 @@ def test_generate_tests_from_json(self):
108106

109107
def tearDown(self) -> None:
110108
pass
111-
#remove_temp_dir_if_existent()
109+
# remove_temp_dir_if_existent()
112110

113111

114112
def populate_example(*args, **kwargs):
115113
pass
116-

tests/specification_tests/test_causal_dag.py

Lines changed: 29 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,12 @@ def test_valid_causal_dag(self):
5050
"""Test whether the Causal DAG is valid."""
5151
causal_dag = CausalDAG(self.dag_dot_path)
5252
print(causal_dag)
53-
assert list(causal_dag.graph.nodes) == ["A", "B", "C", "D"] and list(
54-
causal_dag.graph.edges
55-
) == [("A", "B"), ("B", "C"), ("D", "A"), ("D", "C")]
53+
assert list(causal_dag.graph.nodes) == ["A", "B", "C", "D"] and list(causal_dag.graph.edges) == [
54+
("A", "B"),
55+
("B", "C"),
56+
("D", "A"),
57+
("D", "C"),
58+
]
5659

5760
def test_invalid_causal_dag(self):
5861
"""Test whether a cycle-containing directed graph is an invalid causal DAG."""
@@ -96,9 +99,7 @@ class TestDAGDirectEffectIdentification(unittest.TestCase):
9699
def setUp(self) -> None:
97100
temp_dir_path = create_temp_dir_if_non_existent()
98101
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
99-
dag_dot = (
100-
"""digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
101-
)
102+
dag_dot = """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
102103
f = open(self.dag_dot_path, "w")
103104
f.write(dag_dot)
104105
f.close()
@@ -122,9 +123,7 @@ class TestDAGIdentification(unittest.TestCase):
122123
def setUp(self) -> None:
123124
temp_dir_path = create_temp_dir_if_non_existent()
124125
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
125-
dag_dot = (
126-
"""digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
127-
)
126+
dag_dot = """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
128127
f = open(self.dag_dot_path, "w")
129128
f.write(dag_dot)
130129
f.close()
@@ -137,13 +136,10 @@ def test_get_indirect_graph(self):
137136
self.assertEqual(list(indirect_graph.graph.edges), original_edges)
138137
self.assertEqual(indirect_graph.graph.nodes, causal_dag.graph.nodes)
139138

140-
141139
def test_proper_backdoor_graph(self):
142140
"""Test whether converting a Causal DAG to a proper back-door graph works correctly."""
143141
causal_dag = CausalDAG(self.dag_dot_path)
144-
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(
145-
["X1", "X2"], ["Y"]
146-
)
142+
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(["X1", "X2"], ["Y"])
147143
self.assertEqual(
148144
list(proper_backdoor_graph.graph.edges),
149145
[
@@ -163,11 +159,7 @@ def test_constructive_backdoor_criterion_should_hold(self):
163159
causal_dag = CausalDAG(self.dag_dot_path)
164160
xs, ys, zs = ["X1", "X2"], ["Y"], ["Z"]
165161
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(xs, ys)
166-
self.assertTrue(
167-
causal_dag.constructive_backdoor_criterion(
168-
proper_backdoor_graph, xs, ys, zs
169-
)
170-
)
162+
self.assertTrue(causal_dag.constructive_backdoor_criterion(proper_backdoor_graph, xs, ys, zs))
171163

172164
def test_constructive_backdoor_criterion_should_not_hold_not_d_separator_in_proper_backdoor_graph(
173165
self,
@@ -176,11 +168,7 @@ def test_constructive_backdoor_criterion_should_not_hold_not_d_separator_in_prop
176168
causal_dag = CausalDAG(self.dag_dot_path)
177169
xs, ys, zs = ["X1", "X2"], ["Y"], ["V"]
178170
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(xs, ys)
179-
self.assertFalse(
180-
causal_dag.constructive_backdoor_criterion(
181-
proper_backdoor_graph, xs, ys, zs
182-
)
183-
)
171+
self.assertFalse(causal_dag.constructive_backdoor_criterion(proper_backdoor_graph, xs, ys, zs))
184172

185173
def test_constructive_backdoor_criterion_should_not_hold_descendent_of_proper_causal_path(
186174
self,
@@ -190,11 +178,7 @@ def test_constructive_backdoor_criterion_should_not_hold_descendent_of_proper_ca
190178
causal_dag = CausalDAG(self.dag_dot_path)
191179
xs, ys, zs = ["X1", "X2"], ["Y"], ["D1"]
192180
proper_backdoor_graph = causal_dag.get_proper_backdoor_graph(xs, ys)
193-
self.assertFalse(
194-
causal_dag.constructive_backdoor_criterion(
195-
proper_backdoor_graph, xs, ys, zs
196-
)
197-
)
181+
self.assertFalse(causal_dag.constructive_backdoor_criterion(proper_backdoor_graph, xs, ys, zs))
198182

199183
def test_is_min_adjustment_for_min_adjustment(self):
200184
"""Test whether is_min_adjustment can correctly test whether the minimum adjustment set is minimal."""
@@ -262,9 +246,7 @@ def test_enumerate_minimal_adjustment_sets_multiple(self):
262246
)
263247
xs, ys = ["X1", "X2"], ["Y"]
264248
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
265-
set_of_adjustment_sets = set(
266-
frozenset(min_separator) for min_separator in adjustment_sets
267-
)
249+
set_of_adjustment_sets = set(frozenset(min_separator) for min_separator in adjustment_sets)
268250
self.assertEqual(
269251
{frozenset({"Z1"}), frozenset({"Z2"}), frozenset({"Z3"})},
270252
set_of_adjustment_sets,
@@ -291,9 +273,7 @@ def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
291273
)
292274
xs, ys = ["X1", "X2"], ["Y"]
293275
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
294-
set_of_adjustment_sets = set(
295-
frozenset(min_separator) for min_separator in adjustment_sets
296-
)
276+
set_of_adjustment_sets = set(frozenset(min_separator) for min_separator in adjustment_sets)
297277
self.assertEqual(
298278
{frozenset({"Z1", "Z4"}), frozenset({"Z2", "Z4"}), frozenset({"Z3", "Z4"})},
299279
set_of_adjustment_sets,
@@ -304,20 +284,20 @@ def test_dag_with_non_character_nodes(self):
304284
causal_dag = CausalDAG()
305285
causal_dag.graph.add_edges_from(
306286
[
307-
('va', 'ba'),
308-
('ba', 'ia'),
309-
('ba', 'da'),
310-
('ba', 'ra'),
311-
('la', 'va'),
312-
('la', 'aa'),
313-
('aa', 'ia'),
314-
('aa', 'da'),
315-
('aa', 'ra'),
287+
("va", "ba"),
288+
("ba", "ia"),
289+
("ba", "da"),
290+
("ba", "ra"),
291+
("la", "va"),
292+
("la", "aa"),
293+
("aa", "ia"),
294+
("aa", "da"),
295+
("aa", "ra"),
316296
]
317297
)
318-
xs, ys = ['ba'], ['da']
298+
xs, ys = ["ba"], ["da"]
319299
adjustment_sets = causal_dag.enumerate_minimal_adjustment_sets(xs, ys)
320-
self.assertEqual(adjustment_sets, [{'aa'}, {'la'}, {'va'}])
300+
self.assertEqual(adjustment_sets, [{"aa"}, {"la"}, {"va"}])
321301

322302
def tearDown(self) -> None:
323303
remove_temp_dir_if_existent()
@@ -385,9 +365,7 @@ class TestUndirectedGraphAlgorithms(unittest.TestCase):
385365

386366
def setUp(self) -> None:
387367
self.graph = nx.Graph()
388-
self.graph.add_edges_from(
389-
[("a", 2), ("a", 3), (2, 4), (3, 5), (3, 4), (4, "b"), (5, "b")]
390-
)
368+
self.graph.add_edges_from([("a", 2), ("a", 3), (2, 4), (3, 5), (3, 4), (4, "b"), (5, "b")])
391369
self.treatment_node = "a"
392370
self.outcome_node = "b"
393371
self.treatment_node_set = {"a"}
@@ -396,9 +374,7 @@ def setUp(self) -> None:
396374

397375
def test_close_separator(self):
398376
"""Test whether close_separator correctly identifies the close separator of {2,3} in the undirected graph."""
399-
result = close_separator(
400-
self.graph, self.treatment_node, self.outcome_node, self.treatment_node_set
401-
)
377+
result = close_separator(self.graph, self.treatment_node, self.outcome_node, self.treatment_node_set)
402378
self.assertEqual({2, 3}, result)
403379

404380
def test_list_all_min_sep(self):
@@ -414,12 +390,8 @@ def test_list_all_min_sep(self):
414390
)
415391

416392
# Convert list of sets to set of frozen sets for comparison
417-
min_separators = set(
418-
frozenset(min_separator) for min_separator in min_separators
419-
)
420-
self.assertEqual(
421-
{frozenset({2, 3}), frozenset({3, 4}), frozenset({4, 5})}, min_separators
422-
)
393+
min_separators = set(frozenset(min_separator) for min_separator in min_separators)
394+
self.assertEqual({frozenset({2, 3}), frozenset({3, 4}), frozenset({4, 5})}, min_separators)
423395

424396
def tearDown(self) -> None:
425397
remove_temp_dir_if_existent()

0 commit comments

Comments
 (0)