Skip to content

Commit 33d3448

Browse files
committed
Implemented rest of Z3 operators. Closes 147.
1 parent 6764fb3 commit 33d3448

File tree

1 file changed

+65
-24
lines changed

1 file changed

+65
-24
lines changed

causal_testing/specification/variable.py

Lines changed: 65 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -77,78 +77,119 @@ def __init__(self, name: str, datatype: T, distribution: rv_generic = None, hidd
7777
def __repr__(self):
7878
return f"{self.typestring()}: {self.name}::{self.datatype.__name__}"
7979

80-
def __ge__(self, other: Any) -> BoolRef:
81-
"""Create the Z3 expression `other >= self`.
80+
# Thin wrapper for Z1 functions
81+
82+
def __add__(self, other: Any) -> BoolRef:
83+
"""Create the Z3 expression `self + other`.
8284
8385
:param any other: The object to compare against.
84-
:return: The Z3 expression `other >= self`.
86+
:return: The Z3 expression `self + other`.
8587
:rtype: BoolRef
8688
"""
87-
return self.z3.__ge__(_coerce(other))
89+
return self.z3.__add__(_coerce(other))
8890

89-
def __le__(self, other: Any) -> BoolRef:
90-
"""Create the Z3 expression `other <= self`.
91+
def __ge__(self, other: Any) -> BoolRef:
92+
"""Create the Z3 expression `self >= other`.
9193
9294
:param any other: The object to compare against.
93-
:return: The Z3 expression `other >= self`.
95+
:return: The Z3 expression `self >= other`.
9496
:rtype: BoolRef
9597
"""
96-
return self.z3.__le__(_coerce(other))
98+
return self.z3.__ge__(_coerce(other))
9799

98100
def __gt__(self, other: Any) -> BoolRef:
99-
"""Create the Z3 expression `other > self`.
101+
"""Create the Z3 expression `self > other`.
100102
101103
:param any other: The object to compare against.
102-
:return: The Z3 expression `other >= self`.
104+
:return: The Z3 expression `self > other`.
103105
:rtype: BoolRef
104106
"""
105107
return self.z3.__gt__(_coerce(other))
106108

109+
def __le__(self, other: Any) -> BoolRef:
110+
"""Create the Z3 expression `self <= other`.
111+
112+
:param any other: The object to compare against.
113+
:return: The Z3 expression `self <= other`.
114+
:rtype: BoolRef
115+
"""
116+
return self.z3.__le__(_coerce(other))
117+
107118
def __lt__(self, other: Any) -> BoolRef:
108-
"""Create the Z3 expression `other < self`.
119+
"""Create the Z3 expression `self < other`.
109120
110121
:param any other: The object to compare against.
111-
:return: The Z3 expression `other >= self`.
122+
:return: The Z3 expression `self < other`.
112123
:rtype: BoolRef
113124
"""
114125
return self.z3.__lt__(_coerce(other))
115126

127+
128+
def __mod__(self, other: Any) -> BoolRef:
129+
"""Create the Z3 expression `self % other`.
130+
131+
:param any other: The object to compare against.
132+
:return: The Z3 expression `self % other`.
133+
:rtype: BoolRef
134+
"""
135+
return self.z3.__mod__(_coerce(other))
136+
116137
def __mul__(self, other: Any) -> BoolRef:
117-
"""Create the Z3 expression `other * self`.
138+
"""Create the Z3 expression `self * other`.
118139
119140
:param any other: The object to compare against.
120-
:return: The Z3 expression `other >= self`.
141+
:return: The Z3 expression `self * other`.
121142
:rtype: BoolRef
122143
"""
123144
return self.z3.__mul__(_coerce(other))
124145

125-
def __sub__(self, other: Any) -> BoolRef:
126-
"""Create the Z3 expression `other * self`.
146+
def __ne__(self, other: Any) -> BoolRef:
147+
"""Create the Z3 expression `self != other`.
127148
128149
:param any other: The object to compare against.
129-
:return: The Z3 expression `other >= self`.
150+
:return: The Z3 expression `self != other`.
130151
:rtype: BoolRef
131152
"""
132-
return self.z3.__sub__(_coerce(other))
153+
return self.z3.__ne__(_coerce(other))
133154

134-
def __add__(self, other: Any) -> BoolRef:
135-
"""Create the Z3 expression `other * self`.
155+
def __neg__(self) -> BoolRef:
156+
"""Create the Z3 expression `-self`.
136157
137158
:param any other: The object to compare against.
138-
:return: The Z3 expression `other >= self`.
159+
:return: The Z3 expression `-self`.
139160
:rtype: BoolRef
140161
"""
141-
return self.z3.__add__(_coerce(other))
162+
return self.z3.__neg__()
163+
164+
def __pow__(self, other: Any) -> BoolRef:
165+
"""Create the Z3 expression `self ^ other`.
166+
167+
:param any other: The object to compare against.
168+
:return: The Z3 expression `self ^ other`.
169+
:rtype: BoolRef
170+
"""
171+
return self.z3.__pow__(_coerce(other))
172+
173+
def __sub__(self, other: Any) -> BoolRef:
174+
"""Create the Z3 expression `self - other`.
175+
176+
:param any other: The object to compare against.
177+
:return: The Z3 expression `self - other`.
178+
:rtype: BoolRef
179+
"""
180+
return self.z3.__sub__(_coerce(other))
142181

143182
def __truediv__(self, other: Any) -> BoolRef:
144-
"""Create the Z3 expression `other * self`.
183+
"""Create the Z3 expression `self / other`.
145184
146185
:param any other: The object to compare against.
147-
:return: The Z3 expression `other >= self`.
186+
:return: The Z3 expression `self / other`.
148187
:rtype: BoolRef
149188
"""
150189
return self.z3.__truediv__(_coerce(other))
151190

191+
# End thin wrapper
192+
152193
def cast(self, val: Any) -> T:
153194
"""Cast the supplied value to the datatype T of the variable.
154195

0 commit comments

Comments
 (0)