Skip to content

Commit b066de4

Browse files
authored
Merge pull request #267 from CITCOM-project/surrogate-fix
Fix bug in categorical gene types
2 parents 5230bfc + 70d089f commit b066de4

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,14 @@ def create_gene_types(
9999
rel_split = str(relationship).split(" ")
100100

101101
if rel_split[0] in var_space:
102+
datatype = specification.scenario.variables.get(rel_split[0]).datatype
102103
if rel_split[1] == ">=":
103-
var_space[rel_split[0]]["low"] = int(rel_split[2])
104+
var_space[rel_split[0]]["low"] = datatype(rel_split[2])
104105
elif rel_split[1] == "<=":
105-
var_space[rel_split[0]]["high"] = int(rel_split[2])
106+
if datatype == int:
107+
var_space[rel_split[0]]["high"] = int(rel_split[2]) + 1
108+
else:
109+
var_space[rel_split[0]]["high"] = datatype(rel_split[2])
106110

107111
gene_space = []
108112
gene_space.append(var_space[surrogate_model.treatment])

tests/surrogate_tests/test_causal_surrogate_assisted.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from causal_testing.specification.causal_dag import CausalDAG
44
from causal_testing.specification.causal_specification import CausalSpecification
55
from causal_testing.specification.scenario import Scenario
6-
from causal_testing.specification.variable import Input
6+
from causal_testing.specification.variable import Input, Output
77
from causal_testing.surrogate.causal_surrogate_assisted import SimulationResult, CausalSurrogateAssistedTestCase, Simulator
88
from causal_testing.surrogate.surrogate_search_algorithms import GeneticSearchAlgorithm
99
from causal_testing.testing.estimators import CubicSplineRegressionEstimator
@@ -56,9 +56,9 @@ def test_surrogate_model_generation(self):
5656

5757
causal_dag = CausalDAG(self.dag_dot_path)
5858
z = Input("Z", int)
59-
x = Input("X", int)
59+
x = Input("X", float)
6060
m = Input("M", int)
61-
y = Input("Y", int)
61+
y = Output("Y", float)
6262
scenario = Scenario(variables={z, x, m, y})
6363
specification = CausalSpecification(scenario, causal_dag)
6464

@@ -75,9 +75,9 @@ def test_causal_surrogate_assisted_execution(self):
7575

7676
causal_dag = CausalDAG(self.dag_dot_path)
7777
z = Input("Z", int)
78-
x = Input("X", int)
78+
x = Input("X", float)
7979
m = Input("M", int)
80-
y = Input("Y", int)
80+
y = Output("Y", float)
8181
scenario = Scenario(variables={z, x, m, y}, constraints={
8282
z <= 0, z >= 3,
8383
x <= 0, x >= 3,
@@ -107,9 +107,9 @@ def test_causal_surrogate_assisted_execution_failure(self):
107107

108108
causal_dag = CausalDAG(self.dag_dot_path)
109109
z = Input("Z", int)
110-
x = Input("X", int)
110+
x = Input("X", float)
111111
m = Input("M", int)
112-
y = Input("Y", int)
112+
y = Output("Y", float)
113113
scenario = Scenario(variables={z, x, m, y}, constraints={
114114
z <= 0, z >= 3,
115115
x <= 0, x >= 3,
@@ -139,9 +139,9 @@ def test_causal_surrogate_assisted_execution_custom_aggregator(self):
139139

140140
causal_dag = CausalDAG(self.dag_dot_path)
141141
z = Input("Z", int)
142-
x = Input("X", int)
142+
x = Input("X", float)
143143
m = Input("M", int)
144-
y = Input("Y", int)
144+
y = Output("Y", float)
145145
scenario = Scenario(variables={z, x, m, y}, constraints={
146146
z <= 0, z >= 3,
147147
x <= 0, x >= 3,
@@ -172,9 +172,9 @@ def test_causal_surrogate_assisted_execution_incorrect_search_config(self):
172172

173173
causal_dag = CausalDAG(self.dag_dot_path)
174174
z = Input("Z", int)
175-
x = Input("X", int)
175+
x = Input("X", float)
176176
m = Input("M", int)
177-
y = Input("Y", int)
177+
y = Output("Y", float)
178178
scenario = Scenario(variables={z, x, m, y}, constraints={
179179
z <= 0, z >= 3,
180180
x <= 0, x >= 3,

0 commit comments

Comments
 (0)