Skip to content

Commit e559295

Browse files
committed
Removed Z3
1 parent 7c0f0f2 commit e559295

File tree

8 files changed

+23
-365
lines changed

8 files changed

+23
-365
lines changed

causal_testing/specification/scenario.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from collections.abc import Iterable, Mapping
44

55
from tabulate import tabulate
6-
from z3 import ExprRef, substitute
76

87
from .variable import Input, Meta, Output, Variable
98

@@ -20,15 +19,15 @@ class Scenario:
2019
accordingly.
2120
2221
:param {Variable} variables: The set of endogenous variables.
23-
:param {ExprRef} constraints: The set of constraints relating the endogenous variables.
22+
:param {str} constraints: The set of constraints relating the endogenous variables.
2423
:attr variables:
2524
:attr constraints:
2625
"""
2726

2827
variables: Mapping[str, Variable]
29-
constraints: set[ExprRef]
28+
constraints: set[str]
3029

31-
def __init__(self, variables: Iterable[Variable] = None, constraints: set[ExprRef] = None):
30+
def __init__(self, variables: Iterable[Variable] = None, constraints: set[str] = None):
3231
if variables is not None:
3332
self.variables = {v.name: v for v in variables}
3433
else:
@@ -106,10 +105,6 @@ def setup_treatment_variables(self) -> None:
106105
self.prime[k] = v_prime.name
107106
self.unprime[v_prime.name] = k
108107

109-
substitutions = {(self.variables[n].z3, self.treatment_variables[n].z3) for n in self.variables}
110-
treatment_constraints = {substitute(c, *substitutions) for c in self.constraints}
111-
self.constraints = self.constraints.union(treatment_constraints)
112-
113108
def variables_of_type(self, t: type) -> set[Variable]:
114109
"""Get the set of scenario variables of a particular type, e.g. Inputs.
115110

causal_testing/specification/variable.py

Lines changed: 1 addition & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,16 @@
1-
"""This module contains the Variable abstract class, as well as its concrete extensions: Input, Output and Meta. The
2-
function z3_types and the private function _coerce are also in this module."""
1+
"""This module contains the Variable abstract class, as well as its concrete extensions: Input, Output and Meta."""
32

43
from __future__ import annotations
54
from abc import ABC
65
from collections.abc import Callable
7-
from enum import Enum
86
from typing import Any, TypeVar
97

108
import lhsmdu
119
from pandas import DataFrame
1210
from scipy.stats._distn_infrastructure import rv_generic
13-
from z3 import Bool, BoolRef, Const, EnumSort, Int, RatNumRef, Real, String
1411

1512
# Declare type variable
1613
T = TypeVar("T")
17-
z3 = TypeVar("Z3")
18-
19-
20-
def z3_types(datatype: T) -> z3:
21-
"""Cast datatype to Z3 datatype
22-
:param datatype: python datatype to be cast
23-
:return: Type name compatible with Z3 library
24-
"""
25-
types = {int: Int, str: String, float: Real, bool: Bool}
26-
if datatype in types:
27-
return types[datatype]
28-
if issubclass(datatype, Enum):
29-
dtype, _ = EnumSort(datatype.__name__, [str(x.value) for x in datatype])
30-
return lambda x: Const(x, dtype)
31-
if hasattr(datatype, "to_z3"):
32-
return datatype.to_z3()
33-
raise ValueError(
34-
f"Cannot convert type {datatype} to Z3."
35-
+ " Please use a native type, an Enum, or implement a conversion manually."
36-
)
37-
38-
39-
def _coerce(val: Any) -> Any:
40-
"""Coerce Variables to their Z3 equivalents if appropriate to do so,
41-
otherwise assume literal constants.
42-
43-
:param any val: A value, possibly a Variable.
44-
:return: Either a Z3 ExprRef representing the variable or the original value.
45-
:rtype: Any
46-
47-
"""
48-
if isinstance(val, Variable):
49-
return val.z3
50-
return val
5114

5215

5316
class Variable(ABC):
@@ -56,7 +19,6 @@ class Variable(ABC):
5619
:param str name: The name of the variable.
5720
:param T datatype: The datatype of the variable.
5821
:param rv_generic distribution: The expected distribution of the variable values.
59-
:attr type z3: The Z3 mirror of the variable.
6022
:attr name:
6123
:attr datatype:
6224
:attr distribution:
@@ -70,125 +32,12 @@ class Variable(ABC):
7032
def __init__(self, name: str, datatype: T, distribution: rv_generic = None, hidden: bool = False):
7133
self.name = name
7234
self.datatype = datatype
73-
self.z3 = z3_types(datatype)(name)
7435
self.distribution = distribution
7536
self.hidden = hidden
7637

7738
def __repr__(self):
7839
return f"{self.typestring()}: {self.name}::{self.datatype.__name__}"
7940

80-
# Thin wrapper for Z1 functions
81-
82-
def __add__(self, other: Any) -> BoolRef:
83-
"""Create the Z3 expression `self + other`.
84-
85-
:param any other: The object to compare against.
86-
:return: The Z3 expression `self + other`.
87-
:rtype: BoolRef
88-
"""
89-
return self.z3.__add__(_coerce(other))
90-
91-
def __ge__(self, other: Any) -> BoolRef:
92-
"""Create the Z3 expression `self >= other`.
93-
94-
:param any other: The object to compare against.
95-
:return: The Z3 expression `self >= other`.
96-
:rtype: BoolRef
97-
"""
98-
return self.z3.__ge__(_coerce(other))
99-
100-
def __gt__(self, other: Any) -> BoolRef:
101-
"""Create the Z3 expression `self > other`.
102-
103-
:param any other: The object to compare against.
104-
:return: The Z3 expression `self > other`.
105-
:rtype: BoolRef
106-
"""
107-
return self.z3.__gt__(_coerce(other))
108-
109-
def __le__(self, other: Any) -> BoolRef:
110-
"""Create the Z3 expression `self <= other`.
111-
112-
:param any other: The object to compare against.
113-
:return: The Z3 expression `self <= other`.
114-
:rtype: BoolRef
115-
"""
116-
return self.z3.__le__(_coerce(other))
117-
118-
def __lt__(self, other: Any) -> BoolRef:
119-
"""Create the Z3 expression `self < other`.
120-
121-
:param any other: The object to compare against.
122-
:return: The Z3 expression `self < other`.
123-
:rtype: BoolRef
124-
"""
125-
return self.z3.__lt__(_coerce(other))
126-
127-
def __mod__(self, other: Any) -> BoolRef:
128-
"""Create the Z3 expression `self % other`.
129-
130-
:param any other: The object to compare against.
131-
:return: The Z3 expression `self % other`.
132-
:rtype: BoolRef
133-
"""
134-
return self.z3.__mod__(_coerce(other))
135-
136-
def __mul__(self, other: Any) -> BoolRef:
137-
"""Create the Z3 expression `self * other`.
138-
139-
:param any other: The object to compare against.
140-
:return: The Z3 expression `self * other`.
141-
:rtype: BoolRef
142-
"""
143-
return self.z3.__mul__(_coerce(other))
144-
145-
def __ne__(self, other: Any) -> BoolRef:
146-
"""Create the Z3 expression `self != other`.
147-
148-
:param any other: The object to compare against.
149-
:return: The Z3 expression `self != other`.
150-
:rtype: BoolRef
151-
"""
152-
return self.z3.__ne__(_coerce(other))
153-
154-
def __neg__(self) -> BoolRef:
155-
"""Create the Z3 expression `-self`.
156-
157-
:param any other: The object to compare against.
158-
:return: The Z3 expression `-self`.
159-
:rtype: BoolRef
160-
"""
161-
return self.z3.__neg__()
162-
163-
def __pow__(self, other: Any) -> BoolRef:
164-
"""Create the Z3 expression `self ^ other`.
165-
166-
:param any other: The object to compare against.
167-
:return: The Z3 expression `self ^ other`.
168-
:rtype: BoolRef
169-
"""
170-
return self.z3.__pow__(_coerce(other))
171-
172-
def __sub__(self, other: Any) -> BoolRef:
173-
"""Create the Z3 expression `self - other`.
174-
175-
:param any other: The object to compare against.
176-
:return: The Z3 expression `self - other`.
177-
:rtype: BoolRef
178-
"""
179-
return self.z3.__sub__(_coerce(other))
180-
181-
def __truediv__(self, other: Any) -> BoolRef:
182-
"""Create the Z3 expression `self / other`.
183-
184-
:param any other: The object to compare against.
185-
:return: The Z3 expression `self / other`.
186-
:rtype: BoolRef
187-
"""
188-
return self.z3.__truediv__(_coerce(other))
189-
190-
# End thin wrapper
191-
19241
def cast(self, val: Any) -> T:
19342
"""Cast the supplied value to the datatype T of the variable.
19443
@@ -209,16 +58,6 @@ def cast(self, val: Any) -> T:
20958
return self.datatype(val)
21059
return self.datatype(str(val))
21160

212-
def z3_val(self, z3_var, val: Any) -> T:
213-
"""Cast value to Z3 value"""
214-
native_val = self.cast(val)
215-
if isinstance(native_val, Enum):
216-
values = [z3_var.sort().constructor(c)() for c in range(z3_var.sort().num_constructors())]
217-
values = [v for v in values if val.__class__(str(v)) == val]
218-
assert len(values) == 1, f"Expected {values} to be length 1"
219-
return values[0]
220-
return native_val
221-
22261
def sample(self, n_samples: int) -> [T]:
22362
"""Generate a Latin Hypercube Sample of size n_samples according to the
22463
Variable's distribution.

causal_testing/surrogate/surrogate_search_algorithms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def create_gene_types(
9898
var_space[adj] = {}
9999

100100
for relationship in list(specification.scenario.constraints):
101+
print(relationship)
101102
rel_split = str(relationship).split(" ")
102103

103104
if rel_split[0] in var_space:
@@ -109,7 +110,6 @@ def create_gene_types(
109110
var_space[rel_split[0]]["high"] = int(rel_split[2]) + 1
110111
else:
111112
var_space[rel_split[0]]["high"] = datatype(rel_split[2])
112-
113113
gene_space = []
114114
gene_space.append(var_space[surrogate_model.treatment])
115115
for adj in surrogate_model.adjustment_set:

dafni/main_dafni.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,6 @@ def validate_variables(data_dict: dict) -> tuple:
126126

127127
constraints = set()
128128

129-
for variable, input_var in zip(variables, inputs):
130-
if "constraint" in variable:
131-
constraints.add(input_var.z3 == variable["constraint"])
132129
else:
133130
raise ValidationError("Cannot find the variables defined by the causal tests.")
134131

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ requires-python = ">=3.10"
1515
license = { text = "MIT" }
1616
keywords = ["causal inference", "verification"]
1717
dependencies = [
18-
"z3_solver~=4.11.2", # z3_solver does not follow semantic versioning and tying to 4.11 introduces problems
1918
"fitter~=1.7",
2019
"lifelines~=0.29.0",
2120
"lhsmdu~=1.1",

tests/json_front_tests/test_json_class.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,9 @@ def test_f_flag(self):
9999
}
100100
self.json_class.test_plan = example_test
101101
effects = {"NoEffect": NoEffect()}
102-
mutates = {
103-
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
104-
> self.json_class.scenario.variables[x].z3
105-
}
106102
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
107103
with self.assertRaises(StatisticsError):
108-
self.json_class.run_json_tests(effects, estimators, True, mutates)
104+
self.json_class.run_json_tests(effects, estimators, True)
109105

110106
def test_generate_coefficient_tests_from_json(self):
111107
example_test = {
@@ -149,15 +145,9 @@ def test_run_json_tests_from_json(self):
149145
}
150146
self.json_class.test_plan = example_test
151147
effects = {"NoEffect": NoEffect()}
152-
mutates = {
153-
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
154-
> self.json_class.scenario.variables[x].z3
155-
}
156148
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
157149

158-
test_results = self.json_class.run_json_tests(
159-
effects=effects, estimators=estimators, f_flag=False, mutates=mutates
160-
)
150+
test_results = self.json_class.run_json_tests(effects=effects, estimators=estimators, f_flag=False)
161151
self.assertTrue(test_results[0]["failed"])
162152

163153
def test_generate_tests_from_json_no_dist(self):
@@ -176,13 +166,9 @@ def test_generate_tests_from_json_no_dist(self):
176166
}
177167
self.json_class.test_plan = example_test
178168
effects = {"NoEffect": NoEffect()}
179-
mutates = {
180-
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
181-
> self.json_class.scenario.variables[x].z3
182-
}
183169
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
184170

185-
self.json_class.run_json_tests(effects=effects, mutates=mutates, estimators=estimators, f_flag=False)
171+
self.json_class.run_json_tests(effects=effects, estimators=estimators, f_flag=False)
186172

187173
# Test that the final log message prints that failed tests are printed, which is expected behaviour for this scenario
188174
with open("temp_out.txt", "r") as reader:
@@ -206,13 +192,9 @@ def test_formula_in_json_test(self):
206192
}
207193
self.json_class.test_plan = example_test
208194
effects = {"Positive": Positive()}
209-
mutates = {
210-
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
211-
> self.json_class.scenario.variables[x].z3
212-
}
213195
estimators = {"LinearRegressionEstimator": LinearRegressionEstimator}
214196

215-
self.json_class.run_json_tests(effects=effects, mutates=mutates, estimators=estimators, f_flag=False)
197+
self.json_class.run_json_tests(effects=effects, estimators=estimators, f_flag=False)
216198
with open("temp_out.txt", "r") as reader:
217199
temp_out = reader.readlines()
218200
self.assertIn("test_output ~ test_input", "".join(temp_out))
@@ -282,13 +264,9 @@ def add_modelling_assumptions(self):
282264
}
283265
self.json_class.test_plan = example_test
284266
effects = {"Positive": Positive()}
285-
mutates = {
286-
"Increase": lambda x: self.json_class.scenario.treatment_variables[x].z3
287-
> self.json_class.scenario.variables[x].z3
288-
}
289267
estimators = {"ExampleEstimator": ExampleEstimator}
290268
with self.assertRaises(TypeError):
291-
self.json_class.run_json_tests(effects=effects, mutates=mutates, estimators=estimators, f_flag=False)
269+
self.json_class.run_json_tests(effects=effects, estimators=estimators, f_flag=False)
292270

293271
def tearDown(self) -> None:
294272
if os.path.exists("temp_out.txt"):

0 commit comments

Comments
 (0)