@@ -35,8 +35,8 @@ def tearDown(self) -> None:
35
35
36
36
class TestIVAssumptions (unittest .TestCase ):
37
37
def setUp (self ) -> None :
38
- temp_dir_path = create_temp_dir_if_non_existent ()
39
- self .dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
38
+ self . temp_dir_path = tempfile . mkdtemp ()
39
+ self .dag_dot_path = os .path .join (self . temp_dir_path , "dag.dot" )
40
40
dag_dot = """digraph G { I -> X; X -> Y; U -> X; U -> Y;}"""
41
41
f = open (self .dag_dot_path , "w" )
42
42
f .write (dag_dot )
@@ -63,7 +63,9 @@ def test_common_cause(self):
63
63
causal_dag .graph .add_edge ("U" , "I" )
64
64
with self .assertRaises (ValueError ):
65
65
causal_dag .check_iv_assumptions ("X" , "Y" , "I" )
66
-
66
+
67
+ def tearDown (self ) -> None :
68
+ shutil .rmtree (self .temp_dir_path )
67
69
68
70
class TestCausalDAG (unittest .TestCase ):
69
71
"""
@@ -74,8 +76,8 @@ class TestCausalDAG(unittest.TestCase):
74
76
"""
75
77
76
78
def setUp (self ) -> None :
77
- temp_dir_path = create_temp_dir_if_non_existent ()
78
- self .dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
79
+ self . temp_dir_path = tempfile . mkdtemp ()
80
+ self .dag_dot_path = os .path .join (self . temp_dir_path , "dag.dot" )
79
81
dag_dot = """digraph G { A -> B; B -> C; D -> A; D -> C;}"""
80
82
f = open (self .dag_dot_path , "w" )
81
83
f .write (dag_dot )
@@ -107,7 +109,7 @@ def test_to_dot_string(self):
107
109
self .assertEqual (causal_dag .to_dot_string (), """digraph G {\n A -> B;\n B -> C;\n D -> A;\n D -> C;\n }""" )
108
110
109
111
def tearDown (self ) -> None :
110
- remove_temp_dir_if_existent ( )
112
+ shutil . rmtree ( self . temp_dir_path )
111
113
112
114
113
115
class TestCyclicCausalDAG (unittest .TestCase ):
@@ -116,8 +118,8 @@ class TestCyclicCausalDAG(unittest.TestCase):
116
118
"""
117
119
118
120
def setUp (self ) -> None :
119
- temp_dir_path = create_temp_dir_if_non_existent ()
120
- self .dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
121
+ self . temp_dir_path = tempfile . mkdtemp ()
122
+ self .dag_dot_path = os .path .join (self . temp_dir_path , "dag.dot" )
121
123
dag_dot = """digraph G { A -> B; B -> C; D -> A; D -> C; C -> A;}"""
122
124
f = open (self .dag_dot_path , "w" )
123
125
f .write (dag_dot )
@@ -127,7 +129,7 @@ def test_invalid_causal_dag(self):
127
129
self .assertRaises (nx .HasACycle , CausalDAG , self .dag_dot_path )
128
130
129
131
def tearDown (self ) -> None :
130
- remove_temp_dir_if_existent ( )
132
+ shutil . rmtree ( self . temp_dir_path )
131
133
132
134
133
135
class TestDAGDirectEffectIdentification (unittest .TestCase ):
@@ -136,8 +138,8 @@ class TestDAGDirectEffectIdentification(unittest.TestCase):
136
138
"""
137
139
138
140
def setUp (self ) -> None :
139
- temp_dir_path = create_temp_dir_if_non_existent ()
140
- self .dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
141
+ self . temp_dir_path = tempfile . mkdtemp ()
142
+ self .dag_dot_path = os .path .join (self . temp_dir_path , "dag.dot" )
141
143
dag_dot = """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
142
144
f = open (self .dag_dot_path , "w" )
143
145
f .write (dag_dot )
@@ -152,16 +154,18 @@ def test_direct_effect_adjustment_sets_no_adjustment(self):
152
154
causal_dag = CausalDAG (self .dag_dot_path )
153
155
adjustment_sets = causal_dag .direct_effect_adjustment_sets (["X2" ], ["D1" ])
154
156
self .assertEqual (list (adjustment_sets ), [set ()])
155
-
157
+
158
+ def tearDown (self ) -> None :
159
+ shutil .rmtree (self .temp_dir_path )
156
160
157
161
class TestDAGIdentification (unittest .TestCase ):
158
162
"""
159
163
Test the Causal DAG identification algorithms and supporting algorithms.
160
164
"""
161
165
162
166
def setUp (self ) -> None :
163
- temp_dir_path = create_temp_dir_if_non_existent ()
164
- self .dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
167
+ self . temp_dir_path = tempfile . mkdtemp ()
168
+ self .dag_dot_path = os .path .join (self . temp_dir_path , "dag.dot" )
165
169
dag_dot = """digraph G { X1->X2;X2->V;X2->D1;X2->D2;D1->Y;D1->D2;Y->D3;Z->X2;Z->Y;}"""
166
170
f = open (self .dag_dot_path , "w" )
167
171
f .write (dag_dot )
@@ -339,8 +343,7 @@ def test_dag_with_non_character_nodes(self):
339
343
self .assertEqual (adjustment_sets , [{"aa" }, {"la" }, {"va" }])
340
344
341
345
def tearDown (self ) -> None :
342
- remove_temp_dir_if_existent ()
343
-
346
+ shutil .rmtree (self .temp_dir_path )
344
347
345
348
class TestDependsOnOutputs (unittest .TestCase ):
346
349
"""
@@ -352,8 +355,8 @@ def setUp(self) -> None:
352
355
from causal_testing .specification .variable import Input , Output , Meta
353
356
from causal_testing .specification .scenario import Scenario
354
357
355
- temp_dir_path = create_temp_dir_if_non_existent ()
356
- self .dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
358
+ self . temp_dir_path = tempfile . mkdtemp ()
359
+ self .dag_dot_path = os .path .join (self . temp_dir_path , "dag.dot" )
357
360
dag_dot = """digraph G { A -> B; B -> C; D -> A; D -> C}"""
358
361
f = open (self .dag_dot_path , "w" )
359
362
f .write (dag_dot )
@@ -391,7 +394,7 @@ def test_depends_on_outputs_input(self):
391
394
self .assertFalse (causal_dag .depends_on_outputs ("D" , self .scenario ))
392
395
393
396
def tearDown (self ) -> None :
394
- remove_temp_dir_if_existent ( )
397
+ shutil . rmtree ( self . temp_dir_path )
395
398
396
399
397
400
class TestUndirectedGraphAlgorithms (unittest .TestCase ):
@@ -431,18 +434,15 @@ def test_list_all_min_sep(self):
431
434
min_separators = set (frozenset (min_separator ) for min_separator in min_separators )
432
435
self .assertEqual ({frozenset ({2 , 3 }), frozenset ({3 , 4 }), frozenset ({4 , 5 })}, min_separators )
433
436
434
- def tearDown (self ) -> None :
435
- remove_temp_dir_if_existent ()
436
-
437
437
438
438
class TestHiddenVariableDAG (unittest .TestCase ):
439
439
"""
440
440
Test the CausalDAG identification for the exclusion of hidden variables.
441
441
"""
442
442
443
443
def setUp (self ) -> None :
444
- temp_dir_path = create_temp_dir_if_non_existent ()
445
- self .dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
444
+ self . temp_dir_path = tempfile . mkdtemp ()
445
+ self .dag_dot_path = os .path .join (self . temp_dir_path , "dag.dot" )
446
446
dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M; M -> Y; Z -> M; }"""
447
447
with open (self .dag_dot_path , "w" ) as f :
448
448
f .write (dag_dot )
@@ -463,4 +463,4 @@ def test_hidden_varaible_adjustment_sets(self):
463
463
self .assertNotEqual (adjustment_sets , adjustment_sets_with_hidden )
464
464
465
465
def tearDown (self ) -> None :
466
- remove_temp_dir_if_existent ( )
466
+ shutil . rmtree ( self . temp_dir_path )
0 commit comments