1
1
import unittest
2
- from causal_testing .surrogate .causal_surrogate_assisted import SimulationResult , SearchFitnessFunction
3
- from causal_testing .testing .estimators import Estimator , CubicSplineRegressionEstimator
2
+ from causal_testing .data_collection .data_collector import ObservationalDataCollector
3
+ from causal_testing .specification .causal_dag import CausalDAG
4
+ from causal_testing .specification .causal_specification import CausalSpecification
5
+ from causal_testing .specification .scenario import Scenario
6
+ from causal_testing .specification .variable import Input
7
+ from causal_testing .surrogate .causal_surrogate_assisted import SimulationResult , SearchFitnessFunction , CausalSurrogateAssistedTestCase , Simulator
8
+ from causal_testing .surrogate .surrogate_search_algorithms import GeneticSearchAlgorithm
9
+ from causal_testing .testing .estimators import CubicSplineRegressionEstimator
10
+ from tests .test_helpers import create_temp_dir_if_non_existent , remove_temp_dir_if_existent
11
+ import os
12
+ import pandas as pd
13
+ import numpy as np
4
14
5
15
class TestSimulationResult (unittest .TestCase ):
6
16
@@ -28,7 +38,16 @@ def test_inputs(self):
28
38
29
39
class TestSearchFitnessFunction (unittest .TestCase ):
30
40
31
- #TODO: complete tests for causal surrogate
41
+ @classmethod
42
+ def setUpClass (cls ) -> None :
43
+ cls .class_df = load_class_df ()
44
+
45
+ def setUp (self ):
46
+ temp_dir_path = create_temp_dir_if_non_existent ()
47
+ self .dag_dot_path = os .path .join (temp_dir_path , "dag.dot" )
48
+ dag_dot = """digraph DAG { rankdir=LR; Z -> X; X -> M [included=1, expected=positive]; M -> Y [included=1, expected=negative]; Z -> M; }"""
49
+ with open (self .dag_dot_path , "w" ) as f :
50
+ f .write (dag_dot )
32
51
33
52
def test_init_valid_values (self ):
34
53
@@ -39,4 +58,77 @@ def test_init_valid_values(self):
39
58
search_function = SearchFitnessFunction (fitness_function = test_function , surrogate_model = surrogate_model )
40
59
41
60
self .assertTrue (callable (search_function .fitness_function ))
42
- self .assertIsInstance (search_function .surrogate_model , CubicSplineRegressionEstimator )
61
+ self .assertIsInstance (search_function .surrogate_model , CubicSplineRegressionEstimator )
62
+
63
+ def test_surrogate_model_generation (self ):
64
+ c_s_a_test_case = CausalSurrogateAssistedTestCase (None , None , None )
65
+
66
+ df = self .class_df .copy ()
67
+
68
+ causal_dag = CausalDAG (self .dag_dot_path )
69
+ z = Input ("Z" , int )
70
+ x = Input ("X" , int )
71
+ m = Input ("M" , int )
72
+ y = Input ("Y" , int )
73
+ scenario = Scenario (variables = {z , x , m , y })
74
+ specification = CausalSpecification (scenario , causal_dag )
75
+
76
+ surrogate_models = c_s_a_test_case .generate_surrogates (specification , ObservationalDataCollector (scenario , df ))
77
+ self .assertEqual (len (surrogate_models ), 2 )
78
+
79
+ for surrogate in surrogate_models :
80
+ self .assertIsInstance (surrogate , CubicSplineRegressionEstimator )
81
+ self .assertNotEqual (surrogate .treatment , "Z" )
82
+ self .assertNotEqual (surrogate .outcome , "Z" )
83
+
84
+ def test_causal_surrogate_assisted_execution (self ):
85
+ df = self .class_df .copy ()
86
+
87
+ causal_dag = CausalDAG (self .dag_dot_path )
88
+ z = Input ("Z" , int )
89
+ x = Input ("X" , int )
90
+ m = Input ("M" , int )
91
+ y = Input ("Y" , int )
92
+ scenario = Scenario (variables = {z , x , m , y }, constraints = {
93
+ z <= 0 , z >= 3 ,
94
+ x <= 0 , x >= 3 ,
95
+ m <= 0 , m >= 3
96
+ })
97
+ specification = CausalSpecification (scenario , causal_dag )
98
+
99
+ search_algorithm = GeneticSearchAlgorithm (config = {
100
+ "parent_selection_type" : "tournament" ,
101
+ "K_tournament" : 4 ,
102
+ "mutation_type" : "random" ,
103
+ "mutation_percent_genes" : 50 ,
104
+ "mutation_by_replacement" : True ,
105
+ })
106
+ simulator = TestSimulator ()
107
+
108
+ c_s_a_test_case = CausalSurrogateAssistedTestCase (specification , search_algorithm , simulator )
109
+
110
+ result , iterations , result_data = c_s_a_test_case .execute (ObservationalDataCollector (scenario , df ))
111
+
112
+ self .assertIsInstance (result , SimulationResult )
113
+ self .assertEqual (iterations , 1 )
114
+ self .assertEqual (len (result_data ), 17 )
115
+
116
+ def tearDown (self ) -> None :
117
+ remove_temp_dir_if_existent ()
118
+
119
+ def load_class_df ():
120
+ """Get the testing data and put into a dataframe."""
121
+
122
+ class_df = pd .DataFrame ({"Z" : np .arange (16 ), "X" : np .arange (16 ), "M" : np .arange (16 , 32 ), "Y" : np .arange (32 ,16 ,- 1 )})
123
+ return class_df
124
+
125
+ class TestSimulator ():
126
+
127
+ def run_with_config (self , configuration : dict ) -> SimulationResult :
128
+ return SimulationResult ({"Z" : 1 , "X" : 1 , "M" : 1 , "Y" : 1 }, True , None )
129
+
130
+ def startup (self ):
131
+ pass
132
+
133
+ def shutdown (self ):
134
+ pass
0 commit comments