Skip to content

Commit b53b8d8

Browse files
committed
Enum variables mostly fixed
1 parent 81a635c commit b53b8d8

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

causal_testing/generation/abstract_causal_test_case.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,18 @@ def _generate_concrete_tests(
9393
var = self.scenario.variables[name]
9494
samples[var.name] = lhsmdu.inverseTransformSample(var.distribution, samples[var.name])
9595

96+
print(samples["ego_vehicle"])
9697
for index, row in samples.iterrows():
9798
optimizer = z3.Optimize()
9899
for c in self.scenario.constraints:
99100
optimizer.assert_and_track(c, str(c))
100101
for c in self.intervention_constraints:
101102
optimizer.assert_and_track(c, str(c))
102103

103-
optimizer.add_soft([self.scenario.variables[v].z3 == row[v] for v in run_columns])
104+
for v in run_columns:
105+
optimizer.add_soft(self.scenario.variables[v].z3 == self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v]))
106+
107+
# optimizer.add_soft([optimizer.add_soft(self.scenario.variables[v].z3 == self.scenario.variables[v].z3_val(self.scenario.variables[v].z3, row[v])) for v in run_columns])
104108
if optimizer.check() == z3.unsat:
105109
logger.warning(
106110
"Satisfiability of test case was unsat.\n" "Constraints \n %s \n Unsat core %s",

causal_testing/specification/variable.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import lhsmdu
77
from pandas import DataFrame
88
from scipy.stats._distn_infrastructure import rv_generic
9-
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String
9+
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String, DatatypeRef
1010

1111
# Declare type variable
1212
# Is there a better way? I'd really like to do Variable[T](ExprRef)
@@ -153,19 +153,24 @@ def cast(self, val: Any) -> T:
153153
:rtype: T
154154
"""
155155
assert val is not None, f"Invalid value None for variable {self}"
156+
if isinstance(val, self.datatype):
157+
return val
156158
if isinstance(val, RatNumRef) and self.datatype == float:
157159
return float(val.numerator().as_long() / val.denominator().as_long())
158160
if hasattr(val, "is_string_value") and val.is_string_value() and self.datatype == str:
159161
return val.as_string()
160162
if (isinstance(val, float) or isinstance(val, int)) and (self.datatype == int or self.datatype == float):
161163
return self.datatype(val)
164+
if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef):
165+
return self.datatype[str(val)]
162166
return self.datatype(str(val))
163167

164168
def z3_val(self, z3_var, val: Any) -> T:
165169
native_val = self.cast(val)
170+
print(val, type(val), native_val, type(native_val))
166171
if isinstance(native_val, Enum):
167172
values = [z3_var.sort().constructor(c)() for c in range(z3_var.sort().num_constructors())]
168-
values = [v for v in values if str(v) == str(val)]
173+
values = [v for v in values if type(val)[str(v)] == val]
169174
assert len(values) == 1, f"Expected {values} to be length 1"
170175
return values[0]
171176
return native_val

0 commit comments

Comments
 (0)