Skip to content

Commit 539e794

Browse files
agilotvkuncak
authored andcommitted
Add floating point arithmetic and tests
1 parent 1b952e8 commit 539e794

File tree

7 files changed

+444
-5
lines changed

7 files changed

+444
-5
lines changed

src/main/scala/inox/ast/Deconstructors.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@ trait TreeDeconstructor {
102102
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
103103
(_, _, _, _, _) => t.BVLiteral(signed, bits, size))
104104
},
105+
classOf[s.FPLiteral] -> { expr =>
106+
val s.FPLiteral(exponent, significand, bits) = expr: @unchecked
107+
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
108+
(_, _, _, _, _) => t.FPLiteral(exponent, significand, bits))
109+
},
105110
classOf[s.IntegerLiteral] -> { expr =>
106111
val s.IntegerLiteral(i) = expr: @unchecked
107112
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
@@ -297,6 +302,11 @@ trait TreeDeconstructor {
297302
(NoIdentifiers, NoVariables, Seq(e), NoTypes, NoFlags,
298303
(_, _, es, _, _) => t.BVSignedToUnsigned(es(0)))
299304
},
305+
classOf[s.FPEquals] -> { expr =>
306+
val s.FPEquals(t1, t2) = expr: @unchecked
307+
(NoIdentifiers, NoVariables, Seq(t1, t2), NoTypes, NoFlags,
308+
(_, _, es, _, _) => t.FPEquals(es(0), es(1)))
309+
},
300310
classOf[s.Tuple] -> { expr =>
301311
val s.Tuple(args) = expr: @unchecked
302312
(NoIdentifiers, NoVariables, args, NoTypes, NoFlags,
@@ -458,6 +468,11 @@ trait TreeDeconstructor {
458468
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
459469
(_, _, _, _, _) => t.BVType(signed, size))
460470
},
471+
classOf[s.FPType] -> { tpe =>
472+
val s.FPType(exponent, significand) = tpe: @unchecked
473+
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
474+
(_, _, _, _, _) => t.FPType(exponent, significand))
475+
},
461476

462477
// @nv: can't use `s.Untyped.getClass` as it is not yet created at this point
463478
scala.reflect.classTag[s.Untyped.type].runtimeClass -> { _ =>

src/main/scala/inox/ast/Expressions.scala

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,18 +233,42 @@ trait Expressions { self: Trees =>
233233
/** $encodingof a floating point literal */
234234
sealed case class FPLiteral(exponent: Int, significand: Int, value: BitSet) extends Literal[BitSet] {
235235
override def getType(using Symbols) = FPType(exponent, significand)
236+
def isNegative: Boolean = !isNaN && value(exponent + significand)
237+
def isPositive: Boolean = !isNaN && !isNegative
238+
def isZero: Boolean = !Range(1, significand + exponent).exists(value)
239+
def isNumber: Boolean = !Range(significand, significand + exponent).forall(value)
240+
def isNaN: Boolean = !isNumber && Range(1, significand).exists(value)
241+
def isInfinite: Boolean = !isNumber && !isNaN
236242
def toBV: BVLiteral = BVLiteral(true, value, exponent + significand)
243+
244+
def strictEquals(obj: Any): Boolean = obj match {
245+
case lit @ FPLiteral(e2, s2, v2) => exponent == e2 && significand == s2 && value == v2
246+
case _ => false
247+
}
248+
249+
/** Semantic equality for FP */
250+
def semEquals(obj: Any): Boolean = obj match
251+
case lit @ FPLiteral(e2, s2, v2) =>
252+
!isNaN && !lit.isNaN && ((isZero && lit.isZero) || strictEquals(obj))
253+
case _ => strictEquals(obj)
254+
255+
override def equals(obj: Any): Boolean = strictEquals(obj)
237256
}
238257

239258
object FPLiteral {
240-
def fromBV(exponent: Int, significand: Int, bv: BVLiteral) = FPLiteral(exponent, significand, bv.value)
259+
def fromBV(exponent: Int, significand: Int, bv: BVLiteral): FPLiteral = FPLiteral(exponent, significand, bv.value)
260+
def plusZero(exponent: Int, significand: Int) = FPLiteral(exponent, significand, BitSet.empty)
261+
def minusZero(exponent: Int, significand: Int) = FPLiteral(exponent, significand, BitSet(exponent + significand))
262+
def NaN(exponent: Int, significand: Int) = FPLiteral(exponent, significand, BitSet(Range(significand - 1, exponent + significand)*))
263+
def minusInfinity(exponent: Int, significand: Int) = FPLiteral(exponent, significand, BitSet(Range(significand, exponent + significand + 1)*))
264+
def plusInfinity(exponent: Int, significand: Int) = FPLiteral(exponent, significand, BitSet(Range(significand, exponent + significand)*))
241265
}
242266

243267
object Float32Literal {
244268
def apply(value: Float): FPLiteral = FPLiteral.fromBV(8, 24, Int32Literal(java.lang.Float.floatToIntBits(value)))
245269

246270
def unapply(e: Expr): Option[Float] = e match {
247-
case f @ FPLiteral(8, 24, b) if b.maxOption.getOrElse(-1) < 32 =>
271+
case f @ FPLiteral(8, 24, b) if b.maxOption.getOrElse(-1) <= 32 =>
248272
f.toBV match {
249273
case Int32Literal(i) => Some(java.lang.Float.intBitsToFloat(i))
250274
case _ => None
@@ -257,7 +281,7 @@ trait Expressions { self: Trees =>
257281
def apply(value: Double): FPLiteral = FPLiteral.fromBV(11, 53, Int64Literal(java.lang.Double.doubleToLongBits(value)))
258282

259283
def unapply(e: Expr): Option[Double] = e match {
260-
case f @ FPLiteral(11, 53, b) if b.maxOption.getOrElse(-1) < 64 =>
284+
case f @ FPLiteral(11, 53, b) if b.maxOption.getOrElse(-1) <= 64 =>
261285
f.toBV match {
262286
case Int64Literal(i) => Some(java.lang.Double.longBitsToDouble(i))
263287
case _ => None
@@ -617,6 +641,13 @@ trait Expressions { self: Trees =>
617641
}
618642
}
619643

644+
/* FP operaions */
645+
646+
sealed case class FPEquals(lhs: Expr, rhs: Expr) extends Expr with CachingTyped {
647+
override protected def computeType(using Symbols): Type =
648+
if getFPType(lhs, rhs).isTyped then BooleanType() else Untyped
649+
}
650+
620651

621652
/* Tuple operations */
622653

src/main/scala/inox/ast/Printers.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ trait Printer {
173173
case Int32Literal(v) => p"$v"
174174
case Int64Literal(v) => p"$v"
175175
case BVLiteral(_, bits, size) => p"x${(size to 1 by -1).map(i => if (bits(i)) "1" else "0").mkString("")}"
176+
case Float32Literal(f) => p"$f"
177+
case Float64Literal(f) => p"$f"
178+
case FPLiteral(exponent, significand, bits) =>
179+
p"x${(exponent + significand to 1 by -1).map(i => if (bits(i)) "1" else "0").mkString("")}"
176180
case IntegerLiteral(v) => p"$v"
177181
case FractionLiteral(n, d) =>
178182
if (d == 1) p"$n"
@@ -275,6 +279,8 @@ trait Printer {
275279
case BVUnsignedToSigned(e) => p"$e.toSigned"
276280
case BVSignedToUnsigned(e) => p"$e.toUnsigned"
277281

282+
case FPEquals(l, r) => p"$l === $r"
283+
278284
case fs @ FiniteSet(rs, _) => p"Set(${rs})"
279285
case fs @ FiniteBag(rs, _) => p"Bag(${rs.toSeq})"
280286
case fm @ FiniteMap(rs, dflt, _, _) =>

src/main/scala/inox/solvers/smtlib/SMTLIBParser.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,24 @@ trait SMTLIBParser {
9797
case _ => BVLiteral(true, n, size.intValue)
9898
}
9999

100+
case FloatingPoint.FPLit(sign, exponent, significand) => otpe match {
101+
case Some(FPType(eb, sb)) =>
102+
(fromSMT(sign, Some(BVType(true, 1))),
103+
fromSMT(exponent, Some(BVType(true, eb))),
104+
fromSMT(significand, Some(BVType(true, sb - 1)))) match {
105+
case (BVLiteral(true, bitset1, 1), BVLiteral(true, bitset2, `eb`), BVLiteral(true, bitset3, rem)) if rem == sb - 1 =>
106+
FPLiteral(eb, sb, bitset1.map(_ + eb + sb - 1) ++ bitset2.map(_ + sb - 1) ++ bitset3)
107+
case _ => throw new MissformedSMTException(term, "FP Lit has inconsistent components")
108+
}
109+
case _ => throw new MissformedSMTException(term, "FP Lit is not of type Float")
110+
}
111+
case FloatingPoint.PlusZero(exponent, significand) => FPLiteral.plusZero(exponent.toInt, significand.toInt)
112+
case FloatingPoint.MinusZero(exponent, significand) => FPLiteral.minusZero(exponent.toInt, significand.toInt)
113+
case FloatingPoint.NaN(exponent, significand) => FPLiteral.NaN(exponent.toInt, significand.toInt)
114+
case FloatingPoint.PlusInfinity(exponent, significand) => FPLiteral.plusInfinity(exponent.toInt, significand.toInt)
115+
case FloatingPoint.MinusInfinity(exponent, significand) => FPLiteral.minusInfinity(exponent.toInt, significand.toInt)
116+
117+
100118
case SDecimal(value) =>
101119
exprOps.normalizeFraction(FractionLiteral(
102120
value.bigDecimal.movePointRight(value.scale).toBigInteger,
@@ -196,6 +214,9 @@ trait SMTLIBParser {
196214
case Some(BVType(signed, _)) => signed
197215
case _ => true
198216
}, (i + 1).bigInteger.intValueExact))
217+
218+
case FloatingPoint.Eq(e1, e2) => fromSMTUnifyType(e1, e2, None)(FPEquals.apply)
219+
199220

200221
case ArraysEx.Select(e1, e2) => otpe match {
201222
case Some(tpe) =>

src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
151151
case IntegerType() => Ints.IntSort()
152152
case RealType() => Reals.RealSort()
153153
case BVType(_,l) => FixedSizeBitVectors.BitVectorSort(l)
154+
case FPType(e, s) => FloatingPoint.FloatingPointSort(e, s)
154155
case CharType() => FixedSizeBitVectors.BitVectorSort(16)
155156
case StringType() => Strings.StringSort()
156157

@@ -277,6 +278,14 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
277278

278279
case IntegerLiteral(i) => intToTerm(i)
279280
case BVLiteral(_, bits, size) => FixedSizeBitVectors.BitVectorLit(List.range(1, size + 1).map(i => bits(size + 1 - i)))
281+
case FPLiteral(e, s, bits) =>
282+
val size = e + s
283+
FloatingPoint.FPLit(
284+
FixedSizeBitVectors.BitVectorLit(List.range(1, 2).map(i => bits(size + 1 - i))),
285+
FixedSizeBitVectors.BitVectorLit(List.range(2, e + 2).map(i => bits(size + 1 - i))),
286+
FixedSizeBitVectors.BitVectorLit(List.range(e + 2, size + 1).map(i => bits(size + 1 - i)))
287+
)
288+
280289
case FractionLiteral(n, d) => Reals.Div(realToTerm(n), realToTerm(d))
281290
case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromShort(c.toShort))
282291
case BooleanLiteral(v) => Core.BoolConst(v)
@@ -372,35 +381,40 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
372381

373382
case UMinus(u) => u.getType match {
374383
case BVType(_,_) => FixedSizeBitVectors.Neg(toSMT(u))
384+
case FPType(_, _) => FloatingPoint.Neg(toSMT(u))
375385
case IntegerType() => Ints.Neg(toSMT(u))
376386
case RealType() => Reals.Neg(toSMT(u))
377387
}
378388

379389
case Equals(a, b) => Core.Equals(toSMT(a), toSMT(b))
380390
case Implies(a, b) => Core.Implies(toSMT(a), toSMT(b))
381-
case pl @ Plus(a, _) =>
391+
case pl @ Plus(a, b) =>
382392
val rec = flattenPlus(pl).map(toSMT)
383393
a.getType match {
384394
case BVType(_,_) => FixedSizeBitVectors.Add(rec)
395+
case FPType(_, _) => FloatingPoint.Add(FloatingPoint.RNE(), toSMT(a), toSMT(b))
385396
case IntegerType() => Ints.Add(rec)
386397
case RealType() => Reals.Add(rec)
387398
}
388399
case Minus(a, b) => a.getType match {
389400
case BVType(_,_) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b))
401+
case FPType(_,_) => FloatingPoint.Sub(FloatingPoint.RNE(), toSMT(a), toSMT(b))
390402
case IntegerType() => Ints.Sub(toSMT(a), toSMT(b))
391403
case RealType() => Reals.Sub(toSMT(a), toSMT(b))
392404
}
393-
case tms @ Times(a, _) =>
405+
case tms @ Times(a, b) =>
394406
val rec = flattenTimes(tms).map(toSMT)
395407
a.getType match {
396408
case BVType(_,_) => FixedSizeBitVectors.Mul(rec)
409+
case FPType(_,_) => FloatingPoint.Mul(FloatingPoint.RNE(), toSMT(a), toSMT(b))
397410
case IntegerType() => Ints.Mul(rec)
398411
case RealType() => Reals.Mul(rec)
399412
}
400413

401414
case Division(a, b) => a.getType match {
402415
case BVType(true, _) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b))
403416
case BVType(false, _) => FixedSizeBitVectors.UDiv(toSMT(a), toSMT(b))
417+
case FPType(_,_) => FloatingPoint.Div(FloatingPoint.RNE(), toSMT(a), toSMT(b))
404418
case IntegerType() =>
405419
val ar = toSMT(a)
406420
val br = toSMT(b)
@@ -415,6 +429,7 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
415429
case Remainder(a, b) => a.getType match {
416430
case BVType(true, _) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b))
417431
case BVType(false, _) => FixedSizeBitVectors.URem(toSMT(a), toSMT(b))
432+
case FPType(_, _) => FloatingPoint.Rem(toSMT(a), toSMT(b))
418433
case IntegerType() =>
419434
val q = toSMT(Division(a, b))
420435
Ints.Sub(toSMT(a), Ints.Mul(toSMT(b), q))
@@ -440,27 +455,31 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
440455
case LessThan(a, b) => a.getType match {
441456
case BVType(true, _) => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b))
442457
case BVType(false, _) => FixedSizeBitVectors.ULessThan(toSMT(a), toSMT(b))
458+
case FPType(_,_) => FloatingPoint.LessThan(toSMT(a), toSMT(b))
443459
case IntegerType() => Ints.LessThan(toSMT(a), toSMT(b))
444460
case RealType() => Reals.LessThan(toSMT(a), toSMT(b))
445461
case CharType() => FixedSizeBitVectors.ULessThan(toSMT(a), toSMT(b))
446462
}
447463
case LessEquals(a, b) => a.getType match {
448464
case BVType(true, _) => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b))
449465
case BVType(false, _) => FixedSizeBitVectors.ULessEquals(toSMT(a), toSMT(b))
466+
case FPType(_,_) => FloatingPoint.LessEquals(toSMT(a), toSMT(b))
450467
case IntegerType() => Ints.LessEquals(toSMT(a), toSMT(b))
451468
case RealType() => Reals.LessEquals(toSMT(a), toSMT(b))
452469
case CharType() => FixedSizeBitVectors.ULessEquals(toSMT(a), toSMT(b))
453470
}
454471
case GreaterThan(a, b) => a.getType match {
455472
case BVType(true, _) => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b))
456473
case BVType(false, _) => FixedSizeBitVectors.UGreaterThan(toSMT(a), toSMT(b))
474+
case FPType(_,_) => FloatingPoint.GreaterThan(toSMT(a), toSMT(b))
457475
case IntegerType() => Ints.GreaterThan(toSMT(a), toSMT(b))
458476
case RealType() => Reals.GreaterThan(toSMT(a), toSMT(b))
459477
case CharType() => FixedSizeBitVectors.UGreaterThan(toSMT(a), toSMT(b))
460478
}
461479
case GreaterEquals(a, b) => a.getType match {
462480
case BVType(true, _) => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b))
463481
case BVType(false, _) => FixedSizeBitVectors.UGreaterEquals(toSMT(a), toSMT(b))
482+
case FPType(_,_) => FloatingPoint.GreaterEquals(toSMT(a), toSMT(b))
464483
case IntegerType() => Ints.GreaterEquals(toSMT(a), toSMT(b))
465484
case RealType() => Reals.GreaterEquals(toSMT(a), toSMT(b))
466485
case CharType() => FixedSizeBitVectors.UGreaterEquals(toSMT(a), toSMT(b))
@@ -487,6 +506,8 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
487506
case BVUnsignedToSigned(e) => toSMT(e)
488507
case BVSignedToUnsigned(e) => toSMT(e)
489508

509+
case FPEquals(a, b) => FloatingPoint.Eq(toSMT(a), toSMT(b))
510+
490511
case And(sub) => SmtLibConstructors.and(sub.map(toSMT))
491512
case Or(sub) => SmtLibConstructors.or(sub.map(toSMT))
492513
case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze))

0 commit comments

Comments
 (0)