Skip to content

Commit 2b5376d

Browse files
authored
Merge pull request #168 from CITCOM-project/z3_operations
Implemented rest of Z3 operators. Closes 147.
2 parents a187fec + d042d4f commit 2b5376d

File tree

2 files changed

+102
-33
lines changed

2 files changed

+102
-33
lines changed

causal_testing/specification/variable.py

Lines changed: 64 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -77,78 +77,118 @@ def __init__(self, name: str, datatype: T, distribution: rv_generic = None, hidd
7777
def __repr__(self):
7878
return f"{self.typestring()}: {self.name}::{self.datatype.__name__}"
7979

80-
def __ge__(self, other: Any) -> BoolRef:
81-
"""Create the Z3 expression `other >= self`.
80+
# Thin wrapper for Z1 functions
81+
82+
def __add__(self, other: Any) -> BoolRef:
83+
"""Create the Z3 expression `self + other`.
8284
8385
:param any other: The object to compare against.
84-
:return: The Z3 expression `other >= self`.
86+
:return: The Z3 expression `self + other`.
8587
:rtype: BoolRef
8688
"""
87-
return self.z3.__ge__(_coerce(other))
89+
return self.z3.__add__(_coerce(other))
8890

89-
def __le__(self, other: Any) -> BoolRef:
90-
"""Create the Z3 expression `other <= self`.
91+
def __ge__(self, other: Any) -> BoolRef:
92+
"""Create the Z3 expression `self >= other`.
9193
9294
:param any other: The object to compare against.
93-
:return: The Z3 expression `other >= self`.
95+
:return: The Z3 expression `self >= other`.
9496
:rtype: BoolRef
9597
"""
96-
return self.z3.__le__(_coerce(other))
98+
return self.z3.__ge__(_coerce(other))
9799

98100
def __gt__(self, other: Any) -> BoolRef:
99-
"""Create the Z3 expression `other > self`.
101+
"""Create the Z3 expression `self > other`.
100102
101103
:param any other: The object to compare against.
102-
:return: The Z3 expression `other >= self`.
104+
:return: The Z3 expression `self > other`.
103105
:rtype: BoolRef
104106
"""
105107
return self.z3.__gt__(_coerce(other))
106108

109+
def __le__(self, other: Any) -> BoolRef:
110+
"""Create the Z3 expression `self <= other`.
111+
112+
:param any other: The object to compare against.
113+
:return: The Z3 expression `self <= other`.
114+
:rtype: BoolRef
115+
"""
116+
return self.z3.__le__(_coerce(other))
117+
107118
def __lt__(self, other: Any) -> BoolRef:
108-
"""Create the Z3 expression `other < self`.
119+
"""Create the Z3 expression `self < other`.
109120
110121
:param any other: The object to compare against.
111-
:return: The Z3 expression `other >= self`.
122+
:return: The Z3 expression `self < other`.
112123
:rtype: BoolRef
113124
"""
114125
return self.z3.__lt__(_coerce(other))
115126

127+
def __mod__(self, other: Any) -> BoolRef:
128+
"""Create the Z3 expression `self % other`.
129+
130+
:param any other: The object to compare against.
131+
:return: The Z3 expression `self % other`.
132+
:rtype: BoolRef
133+
"""
134+
return self.z3.__mod__(_coerce(other))
135+
116136
def __mul__(self, other: Any) -> BoolRef:
117-
"""Create the Z3 expression `other * self`.
137+
"""Create the Z3 expression `self * other`.
118138
119139
:param any other: The object to compare against.
120-
:return: The Z3 expression `other >= self`.
140+
:return: The Z3 expression `self * other`.
121141
:rtype: BoolRef
122142
"""
123143
return self.z3.__mul__(_coerce(other))
124144

125-
def __sub__(self, other: Any) -> BoolRef:
126-
"""Create the Z3 expression `other * self`.
145+
def __ne__(self, other: Any) -> BoolRef:
146+
"""Create the Z3 expression `self != other`.
127147
128148
:param any other: The object to compare against.
129-
:return: The Z3 expression `other >= self`.
149+
:return: The Z3 expression `self != other`.
130150
:rtype: BoolRef
131151
"""
132-
return self.z3.__sub__(_coerce(other))
152+
return self.z3.__ne__(_coerce(other))
133153

134-
def __add__(self, other: Any) -> BoolRef:
135-
"""Create the Z3 expression `other * self`.
154+
def __neg__(self) -> BoolRef:
155+
"""Create the Z3 expression `-self`.
136156
137157
:param any other: The object to compare against.
138-
:return: The Z3 expression `other >= self`.
158+
:return: The Z3 expression `-self`.
139159
:rtype: BoolRef
140160
"""
141-
return self.z3.__add__(_coerce(other))
161+
return self.z3.__neg__()
162+
163+
def __pow__(self, other: Any) -> BoolRef:
164+
"""Create the Z3 expression `self ^ other`.
165+
166+
:param any other: The object to compare against.
167+
:return: The Z3 expression `self ^ other`.
168+
:rtype: BoolRef
169+
"""
170+
return self.z3.__pow__(_coerce(other))
171+
172+
def __sub__(self, other: Any) -> BoolRef:
173+
"""Create the Z3 expression `self - other`.
174+
175+
:param any other: The object to compare against.
176+
:return: The Z3 expression `self - other`.
177+
:rtype: BoolRef
178+
"""
179+
return self.z3.__sub__(_coerce(other))
142180

143181
def __truediv__(self, other: Any) -> BoolRef:
144-
"""Create the Z3 expression `other * self`.
182+
"""Create the Z3 expression `self / other`.
145183
146184
:param any other: The object to compare against.
147-
:return: The Z3 expression `other >= self`.
185+
:return: The Z3 expression `self / other`.
148186
:rtype: BoolRef
149187
"""
150188
return self.z3.__truediv__(_coerce(other))
151189

190+
# End thin wrapper
191+
152192
def cast(self, val: Any) -> T:
153193
"""Cast the supplied value to the datatype T of the variable.
154194

tests/specification_tests/test_variable.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,14 @@ class Var(Variable):
133133
var = Var("v", int)
134134
self.assertEqual(var.typestring(), "Var")
135135

136+
136137
def test_copy(self):
137138
ip = Input("ip", float, norm)
138-
self.assertNotEqual(ip.copy(), ip)
139+
self.assertTrue(ip.copy() is not ip)
139140
self.assertEqual(ip.copy().name, ip.name)
140141
self.assertEqual(ip.copy().datatype, ip.datatype)
141142
self.assertEqual(ip.copy().distribution, ip.distribution)
143+
self.assertEqual(repr(ip), repr(ip.copy()))
142144

143145

144146
class TestZ3Methods(unittest.TestCase):
@@ -152,14 +154,41 @@ class TestZ3Methods(unittest.TestCase):
152154
def setUp(self) -> None:
153155
self.i1 = Input("i1", int)
154156

155-
def test_ge_add(self):
156-
self.assertEqual(str(self.i1 + 1 >= 5), "i1 + 1 >= 5")
157+
def test_ge_self(self):
158+
self.assertEqual(str(self.i1 >= self.i1), "i1 >= i1")
159+
160+
def test_add(self):
161+
self.assertEqual(str(self.i1 + 1), "i1 + 1")
162+
163+
def test_ge(self):
164+
self.assertEqual(str(self.i1 >= 5), "i1 >= 5")
165+
166+
def test_mod(self):
167+
self.assertEqual(str(self.i1 % 2), "i1%2")
168+
169+
def test_ne(self):
170+
self.assertEqual(str(self.i1 != 5), "i1 != 5")
171+
172+
def test_neg(self):
173+
self.assertEqual(str(-self.i1), "-i1")
174+
175+
def test_pow(self):
176+
self.assertEqual(str(self.i1 ** 5), "i1**5")
177+
178+
def test_le(self):
179+
self.assertEqual(str(self.i1 <= 5), "i1 <= 5")
180+
181+
def test_mul(self):
182+
self.assertEqual(str(self.i1 * 2), "i1*2")
183+
184+
def test_gt(self):
185+
self.assertEqual(str(self.i1 > 5), "i1 > 5")
157186

158-
def test_le_mul(self):
159-
self.assertEqual(str(self.i1 * 2 <= 5), "i1*2 <= 5")
187+
def test_truediv(self):
188+
self.assertEqual(str(self.i1 / 3), "i1/3")
160189

161-
def test_gt_truediv(self):
162-
self.assertEqual(str(self.i1 / 3 > 5), "i1/3 > 5")
190+
def test_sub(self):
191+
self.assertEqual(str(self.i1 - 4), "i1 - 4")
163192

164-
def test_lt_sub(self):
165-
self.assertEqual(str(self.i1 - 4 < 5), "i1 - 4 < 5")
193+
def test_lt(self):
194+
self.assertEqual(str(self.i1 < 5), "i1 < 5")

0 commit comments

Comments
 (0)