Skip to content

Commit 1261e53

Browse files
agilotvkuncak
authored andcommitted
Add predicates for floating points
1 parent 4f56b0b commit 1261e53

File tree

7 files changed

+120
-2
lines changed

7 files changed

+120
-2
lines changed

src/main/scala/inox/ast/Deconstructors.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,31 @@ trait TreeDeconstructor {
332332
(NoIdentifiers, NoVariables, Seq(rm, e), NoTypes, NoFlags,
333333
(_, _, es, _, _) => t.FPCast(eb, sb, es(0), es(1)))
334334
},
335+
classOf[s.FPIsZero] -> { expr =>
336+
val s.FPIsZero(e) = expr: @unchecked
337+
(NoIdentifiers, NoVariables, Seq(e), NoTypes, NoFlags,
338+
(_, _, es, _, _) => t.FPIsZero(es(0)))
339+
},
340+
classOf[s.FPIsNaN] -> { expr =>
341+
val s.FPIsNaN(e) = expr: @unchecked
342+
(NoIdentifiers, NoVariables, Seq(e), NoTypes, NoFlags,
343+
(_, _, es, _, _) => t.FPIsNaN(es(0)))
344+
},
345+
classOf[s.FPIsPositive] -> { expr =>
346+
val s.FPIsPositive(e) = expr: @unchecked
347+
(NoIdentifiers, NoVariables, Seq(e), NoTypes, NoFlags,
348+
(_, _, es, _, _) => t.FPIsPositive(es(0)))
349+
},
350+
classOf[s.FPIsNegative] -> { expr =>
351+
val s.FPIsNegative(e) = expr: @unchecked
352+
(NoIdentifiers, NoVariables, Seq(e), NoTypes, NoFlags,
353+
(_, _, es, _, _) => t.FPIsNegative(es(0)))
354+
},
355+
classOf[s.FPIsInfinite] -> { expr =>
356+
val s.FPIsInfinite(e) = expr: @unchecked
357+
(NoIdentifiers, NoVariables, Seq(e), NoTypes, NoFlags,
358+
(_, _, es, _, _) => t.FPIsInfinite(es(0)))
359+
},
335360
classOf[s.RoundTowardZero.type] -> { expr =>
336361
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
337362
(_, _, _, _, _) => t.RoundTowardZero)

src/main/scala/inox/ast/Expressions.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,31 @@ trait Expressions { self: Trees =>
679679
else Untyped
680680
}
681681

682+
sealed case class FPIsZero(e: Expr) extends Expr with CachingTyped {
683+
override protected def computeType(using Symbols): Type =
684+
if getFPType(e).isTyped then BooleanType() else Untyped
685+
}
686+
687+
sealed case class FPIsInfinite(e: Expr) extends Expr with CachingTyped {
688+
override protected def computeType(using Symbols): Type =
689+
if getFPType(e).isTyped then BooleanType() else Untyped
690+
}
691+
692+
sealed case class FPIsNaN(e: Expr) extends Expr with CachingTyped {
693+
override protected def computeType(using Symbols): Type =
694+
if getFPType(e).isTyped then BooleanType() else Untyped
695+
}
696+
697+
sealed case class FPIsNegative(e: Expr) extends Expr with CachingTyped {
698+
override protected def computeType(using Symbols): Type =
699+
if getFPType(e).isTyped then BooleanType() else Untyped
700+
}
701+
702+
sealed case class FPIsPositive(e: Expr) extends Expr with CachingTyped {
703+
override protected def computeType(using Symbols): Type =
704+
if getFPType(e).isTyped then BooleanType() else Untyped
705+
}
706+
682707

683708
/* Rounding modes */
684709
object RoundTowardZero extends Expr with CachingTyped {

src/main/scala/inox/ast/Printers.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ trait Printer {
287287
case FPSub(_, e1, e2) => p"$e1 - $e2"
288288
case FPMul(_, e1, e2) => p"$e1 * $e2"
289289
case FPDiv(_, e1, e2) => p"$e1 / $e2"
290+
case FPIsZero(e) => p"$e == 0"
291+
case FPIsNaN(e) => p"$e.isNaN"
292+
case FPIsInfinite(e) => p"$e.isInfinite"
293+
case FPIsNegative(e) => p"$e.isNegative"
294+
case FPIsPositive(e) => p"$e.isPositive"
290295

291296
case RoundTowardZero => p"RTZ"
292297
case RoundTowardPositive => p"RTP"

src/main/scala/inox/solvers/smtlib/SMTLIBParser.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,11 @@ trait SMTLIBParser {
228228
case Seq(t) => (RoundNearestTiesToEven, fromSMT(t, None))
229229
}
230230
FPCast(newExp.toInt, newSig.toInt, rm, arg)
231+
case FloatingPoint.IsNaN(t) => FPIsNaN(fromSMT(t, None))
232+
case FloatingPoint.IsZero(t) => FPIsZero(fromSMT(t, None))
233+
case FloatingPoint.IsPositive(t) => FPIsPositive(fromSMT(t, None))
234+
case FloatingPoint.IsNegative(t) => FPIsNegative(fromSMT(t, None))
235+
case FloatingPoint.IsInfinite(t) => FPIsInfinite(fromSMT(t, None))
231236

232237
case FloatingPoint.RoundTowardZero() => RoundTowardZero
233238
case FloatingPoint.RoundTowardPositive() => RoundTowardPositive

src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,11 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
513513
case FPMul(rm, a, b) => FloatingPoint.Mul(toSMT(rm), toSMT(a), toSMT(b))
514514
case FPDiv(rm, a, b) => FloatingPoint.Div(toSMT(rm), toSMT(a), toSMT(b))
515515
case FPCast(ne, ns, rm, e) => FloatingPoint.ToFP(ne, ns, toSMT(rm), toSMT(e))
516+
case FPIsInfinite(e) => FloatingPoint.IsInfinite(toSMT(e))
517+
case FPIsZero(e) => FloatingPoint.IsZero(toSMT(e))
518+
case FPIsNaN(e) => FloatingPoint.IsNaN(toSMT(e))
519+
case FPIsPositive(e) => FloatingPoint.IsPositive(toSMT(e))
520+
case FPIsNegative(e) => FloatingPoint.IsNegative(toSMT(e))
516521

517522
case RoundTowardZero => FloatingPoint.RoundTowardZero()
518523
case RoundTowardNegative => FloatingPoint.RoundTowardNegative()

src/main/scala/inox/utils/Serialization.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,12 @@ class InoxSerializer(val trees: ast.Trees, serializeProducts: Boolean = false) e
566566
classSerializer[RoundNearestTiesToEven.type] (118),
567567
classSerializer[RoundNearestTiesToAway.type] (119),
568568

569+
classSerializer[FPIsZero] (120),
570+
classSerializer[FPIsInfinite] (121),
571+
classSerializer[FPIsNaN] (122),
572+
classSerializer[FPIsPositive] (123),
573+
classSerializer[FPIsNegative] (124),
574+
569575
// Inox Types
570576
classSerializer[Untyped.type] (75),
571577
classSerializer[BooleanType] (76),

src/test/scala/inox/solvers/SemanticsSuite.scala

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,9 @@ class SemanticsSuite extends AnyFunSuite {
395395
check(s, LessThan(FractionLiteral(4, 2), FractionLiteral(7, 1)), BooleanLiteral(true))
396396
}
397397

398-
val floatValues: Seq[Float] = Seq(0f, -0f, 0.1f, -6.7f, Float.NaN, Float.MinValue, Float.MinValue, Float.PositiveInfinity, Float.NegativeInfinity)
399-
val doubleValues: Seq[Double] = Seq(0d, -0d, 0.1d, -6.7d, Double.NaN, Double.MinValue, Double.MinValue, Double.PositiveInfinity, Double.NegativeInfinity)
398+
import scala.collection.immutable.HashSet
399+
val floatValues: Set[Float] = HashSet(0f, -0f, 0.1f, -6.7f, Float.NaN, Float.MinValue, Float.MinValue, Float.PositiveInfinity, Float.NegativeInfinity)
400+
val doubleValues: Set[Double] = HashSet(0d, -0d, 0.1d, -6.7d, Double.NaN, Double.MinValue, Double.MinValue, Double.PositiveInfinity, Double.NegativeInfinity)
400401

401402

402403
test("Floating point literals", filterSolvers(_, princess = true)) { ctx =>
@@ -453,6 +454,28 @@ class SemanticsSuite extends AnyFunSuite {
453454
check(s, LessThan(Float32Literal(i), Float32Literal(j)), BooleanLiteral(i < j))
454455
}
455456

457+
for (i <- floatValues.excl(0)) {
458+
check(s, FPIsNegative(Float32Literal(i)), BooleanLiteral(i < 0))
459+
}
460+
check(s, FPIsNegative(Float32Literal(-0)), BooleanLiteral(true))
461+
check(s, FPIsNegative(Float32Literal(0)), BooleanLiteral(false))
462+
463+
for (i <- floatValues.excl(0)) {
464+
check(s, FPIsPositive(Float32Literal(i)), BooleanLiteral(i > 0))
465+
}
466+
check(s, FPIsPositive(Float32Literal(0)), BooleanLiteral(true))
467+
check(s, FPIsPositive(Float32Literal(-0)), BooleanLiteral(false))
468+
469+
for (i <- floatValues) {
470+
check(s, FPIsInfinite(Float32Literal(i)), BooleanLiteral(i == Float.PositiveInfinity || i == Float.NegativeInfinity))
471+
}
472+
473+
for (i <- floatValues) {
474+
check(s, FPIsZero(Float32Literal(i)), BooleanLiteral(i == 0))
475+
}
476+
477+
check(s, FPIsNaN(Float32Literal(Float.NaN)), BooleanLiteral(true))
478+
456479
for (i <- doubleValues; j <- doubleValues) {
457480
check(s, FPEquals(Float64Literal(i), Float64Literal(j)), BooleanLiteral(Float64Literal(i).semEquals(Float64Literal(j))))
458481
check(s, FPEquals(Float64Literal(i), Float64Literal(j)), BooleanLiteral(i == j))
@@ -462,6 +485,30 @@ class SemanticsSuite extends AnyFunSuite {
462485
check(s, LessThan(Float64Literal(i), Float64Literal(j)), BooleanLiteral(i < j))
463486
}
464487

488+
for (i <- doubleValues.excl(0)) {
489+
check(s, FPIsNegative(Float64Literal(i)), BooleanLiteral(i < 0))
490+
}
491+
check(s, FPIsNegative(Float64Literal(-0)), BooleanLiteral(true))
492+
check(s, FPIsNegative(Float64Literal(0)), BooleanLiteral(false))
493+
494+
for (i <- doubleValues.excl(0)) {
495+
check(s, FPIsPositive(Float64Literal(i)), BooleanLiteral(i > 0))
496+
}
497+
check(s, FPIsPositive(Float64Literal(0)), BooleanLiteral(true))
498+
check(s, FPIsPositive(Float64Literal(-0)), BooleanLiteral(false))
499+
500+
for (i <- doubleValues) {
501+
check(s, FPIsInfinite(Float64Literal(i)), BooleanLiteral(i == Double.PositiveInfinity || i == Double.NegativeInfinity))
502+
}
503+
504+
for (i <- doubleValues) {
505+
check(s, FPIsZero(Float64Literal(i)), BooleanLiteral(i == 0))
506+
}
507+
508+
check(s, FPIsNaN(Float64Literal(Double.NaN)), BooleanLiteral(true))
509+
510+
511+
465512
}
466513
test("Let") { ctx =>
467514
val s = solver(ctx)

0 commit comments

Comments
 (0)