Skip to content

Commit 355a42d

Browse files
Update DAG tests to use TempFile
1 parent f62011b commit 355a42d

File tree

1 file changed

+25
-25
lines changed

1 file changed

+25
-25
lines changed

tests/specification_tests/test_causal_dag.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def tearDown(self) -> None:
3535

3636
class TestIVAssumptions(unittest.TestCase):
3737
def setUp(self) -> None:
38-
temp_dir_path = create_temp_dir_if_non_existent()
39-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
38+
self.temp_dir_path = tempfile.mkdtemp()
39+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
4040
dag_dot = """digraph G { I -> X; X -> Y; U -> X; U -> Y;}"""
4141
f = open(self.dag_dot_path, "w")
4242
f.write(dag_dot)
@@ -63,7 +63,9 @@ def test_common_cause(self):
6363
causal_dag.graph.add_edge("U", "I")
6464
with self.assertRaises(ValueError):
6565
causal_dag.check_iv_assumptions("X", "Y", "I")
66-
66+
67+
def tearDown(self) -> None:
68+
shutil.rmtree(self.temp_dir_path)
6769

6870
class TestCausalDAG(unittest.TestCase):
6971
"""
@@ -74,8 +76,8 @@ class TestCausalDAG(unittest.TestCase):
7476
"""
7577

7678
def setUp(self) -> None:
77-
temp_dir_path = create_temp_dir_if_non_existent()
78-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
79+
self.temp_dir_path = tempfile.mkdtemp()
80+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
7981
dag_dot = """digraph G { A -> B; B -> C; D -> A; D -> C;}"""
8082
f = open(self.dag_dot_path, "w")
8183
f.write(dag_dot)
@@ -107,7 +109,7 @@ def test_to_dot_string(self):
107109
self.assertEqual(causal_dag.to_dot_string(), """digraph G {\nA -> B;\nB -> C;\nD -> A;\nD -> C;\n}""")
108110

109111
def tearDown(self) -> None:
110-
remove_temp_dir_if_existent()
112+
shutil.rmtree(self.temp_dir_path)
111113

112114

113115
class TestCyclicCausalDAG(unittest.TestCase):
@@ -116,8 +118,8 @@ class TestCyclicCausalDAG(unittest.TestCase):
116118
"""
117119

118120
def setUp(self) -> None:
119-
temp_dir_path = create_temp_dir_if_non_existent()
120-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
121+
self.temp_dir_path = tempfile.mkdtemp()
122+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
121123
dag_dot = """digraph G { A -> B; B -> C; D -> A; D -> C; C -> A;}"""
122124
f = open(self.dag_dot_path, "w")
123125
f.write(dag_dot)
@@ -127,7 +129,7 @@ def test_invalid_causal_dag(self):
127129
self.assertRaises(nx.HasACycle, CausalDAG, self.dag_dot_path)
128130

129131
def tearDown(self) -> None:
130-
remove_temp_dir_if_existent()
132+
shutil.rmtree(self.temp_dir_path)
131133

132134

133135
class TestDAGDirectEffectIdentification(unittest.TestCase):
@@ -136,8 +138,8 @@ class TestDAGDirectEffectIdentification(unittest.TestCase):
136138
"""
137139

138140
def setUp(self) -> None:
139-
temp_dir_path = create_temp_dir_if_non_existent()
140-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
141+
self.temp_dir_path = tempfile.mkdtemp()
142+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
141143
dag_dot = """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
142144
f = open(self.dag_dot_path, "w")
143145
f.write(dag_dot)
@@ -152,16 +154,18 @@ def test_direct_effect_adjustment_sets_no_adjustment(self):
152154
causal_dag = CausalDAG(self.dag_dot_path)
153155
adjustment_sets = causal_dag.direct_effect_adjustment_sets(["X2"], ["D1"])
154156
self.assertEqual(list(adjustment_sets), [set()])
155-
157+
158+
def tearDown(self) -> None:
159+
shutil.rmtree(self.temp_dir_path)
156160

157161
class TestDAGIdentification(unittest.TestCase):
158162
"""
159163
Test the Causal DAG identification algorithms and supporting algorithms.
160164
"""
161165

162166
def setUp(self) -> None:
163-
temp_dir_path = create_temp_dir_if_non_existent()
164-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
167+
self.temp_dir_path = tempfile.mkdtemp()
168+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
165169
dag_dot = """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
166170
f = open(self.dag_dot_path, "w")
167171
f.write(dag_dot)
@@ -339,8 +343,7 @@ def test_dag_with_non_character_nodes(self):
339343
self.assertEqual(adjustment_sets, [{"aa"}, {"la"}, {"va"}])
340344

341345
def tearDown(self) -> None:
342-
remove_temp_dir_if_existent()
343-
346+
shutil.rmtree(self.temp_dir_path)
344347

345348
class TestDependsOnOutputs(unittest.TestCase):
346349
"""
@@ -352,8 +355,8 @@ def setUp(self) -> None:
352355
from causal_testing.specification.variable import Input, Output, Meta
353356
from causal_testing.specification.scenario import Scenario
354357

355-
temp_dir_path = create_temp_dir_if_non_existent()
356-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
358+
self.temp_dir_path = tempfile.mkdtemp()
359+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
357360
dag_dot = """digraph G { A -> B; B -> C; D -> A; D -> C}"""
358361
f = open(self.dag_dot_path, "w")
359362
f.write(dag_dot)
@@ -391,7 +394,7 @@ def test_depends_on_outputs_input(self):
391394
self.assertFalse(causal_dag.depends_on_outputs("D", self.scenario))
392395

393396
def tearDown(self) -> None:
394-
remove_temp_dir_if_existent()
397+
shutil.rmtree(self.temp_dir_path)
395398

396399

397400
class TestUndirectedGraphAlgorithms(unittest.TestCase):
@@ -431,18 +434,15 @@ def test_list_all_min_sep(self):
431434
min_separators = set(frozenset(min_separator) for min_separator in min_separators)
432435
self.assertEqual({frozenset({2, 3}), frozenset({3, 4}), frozenset({4, 5})}, min_separators)
433436

434-
def tearDown(self) -> None:
435-
remove_temp_dir_if_existent()
436-
437437

438438
class TestHiddenVariableDAG(unittest.TestCase):
439439
"""
440440
Test the CausalDAG identification for the exclusion of hidden variables.
441441
"""
442442

443443
def setUp(self) -> None:
444-
temp_dir_path = create_temp_dir_if_non_existent()
445-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
444+
self.temp_dir_path = tempfile.mkdtemp()
445+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
446446
dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M; M -> Y; Z -> M; }"""
447447
with open(self.dag_dot_path, "w") as f:
448448
f.write(dag_dot)
@@ -463,4 +463,4 @@ def test_hidden_varaible_adjustment_sets(self):
463463
self.assertNotEqual(adjustment_sets, adjustment_sets_with_hidden)
464464

465465
def tearDown(self) -> None:
466-
remove_temp_dir_if_existent()
466+
shutil.rmtree(self.temp_dir_path)

0 commit comments

Comments
 (0)