3
3
from causal_testing .specification .causal_dag import CausalDAG
4
4
from causal_testing .specification .causal_specification import CausalSpecification
5
5
from causal_testing .specification .scenario import Scenario
6
- from causal_testing .specification .variable import Input
6
+ from causal_testing .specification .variable import Input , Output
7
7
from causal_testing .surrogate .causal_surrogate_assisted import SimulationResult , CausalSurrogateAssistedTestCase , Simulator
8
8
from causal_testing .surrogate .surrogate_search_algorithms import GeneticSearchAlgorithm
9
9
from causal_testing .testing .estimators import CubicSplineRegressionEstimator
@@ -58,7 +58,7 @@ def test_surrogate_model_generation(self):
58
58
z = Input ("Z" , int )
59
59
x = Input ("X" , int )
60
60
m = Input ("M" , int )
61
- y = Input ("Y" , int )
61
+ y = Output ("Y" , float )
62
62
scenario = Scenario (variables = {z , x , m , y })
63
63
specification = CausalSpecification (scenario , causal_dag )
64
64
@@ -77,7 +77,7 @@ def test_causal_surrogate_assisted_execution(self):
77
77
z = Input ("Z" , int )
78
78
x = Input ("X" , int )
79
79
m = Input ("M" , int )
80
- y = Input ("Y" , int )
80
+ y = Output ("Y" , float )
81
81
scenario = Scenario (variables = {z , x , m , y }, constraints = {
82
82
z <= 0 , z >= 3 ,
83
83
x <= 0 , x >= 3 ,
@@ -109,7 +109,7 @@ def test_causal_surrogate_assisted_execution_failure(self):
109
109
z = Input ("Z" , int )
110
110
x = Input ("X" , int )
111
111
m = Input ("M" , int )
112
- y = Input ("Y" , int )
112
+ y = Output ("Y" , float )
113
113
scenario = Scenario (variables = {z , x , m , y }, constraints = {
114
114
z <= 0 , z >= 3 ,
115
115
x <= 0 , x >= 3 ,
@@ -141,7 +141,7 @@ def test_causal_surrogate_assisted_execution_custom_aggregator(self):
141
141
z = Input ("Z" , int )
142
142
x = Input ("X" , int )
143
143
m = Input ("M" , int )
144
- y = Input ("Y" , int )
144
+ y = Output ("Y" , float )
145
145
scenario = Scenario (variables = {z , x , m , y }, constraints = {
146
146
z <= 0 , z >= 3 ,
147
147
x <= 0 , x >= 3 ,
@@ -174,7 +174,7 @@ def test_causal_surrogate_assisted_execution_incorrect_search_config(self):
174
174
z = Input ("Z" , int )
175
175
x = Input ("X" , int )
176
176
m = Input ("M" , int )
177
- y = Input ("Y" , int )
177
+ y = Output ("Y" , float )
178
178
scenario = Scenario (variables = {z , x , m , y }, constraints = {
179
179
z <= 0 , z >= 3 ,
180
180
x <= 0 , x >= 3 ,
0 commit comments