Skip to content

Commit c9f0fc8

Browse files
Remove custom helpers and use tempfile for all tests
1 parent 89dba72 commit c9f0fc8

File tree

8 files changed

+35
-36
lines changed

8 files changed

+35
-36
lines changed

tests/data_collection_tests/test_observational_data_collector.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import unittest
22
import os
3+
import shutil, tempfile
34
import pandas as pd
45
from causal_testing.data_collection.data_collector import ObservationalDataCollector
56
from causal_testing.specification.causal_specification import Scenario
67
from causal_testing.specification.variable import Input, Output, Meta
78
from scipy.stats import uniform, rv_discrete
89
from enum import Enum
910
import random
10-
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
1111

1212

1313
class TestObservationalDataCollector(unittest.TestCase):
@@ -17,9 +17,9 @@ class Color(Enum):
1717
GREEN = "GREEN"
1818
BLUE = "BLUE"
1919

20-
temp_dir_path = create_temp_dir_if_non_existent()
21-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
22-
self.observational_df_path = os.path.join(temp_dir_path, "observational_data.csv")
20+
self.temp_dir_path = tempfile.mkdtemp()
21+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
22+
self.observational_df_path = os.path.join(self.temp_dir_path, "observational_data.csv")
2323
# Y = 3*X1 + X2*X3 + 10
2424
self.observational_df = pd.DataFrame(
2525
{"X1": [1, 2, 3, 4], "X2": [5, 6, 7, 8], "X3": [10, 20, 30, 40], "Y2": ["RED", "GREEN", "BLUE", "BLUE"]}
@@ -66,7 +66,7 @@ def populate_m(data):
6666
assert all((m == 2 * x1 for x1, m in zip(data["X1"], data["M"])))
6767

6868
def tearDown(self) -> None:
69-
remove_temp_dir_if_existent()
69+
shutil.rmtree(self.temp_dir_path)
7070

7171

7272
if __name__ == "__main__":

tests/generation_tests/test_abstract_test_case.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import unittest
22
import os
3+
import shutil, tempfile
34
import pandas as pd
45
import numpy as np
56
from causal_testing.generation.abstract_causal_test_case import AbstractCausalTestCase
67
from causal_testing.generation.enum_gen import EnumGen
78
from causal_testing.specification.causal_specification import Scenario
89
from causal_testing.specification.variable import Input, Output
910
from scipy.stats import uniform, rv_discrete
10-
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
1111
from causal_testing.testing.causal_test_outcome import Positive
1212
from z3 import And
1313
from enum import Enum
@@ -29,9 +29,9 @@ class TestAbstractTestCase(unittest.TestCase):
2929
"""
3030

3131
def setUp(self) -> None:
32-
temp_dir_path = create_temp_dir_if_non_existent()
33-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
34-
self.observational_df_path = os.path.join(temp_dir_path, "observational_data.csv")
32+
self.temp_dir_path = tempfile.mkdtemp()
33+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
34+
self.observational_df_path = os.path.join(self.temp_dir_path, "observational_data.csv")
3535
# Y = 3*X1 + X2*X3 + 10
3636
self.observational_df = pd.DataFrame({"X1": [1, 2, 3, 4], "X2": [5, 6, 7, 8], "X3": [10, 20, 30, 40]})
3737
self.observational_df["Y"] = self.observational_df.apply(
@@ -192,7 +192,7 @@ def test_feasible_constraints(self):
192192
assert len(concrete_tests) < 1000
193193

194194
def tearDown(self) -> None:
195-
remove_temp_dir_if_existent()
195+
shutil.rmtree(self.temp_dir_path)
196196

197197

198198
if __name__ == "__main__":

tests/json_front_tests/test_json_class.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from causal_testing.testing.estimators import LinearRegressionEstimator, Estimator
88
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
9-
from tests.test_helpers import remove_temp_dir_if_existent
109
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
1110
from causal_testing.specification.variable import Input, Output, Meta
1211
from causal_testing.specification.scenario import Scenario
@@ -321,7 +320,6 @@ def add_modelling_assumptions(self):
321320
self.json_class.run_json_tests(effects=effects, mutates=mutates, estimators=estimators, f_flag=False)
322321

323322
def tearDown(self) -> None:
324-
remove_temp_dir_if_existent()
325323
if os.path.exists("temp_out.txt"):
326324
os.remove("temp_out.txt")
327325

tests/specification_tests/test_causal_dag.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import unittest
22
import os
3+
import shutil, tempfile
34
import networkx as nx
45
from causal_testing.specification.causal_dag import CausalDAG, close_separator, list_all_min_sep
56
from causal_testing.specification.scenario import Scenario
67
from causal_testing.specification.variable import Input, Output
78
from causal_testing.testing.base_test_case import BaseTestCase
8-
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
9+
10+
911

1012

1113
class TestCausalDAGIssue90(unittest.TestCase):
@@ -14,8 +16,8 @@ class TestCausalDAGIssue90(unittest.TestCase):
1416
"""
1517

1618
def setUp(self) -> None:
17-
temp_dir_path = create_temp_dir_if_non_existent()
18-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
19+
self.temp_dir_path = tempfile.mkdtemp()
20+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
1921
dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M; M -> Y; Z -> M; }"""
2022
with open(self.dag_dot_path, "w") as f:
2123
f.write(dag_dot)
@@ -28,7 +30,7 @@ def test_enumerate_minimal_adjustment_sets(self):
2830
self.assertEqual([{"Z"}], adjustment_sets)
2931

3032
def tearDown(self) -> None:
31-
remove_temp_dir_if_existent()
33+
shutil.rmtree(self.temp_dir_path)
3234

3335

3436
class TestIVAssumptions(unittest.TestCase):

tests/specification_tests/test_metamorphic_relations.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import unittest
22
import os
3-
3+
import shutil, tempfile
44
import pandas as pd
55
from itertools import combinations
66

7-
from tests.test_helpers import create_temp_dir_if_non_existent
87
from causal_testing.specification.causal_dag import CausalDAG
98
from causal_testing.specification.causal_specification import Scenario
109
from causal_testing.specification.metamorphic_relation import (
@@ -69,8 +68,8 @@ def run_system_with_input_configuration(self, input_configuration: dict) -> pd.D
6968

7069
class TestMetamorphicRelation(unittest.TestCase):
7170
def setUp(self) -> None:
72-
temp_dir_path = create_temp_dir_if_non_existent()
73-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
71+
self.temp_dir_path = tempfile.mkdtemp()
72+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
7473
dag_dot = """digraph DAG { rankdir=LR; X1 -> Z; Z -> M; M -> Y; X2 -> Z; X3 -> M;}"""
7574
with open(self.dag_dot_path, "w") as f:
7675
f.write(dag_dot)
@@ -88,6 +87,9 @@ def setUp(self) -> None:
8887
self.scenario, self.default_control_input_config, self.default_treatment_input_config
8988
)
9089

90+
def tearDown(self) -> None:
91+
shutil.rmtree(self.temp_dir_path)
92+
9193
def test_should_cause_metamorphic_relations_correct_spec(self):
9294
"""Test if the ShouldCause MR passes all metamorphic tests where the DAG perfectly represents the program."""
9395
causal_dag = CausalDAG(self.dag_dot_path)

tests/surrogate_tests/test_causal_surrogate_assisted.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from causal_testing.surrogate.causal_surrogate_assisted import SimulationResult, CausalSurrogateAssistedTestCase, Simulator
88
from causal_testing.surrogate.surrogate_search_algorithms import GeneticSearchAlgorithm
99
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
10-
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
10+
1111
import os
12+
import shutil, tempfile
1213
import pandas as pd
1314
import numpy as np
1415

@@ -43,8 +44,8 @@ def setUpClass(cls) -> None:
4344
cls.class_df = load_class_df()
4445

4546
def setUp(self):
46-
temp_dir_path = create_temp_dir_if_non_existent()
47-
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
47+
self.temp_dir_path = tempfile.mkdtemp()
48+
self.dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
4849
dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M [included=1, expected=positive]; M -> Y [included=1, expected=negative]; Z -> M; }"""
4950
with open(self.dag_dot_path, "w") as f:
5051
f.write(dag_dot)
@@ -199,7 +200,7 @@ def test_causal_surrogate_assisted_execution_incorrect_search_config(self):
199200
custom_data_aggregator=data_double_aggregator)
200201

201202
def tearDown(self) -> None:
202-
remove_temp_dir_if_existent()
203+
shutil.rmtree(self.temp_dir_path)
203204

204205
def load_class_df():
205206
"""Get the testing data and put into a dataframe."""

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from causal_testing.testing.causal_test_suite import CausalTestSuite
1111
from causal_testing.testing.causal_test_adequacy import DAGAdequacy
1212
from causal_testing.testing.causal_test_outcome import NoEffect, Positive
13-
from tests.test_helpers import remove_temp_dir_if_existent
1413
from causal_testing.json_front.json_class import JsonUtility, CausalVariables
1514
from causal_testing.specification.variable import Input, Output, Meta
1615
from causal_testing.specification.scenario import Scenario
@@ -255,6 +254,5 @@ def test_dag_adequacy_independent_other_way(self):
255254
)
256255

257256
def tearDown(self) -> None:
258-
remove_temp_dir_if_existent()
259257
if os.path.exists("temp_out.txt"):
260258
os.remove("temp_out.txt")

tests/testing_tests/test_causal_test_case.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import unittest
22
import os
3+
import tempfile
4+
import shutil
35
import pandas as pd
46
import numpy as np
57

6-
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
78
from causal_testing.specification.causal_specification import CausalSpecification, Scenario
89
from causal_testing.specification.variable import Input, Output
910
from causal_testing.specification.causal_dag import CausalDAG
@@ -44,9 +45,6 @@ def test_str(self):
4445
" {Output: C::float}: ExactValue: 4±0.2.",
4546
)
4647

47-
def tearDown(self) -> None:
48-
remove_temp_dir_if_existent()
49-
5048

5149
class TestCausalTestExecution(unittest.TestCase):
5250
"""Test the causal test execution workflow using observational data.
@@ -57,8 +55,8 @@ class TestCausalTestExecution(unittest.TestCase):
5755

5856
def setUp(self) -> None:
5957
# 1. Create Causal DAG
60-
temp_dir_path = create_temp_dir_if_non_existent()
61-
dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
58+
self.temp_dir_path = tempfile.mkdtemp()
59+
dag_dot_path = os.path.join(self.temp_dir_path, "dag.dot")
6260
dag_dot = """digraph G { A -> C; D -> A; D -> C}"""
6361
with open(dag_dot_path, "w") as file:
6462
file.write(dag_dot)
@@ -88,7 +86,7 @@ def setUp(self) -> None:
8886
df = pd.DataFrame({"D": list(np.random.normal(60, 10, 1000))}) # D = exogenous
8987
df["A"] = [1 if d > 50 else 0 for d in df["D"]]
9088
df["C"] = df["D"] + (4 * (df["A"] + 2)) # C = (4*(A+2)) + D
91-
self.observational_data_csv_path = os.path.join(temp_dir_path, "observational_data.csv")
89+
self.observational_data_csv_path = os.path.join(self.temp_dir_path, "observational_data.csv")
9290
df.to_csv(self.observational_data_csv_path, index=False)
9391

9492
# 5. Create observational data collector
@@ -101,6 +99,9 @@ def setUp(self) -> None:
10199
self.treatment_value = 1
102100
self.control_value = 0
103101

102+
def tearDown(self) -> None:
103+
shutil.rmtree(self.temp_dir_path)
104+
104105
def test_check_minimum_adjustment_set(self):
105106
"""Check that the minimum adjustment set is correctly made"""
106107
minimal_adjustment_set = self.causal_dag.identification(self.base_test_case)
@@ -215,6 +216,3 @@ def test_execute_test_observational_linear_regression_estimator_squared_term(sel
215216
)
216217
causal_test_result = self.causal_test_case.execute_test(estimation_model, self.data_collector)
217218
pd.testing.assert_series_equal(causal_test_result.test_value.value, pd.Series(4.0), atol=1)
218-
219-
def tearDown(self) -> None:
220-
remove_temp_dir_if_existent()

0 commit comments

Comments
 (0)