Skip to content

Commit ffbafa9

Browse files
agilotvkuncak
authored andcommitted
Floating point theory
1 parent 51a4487 commit ffbafa9

File tree

1 file changed

+319
-0
lines changed

1 file changed

+319
-0
lines changed
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
package smtlib
2+
package theories
3+
4+
import trees.Terms._
5+
import common._
6+
7+
import Operations._
8+
9+
/**
10+
* Floating point arithmetic theory as defined in the SMT-LIB standard
11+
* https://smt-lib.org/theories-FloatingPoint.shtml
12+
*/
13+
object FloatingPoint {
14+
15+
//-------
16+
// Sorts
17+
//-------
18+
19+
object RoundingModeSort {
20+
def apply(): Sort = {
21+
Sort(Identifier(SSymbol("RoundingMode")))
22+
}
23+
def unapply(sort: Sort): Boolean = sort match {
24+
case Sort(Identifier(SSymbol("RoundingMode"), Seq()), Seq()) => true
25+
case _ => false
26+
}
27+
}
28+
29+
export Reals.RealSort
30+
31+
// Bit vector sorts, indexed by vector size
32+
export FixedSizeBitVectors.BitVectorSort
33+
34+
35+
/**
36+
* Floating point sort, indexed by the length of the exponent and significand
37+
* components of the number.
38+
*/
39+
object FloatingPointSort {
40+
def apply(ebits: BigInt, sbits: BigInt): Sort = {
41+
require(ebits > 1)
42+
require(sbits > 1)
43+
Sort(Identifier(SSymbol("FloatingPoint"), Seq(SNumeral(ebits), SNumeral(sbits))))
44+
}
45+
def unapply(sort: Sort): Option[(BigInt, BigInt)] = sort match {
46+
case Sort(Identifier(SSymbol("FloatingPoint"), Seq(SNumeral(e), SNumeral(s))), Seq()) if e > 1 && s > 1 => Some(e, s)
47+
case _ => None
48+
}
49+
}
50+
51+
// Short name for common floating point sorts
52+
53+
object Float16 {
54+
def apply(): Sort = {
55+
Sort(Identifier(SSymbol("Float16")))
56+
}
57+
def unapply(sort: Sort): Boolean = sort match {
58+
case Sort(Identifier(SSymbol("Float16"), Seq()), Seq()) => true
59+
case _ => false
60+
}
61+
}
62+
63+
object Float32 {
64+
def apply(): Sort = {
65+
Sort(Identifier(SSymbol("Float32")))
66+
}
67+
def unapply(sort: Sort): Boolean = sort match {
68+
case Sort(Identifier(SSymbol("Float32"), Seq()), Seq()) => true
69+
case _ => false
70+
}
71+
}
72+
73+
object Float64 {
74+
def apply(): Sort = {
75+
Sort(Identifier(SSymbol("Float64")))
76+
}
77+
def unapply(sort: Sort): Boolean = sort match {
78+
case Sort(Identifier(SSymbol("Float64"), Seq()), Seq()) => true
79+
case _ => false
80+
}
81+
}
82+
83+
object Float128 {
84+
def apply(): Sort = {
85+
Sort(Identifier(SSymbol("Float128")))
86+
}
87+
def unapply(sort: Sort): Boolean = sort match {
88+
case Sort(Identifier(SSymbol("Float128"), Seq()), Seq()) => true
89+
case _ => false
90+
}
91+
}
92+
93+
//----------------
94+
// Rounding modes
95+
//----------------
96+
97+
// Constants for rounding modes, and their abbreviated version
98+
99+
object RoundNearestTiesToEven extends Operation0 { override val name = "roundNearestTiesToEven" }
100+
object RoundNearestTiesToAway extends Operation0 { override val name = "roundNearestTiesToAway" }
101+
object RoundTowardPositive extends Operation0 { override val name = "roundTowardPositive" }
102+
object RoundTowardNegative extends Operation0 { override val name = "roundTowardNegative" }
103+
object RoundTowardZero extends Operation0 { override val name = "roundTowardZero" }
104+
105+
lazy val RNE = RoundNearestTiesToEven
106+
lazy val RNA = RoundNearestTiesToAway
107+
lazy val RTP = RoundTowardPositive
108+
lazy val RTN = RoundTowardNegative
109+
lazy val RTZ = RoundTowardZero
110+
111+
//--------------------
112+
// Value constructors
113+
//--------------------
114+
115+
// Bitvector literals
116+
export FixedSizeBitVectors.BitVectorLit
117+
118+
/**
119+
* FP literals as bit string triples, with the leading bit for the significand
120+
* not represented (hidden bit)
121+
*/
122+
object FPLit extends Operation3 { override val name = "fp" }
123+
124+
// Plus and minus infinity
125+
126+
object PlusInfinity {
127+
def apply(ebits: BigInt, sbits: BigInt): Term = {
128+
require(ebits > 1)
129+
require(sbits > 1)
130+
QualifiedIdentifier(Identifier(SSymbol("+oo"), Seq(SNumeral(ebits), SNumeral(sbits))))
131+
}
132+
def unapply(sort: Term): Option[(BigInt, BigInt)] = sort match {
133+
case QualifiedIdentifier(Identifier(SSymbol("+oo"), Seq(SNumeral(e), SNumeral(s))), None) if e > 1 && s > 1 => Some(e, s)
134+
case _ => None
135+
}
136+
}
137+
138+
object MinusInfinity {
139+
def apply(ebits: BigInt, sbits: BigInt): Term = {
140+
require(ebits > 1)
141+
require(sbits > 1)
142+
QualifiedIdentifier(Identifier(SSymbol("-oo"), Seq(SNumeral(ebits), SNumeral(sbits))))
143+
}
144+
def unapply(sort: Term): Option[(BigInt, BigInt)] = sort match {
145+
case QualifiedIdentifier(Identifier(SSymbol("-oo"), Seq(SNumeral(e), SNumeral(s))), None) if e > 1 && s > 1 => Some(e, s)
146+
case _ => None
147+
}
148+
}
149+
150+
// Plus and minus zero
151+
152+
object PlusZero {
153+
def apply(ebits: BigInt, sbits: BigInt): Term = {
154+
require(ebits > 1)
155+
require(sbits > 1)
156+
QualifiedIdentifier(Identifier(SSymbol("+zero"), Seq(SNumeral(ebits), SNumeral(sbits))))
157+
}
158+
def unapply(sort: Term): Option[(BigInt, BigInt)] = sort match {
159+
case QualifiedIdentifier(Identifier(SSymbol("+zero"), Seq(SNumeral(e), SNumeral(s))), None) if e > 1 && s > 1 => Some(e, s)
160+
case _ => None
161+
}
162+
}
163+
164+
object MinusZero {
165+
def apply(ebits: BigInt, sbits: BigInt): Term = {
166+
require(ebits > 1)
167+
require(sbits > 1)
168+
QualifiedIdentifier(Identifier(SSymbol("-zero"), Seq(SNumeral(ebits), SNumeral(sbits))))
169+
}
170+
def unapply(sort: Term): Option[(BigInt, BigInt)] = sort match {
171+
case QualifiedIdentifier(Identifier(SSymbol("-zero"), Seq(SNumeral(e), SNumeral(s))), None) if e > 1 && s > 1 => Some(e, s)
172+
case _ => None
173+
}
174+
}
175+
176+
// Non-numbers
177+
178+
object NaN {
179+
def apply(ebits: BigInt, sbits: BigInt): Term = {
180+
require(ebits > 1)
181+
require(sbits > 1)
182+
QualifiedIdentifier(Identifier(SSymbol("NaN"), Seq(SNumeral(ebits), SNumeral(sbits))))
183+
}
184+
def unapply(sort: Term): Option[(BigInt, BigInt)] = sort match {
185+
case QualifiedIdentifier(Identifier(SSymbol("NaN"), Seq(SNumeral(e), SNumeral(s))), None) if e > 1 && s > 1 => Some(e, s)
186+
case _ => None
187+
}
188+
}
189+
190+
//-----------
191+
// Operators
192+
//-----------
193+
194+
// absolute value
195+
object Abs extends Operation1 { override val name = "fp.abs" }
196+
// negation (no rounding needed)
197+
object Neg extends Operation1 { override val name = "fp.neg" }
198+
// addition
199+
object Add extends Operation3 { override val name = "fp.add" }
200+
// subtraction
201+
object Sub extends Operation3 { override val name = "fp.sub" }
202+
// multiplication
203+
object Mul extends Operation3 { override val name = "fp.mul" }
204+
// division
205+
object Div extends Operation3 { override val name = "fp.div" }
206+
// fused multiplication and addition; (x * y) + z
207+
object FMA extends OperationN { override val name = "fp.fma"; override val numRequired = 4 }
208+
// square root
209+
object Sqrt extends Operation2 { override val name = "fp.sqrt" }
210+
// remainder: x - y * n, where n in Z is nearest to x/y
211+
object Rem extends Operation2 { override val name = "fp.rem" }
212+
// rounding to integral
213+
object RoundToIntegral extends Operation2 { override val name = "fp.roundToIntegral" }
214+
// minimum and maximum
215+
object Min extends Operation2 { override val name = "fp.min" }
216+
object Max extends Operation2 { override val name = "fp.max" }
217+
// comparison operators
218+
// Note that all comparisons evaluate to false if either argument is NaN
219+
object LessEquals extends Operation2 { override val name = "fp.leq" }
220+
object LessThan extends Operation2 { override val name = "fp.lt" }
221+
object GreaterEquals extends Operation2 { override val name = "fp.geq" }
222+
object GreaterThan extends Operation2 { override val name = "fp.gt" }
223+
// IEEE 754-2008 equality (as opposed to SMT-LIB =)
224+
object Eq extends Operation2 { override val name = "fp.eq" }
225+
// Classification of numbers
226+
object IsNormal extends Operation1 { override val name = "fp.isNormal" }
227+
object IsSubnormal extends Operation1 { override val name = "fp.isSubnormal" }
228+
object IsZero extends Operation1 { override val name = "fp.isZero" }
229+
object IsInfinite extends Operation1 { override val name = "fp.isInfinite" }
230+
object IsNaN extends Operation1 { override val name = "fp.isNaN" }
231+
object IsNegative extends Operation1 { override val name = "fp.isNegative" }
232+
object IsPositive extends Operation1 { override val name = "fp.isPositive" }
233+
234+
lazy val + = Add
235+
lazy val - = Sub
236+
lazy val * = Mul
237+
lazy val / = Div
238+
lazy val > = GreaterThan
239+
lazy val >= = GreaterEquals
240+
lazy val < = LessThan
241+
lazy val <= = LessEquals
242+
243+
//------------------------------
244+
// Conversions from other sorts
245+
//------------------------------
246+
247+
object ToFP {
248+
def apply(ebits: BigInt, sbits: BigInt, arg: Term): Term = {
249+
require(ebits > 1)
250+
require(sbits > 1)
251+
FunctionApplication(QualifiedIdentifier(Identifier(SSymbol("to_fp"), Seq(SNumeral(ebits), SNumeral(sbits)))), Seq(arg))
252+
}
253+
254+
def apply(ebits: BigInt, sbits: BigInt, roundingMode: Term, arg: Term): Term = {
255+
require(ebits > 1)
256+
require(sbits > 1)
257+
FunctionApplication(QualifiedIdentifier(Identifier(SSymbol("to_fp"), Seq(SNumeral(ebits), SNumeral(sbits)))), Seq(roundingMode, arg))
258+
}
259+
260+
def unapply(sort: Term): Option[(BigInt, BigInt, Seq[Term])] = sort match {
261+
case FunctionApplication(QualifiedIdentifier(Identifier(SSymbol("to_fp"), Seq(SNumeral(e), SNumeral(s))), None), seq) if e > 1 && s > 1 && 1 <= seq.length && seq.length <= 2 =>
262+
Some((e, s, seq))
263+
case _ => None
264+
}
265+
}
266+
267+
object ToFPUnsigned {
268+
def apply(ebits: BigInt, sbits: BigInt, roundingMode: Term, bitvec: Term): Term = {
269+
require(ebits > 1)
270+
require(sbits > 1)
271+
FunctionApplication(QualifiedIdentifier(Identifier(SSymbol("to_fp_unsigned"), Seq(SNumeral(ebits), SNumeral(sbits)))), Seq(roundingMode, bitvec))
272+
}
273+
274+
def unapply(sort: Term): Option[(BigInt, BigInt, Term, Term)] = sort match {
275+
case FunctionApplication(QualifiedIdentifier(Identifier(SSymbol("to_fp_unsigned"), Seq(SNumeral(e), SNumeral(s))), None), Seq(roundingMode, bitvec)) if e > 1 && s > 1 =>
276+
Some((e, s, roundingMode, bitvec))
277+
case _ => None
278+
}
279+
}
280+
281+
//----------------------------
282+
// Conversions to other sorts
283+
//----------------------------
284+
285+
// to unsigned machine integer, represented as a bit vector
286+
object ToUnsignedBitVector {
287+
def apply(length: BigInt, roundingMode: Term, arg: Term): Term = {
288+
require(length > 0)
289+
FunctionApplication(QualifiedIdentifier(Identifier(SSymbol("fp.to_ubv"), Seq(SNumeral(length))), None), Seq(roundingMode, arg))
290+
}
291+
292+
def unapply(sort: Term): Option[(BigInt, Term, Term)] = sort match {
293+
case FunctionApplication(QualifiedIdentifier(Identifier(SSymbol("fp.to_ubv"), Seq(SNumeral(length))), None), Seq(roundingMode, arg) ) if length > 0 =>
294+
Some((length, roundingMode, arg))
295+
case _ => None
296+
}
297+
}
298+
// to signed machine integer, represented as a 2's complement bit vector
299+
object ToSignedBitVector {
300+
def apply(length: BigInt, roundingMode: Term, arg: Term): Term = {
301+
require(length > 0)
302+
FunctionApplication(QualifiedIdentifier(Identifier(SSymbol("fp.to_sbv"), Seq(SNumeral(length))), None), Seq(roundingMode, arg))
303+
}
304+
305+
def unapply(sort: Term): Option[(BigInt, Term, Term)] = sort match {
306+
case FunctionApplication(QualifiedIdentifier(Identifier(SSymbol("fp.to_sbv"), Seq(SNumeral(length))), None), Seq(roundingMode, arg) ) if length > 0 =>
307+
Some((length, roundingMode, arg))
308+
case _ => None
309+
}
310+
}
311+
// to real
312+
object ToReal extends Operation1 { override val name = "fp.to_real" }
313+
314+
315+
316+
317+
318+
319+
}

0 commit comments

Comments
 (0)