|
6 | 6 | import lhsmdu
|
7 | 7 | from pandas import DataFrame
|
8 | 8 | from scipy.stats._distn_infrastructure import rv_generic
|
9 |
| -from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String |
| 9 | +from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String, DatatypeRef |
10 | 10 |
|
11 | 11 | # Declare type variable
|
12 | 12 | # Is there a better way? I'd really like to do Variable[T](ExprRef)
|
@@ -153,19 +153,24 @@ def cast(self, val: Any) -> T:
|
153 | 153 | :rtype: T
|
154 | 154 | """
|
155 | 155 | assert val is not None, f"Invalid value None for variable {self}"
|
| 156 | + if isinstance(val, self.datatype): |
| 157 | + return val |
156 | 158 | if isinstance(val, RatNumRef) and self.datatype == float:
|
157 | 159 | return float(val.numerator().as_long() / val.denominator().as_long())
|
158 | 160 | if hasattr(val, "is_string_value") and val.is_string_value() and self.datatype == str:
|
159 | 161 | return val.as_string()
|
160 | 162 | if (isinstance(val, float) or isinstance(val, int)) and (self.datatype == int or self.datatype == float):
|
161 | 163 | return self.datatype(val)
|
| 164 | + if issubclass(self.datatype, Enum) and isinstance(val, DatatypeRef): |
| 165 | + return self.datatype[str(val)] |
162 | 166 | return self.datatype(str(val))
|
163 | 167 |
|
164 | 168 | def z3_val(self, z3_var, val: Any) -> T:
|
165 | 169 | native_val = self.cast(val)
|
| 170 | + print(val, type(val), native_val, type(native_val)) |
166 | 171 | if isinstance(native_val, Enum):
|
167 | 172 | values = [z3_var.sort().constructor(c)() for c in range(z3_var.sort().num_constructors())]
|
168 |
| - values = [v for v in values if str(v) == str(val)] |
| 173 | + values = [v for v in values if type(val)[str(v)] == val] |
169 | 174 | assert len(values) == 1, f"Expected {values} to be length 1"
|
170 | 175 | return values[0]
|
171 | 176 | return native_val
|
|
0 commit comments