Skip to content

Commit 8c37082

Browse files
committed
Adding tests
1 parent 27c737a commit 8c37082

File tree

4 files changed

+138
-8
lines changed

4 files changed

+138
-8
lines changed

causal_testing/specification/causal_dag.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,9 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
531531
if scenario is not None:
532532
minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario)
533533

534+
if len(minimal_adjustment_sets) == 0:
535+
return set()
536+
534537
minimal_adjustment_set = min(minimal_adjustment_sets, key=len)
535538
return minimal_adjustment_set
536539

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,11 @@ def create_gene_types(
105105
for relationship in list(specification.scenario.constraints):
106106
rel_split = str(relationship).split(" ")
107107

108-
if rel_split[1] == ">=":
109-
var_space[rel_split[0]]["low"] = int(rel_split[2])
110-
elif rel_split[1] == "<=":
111-
var_space[rel_split[0]]["high"] = int(rel_split[2])
108+
if rel_split[0] in var_space.keys():
109+
if rel_split[1] == ">=":
110+
var_space[rel_split[0]]["low"] = int(rel_split[2])
111+
elif rel_split[1] == "<=":
112+
var_space[rel_split[0]]["high"] = int(rel_split[2])
112113

113114
gene_space = []
114115
gene_space.append(var_space[fitness_function.surrogate_model.treatment])

tests/specification_tests/test_causal_dag.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import os
33
import networkx as nx
44
from causal_testing.specification.causal_dag import CausalDAG, close_separator, list_all_min_sep
5+
from causal_testing.specification.scenario import Scenario
6+
from causal_testing.specification.variable import Input, Output
7+
from causal_testing.testing.base_test_case import BaseTestCase
58
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
69

710

@@ -428,3 +431,34 @@ def test_list_all_min_sep(self):
428431

429432
def tearDown(self) -> None:
430433
remove_temp_dir_if_existent()
434+
435+
436+
class TestHiddenVariableDAG(unittest.TestCase):
437+
"""
438+
Test the CausalDAG identification for the exclusion of hidden variables.
439+
"""
440+
441+
def setUp(self) -> None:
442+
temp_dir_path = create_temp_dir_if_non_existent()
443+
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
444+
dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M; M -> Y; Z -> M; }"""
445+
with open(self.dag_dot_path, "w") as f:
446+
f.write(dag_dot)
447+
448+
def test_hidden_varaible_adjustment_sets(self):
449+
"""Test whether identification produces different adjustment sets depending on if a variable is hidden."""
450+
causal_dag = CausalDAG(self.dag_dot_path)
451+
z = Input("Z", int)
452+
x = Input("X", int)
453+
m = Input("M", int)
454+
455+
scenario = Scenario(variables={z, x, m})
456+
adjustment_sets = causal_dag.identification(BaseTestCase(x, m), scenario)
457+
458+
z.hidden = True
459+
adjustment_sets_with_hidden = causal_dag.identification(BaseTestCase(x, m), scenario)
460+
461+
self.assertNotEqual(adjustment_sets, adjustment_sets_with_hidden)
462+
463+
def tearDown(self) -> None:
464+
remove_temp_dir_if_existent()

tests/surrogate_tests/test_causal_surrogate_assisted.py

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
import unittest
2-
from causal_testing.surrogate.causal_surrogate_assisted import SimulationResult, SearchFitnessFunction
3-
from causal_testing.testing.estimators import Estimator, CubicSplineRegressionEstimator
2+
from causal_testing.data_collection.data_collector import ObservationalDataCollector
3+
from causal_testing.specification.causal_dag import CausalDAG
4+
from causal_testing.specification.causal_specification import CausalSpecification
5+
from causal_testing.specification.scenario import Scenario
6+
from causal_testing.specification.variable import Input
7+
from causal_testing.surrogate.causal_surrogate_assisted import SimulationResult, SearchFitnessFunction, CausalSurrogateAssistedTestCase, Simulator
8+
from causal_testing.surrogate.surrogate_search_algorithms import GeneticSearchAlgorithm
9+
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
10+
from tests.test_helpers import create_temp_dir_if_non_existent, remove_temp_dir_if_existent
11+
import os
12+
import pandas as pd
13+
import numpy as np
414

515
class TestSimulationResult(unittest.TestCase):
616

@@ -28,7 +38,16 @@ def test_inputs(self):
2838

2939
class TestSearchFitnessFunction(unittest.TestCase):
3040

31-
#TODO: complete tests for causal surrogate
41+
@classmethod
42+
def setUpClass(cls) -> None:
43+
cls.class_df = load_class_df()
44+
45+
def setUp(self):
46+
temp_dir_path = create_temp_dir_if_non_existent()
47+
self.dag_dot_path = os.path.join(temp_dir_path, "dag.dot")
48+
dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M [included=1, expected=positive]; M -> Y [included=1, expected=negative]; Z -> M; }"""
49+
with open(self.dag_dot_path, "w") as f:
50+
f.write(dag_dot)
3251

3352
def test_init_valid_values(self):
3453

@@ -39,4 +58,77 @@ def test_init_valid_values(self):
3958
search_function = SearchFitnessFunction(fitness_function=test_function, surrogate_model=surrogate_model)
4059

4160
self.assertTrue(callable(search_function.fitness_function))
42-
self.assertIsInstance(search_function.surrogate_model, CubicSplineRegressionEstimator)
61+
self.assertIsInstance(search_function.surrogate_model, CubicSplineRegressionEstimator)
62+
63+
def test_surrogate_model_generation(self):
64+
c_s_a_test_case = CausalSurrogateAssistedTestCase(None, None, None)
65+
66+
df = self.class_df.copy()
67+
68+
causal_dag = CausalDAG(self.dag_dot_path)
69+
z = Input("Z", int)
70+
x = Input("X", int)
71+
m = Input("M", int)
72+
y = Input("Y", int)
73+
scenario = Scenario(variables={z, x, m, y})
74+
specification = CausalSpecification(scenario, causal_dag)
75+
76+
surrogate_models = c_s_a_test_case.generate_surrogates(specification, ObservationalDataCollector(scenario, df))
77+
self.assertEqual(len(surrogate_models), 2)
78+
79+
for surrogate in surrogate_models:
80+
self.assertIsInstance(surrogate, CubicSplineRegressionEstimator)
81+
self.assertNotEqual(surrogate.treatment, "Z")
82+
self.assertNotEqual(surrogate.outcome, "Z")
83+
84+
def test_causal_surrogate_assisted_execution(self):
85+
df = self.class_df.copy()
86+
87+
causal_dag = CausalDAG(self.dag_dot_path)
88+
z = Input("Z", int)
89+
x = Input("X", int)
90+
m = Input("M", int)
91+
y = Input("Y", int)
92+
scenario = Scenario(variables={z, x, m, y}, constraints={
93+
z <= 0, z >= 3,
94+
x <= 0, x >= 3,
95+
m <= 0, m >= 3
96+
})
97+
specification = CausalSpecification(scenario, causal_dag)
98+
99+
search_algorithm = GeneticSearchAlgorithm(config= {
100+
"parent_selection_type": "tournament",
101+
"K_tournament": 4,
102+
"mutation_type": "random",
103+
"mutation_percent_genes": 50,
104+
"mutation_by_replacement": True,
105+
})
106+
simulator = TestSimulator()
107+
108+
c_s_a_test_case = CausalSurrogateAssistedTestCase(specification, search_algorithm, simulator)
109+
110+
result, iterations, result_data = c_s_a_test_case.execute(ObservationalDataCollector(scenario, df))
111+
112+
self.assertIsInstance(result, SimulationResult)
113+
self.assertEqual(iterations, 1)
114+
self.assertEqual(len(result_data), 17)
115+
116+
def tearDown(self) -> None:
117+
remove_temp_dir_if_existent()
118+
119+
def load_class_df():
120+
"""Get the testing data and put into a dataframe."""
121+
122+
class_df = pd.DataFrame({"Z": np.arange(16), "X": np.arange(16), "M": np.arange(16, 32), "Y": np.arange(32,16,-1)})
123+
return class_df
124+
125+
class TestSimulator():
126+
127+
def run_with_config(self, configuration: dict) -> SimulationResult:
128+
return SimulationResult({"Z": 1, "X": 1, "M": 1, "Y": 1}, True, None)
129+
130+
def startup(self):
131+
pass
132+
133+
def shutdown(self):
134+
pass

0 commit comments

Comments
 (0)