Skip to content

Commit e48ccf7

Browse files
agilotvkuncak
authored andcommitted
Add sqrt and abs tokens for floating points
1 parent 1261e53 commit e48ccf7

File tree

7 files changed

+35
-5
lines changed

7 files changed

+35
-5
lines changed

src/main/scala/inox/ast/Deconstructors.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,16 @@ trait TreeDeconstructor {
327327
(NoIdentifiers, NoVariables, Seq(rm, t1, t2), NoTypes, NoFlags,
328328
(_, _, es, _, _) => t.FPDiv(es(0), es(1), es(2)))
329329
},
330+
classOf[s.FPAbs] -> { expr =>
331+
val s.FPAbs(e) = expr: @unchecked
332+
(NoIdentifiers, NoVariables, Seq(e), NoTypes, NoFlags,
333+
(_, _, es, _, _) => t.FPAbs(es(0)))
334+
},
335+
classOf[s.Sqrt] -> { expr =>
336+
val s.Sqrt(rm, e) = expr: @unchecked
337+
(NoIdentifiers, NoVariables, Seq(rm, e), NoTypes, NoFlags,
338+
(_, _, es, _, _) => t.Sqrt(es(0), es(1)))
339+
},
330340
classOf[s.FPCast] -> { expr =>
331341
val s.FPCast(eb, sb, rm, e) = expr: @unchecked
332342
(NoIdentifiers, NoVariables, Seq(rm, e), NoTypes, NoFlags,

src/main/scala/inox/ast/Expressions.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,15 @@ trait Expressions { self: Trees =>
668668
if getRoundingMode(rm).isTyped then getFPType(lhs, rhs) else Untyped
669669
}
670670

671+
sealed case class FPAbs(e: Expr) extends Expr with CachingTyped {
672+
override protected def computeType(using Symbols): Type = getFPType(e)
673+
}
674+
675+
sealed case class Sqrt(rm: Expr, e: Expr) extends Expr with CachingTyped {
676+
override protected def computeType(using Symbols): Type =
677+
if getRoundingMode(rm).isTyped then getFPType(e) else Untyped
678+
}
679+
671680
sealed case class FPCast(newExponent: Int, newSignificand: Int, rm: Expr, expr: Expr) extends Expr with CachingTyped {
672681
override protected def computeType(using Symbols): Type =
673682
if getRoundingMode(rm).isTyped &&

src/main/scala/inox/ast/Printers.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ trait Printer {
287287
case FPSub(_, e1, e2) => p"$e1 - $e2"
288288
case FPMul(_, e1, e2) => p"$e1 * $e2"
289289
case FPDiv(_, e1, e2) => p"$e1 / $e2"
290+
case FPAbs(e) => p"abs($e)"
291+
case Sqrt(_, e) => p"sqrt($e)"
290292
case FPIsZero(e) => p"$e == 0"
291293
case FPIsNaN(e) => p"$e.isNaN"
292294
case FPIsInfinite(e) => p"$e.isInfinite"

src/main/scala/inox/solvers/smtlib/SMTLIBParser.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,12 @@ trait SMTLIBParser {
218218
case FloatingPoint.Mul(rm, t1, t2) => fromSMTUnifyType(t1, t2, otpe)((e1, e2) => FPMul(fromSMT(rm, RoundingMode), e1, e2))
219219
case FloatingPoint.Div(rm, t1, t2) => fromSMTUnifyType(t1, t2, otpe)((e1, e2) => FPDiv(fromSMT(rm, RoundingMode), e1, e2))
220220
case FloatingPoint.Neg(t) => UMinus(fromSMT(t, otpe))
221-
case FloatingPoint.GreaterThan(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(GreaterThan.apply)
222-
case FloatingPoint.LessThan(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(LessThan.apply)
223-
case FloatingPoint.GreaterEquals(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(GreaterEquals.apply)
224-
case FloatingPoint.LessEquals(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(LessEquals.apply)
221+
case FloatingPoint.Abs(t) => FPAbs(fromSMT(t, otpe))
222+
case FloatingPoint.Sqrt(rm, t) => Sqrt(fromSMT(rm, RoundingMode), fromSMT(t, otpe))
223+
case FloatingPoint.GreaterThan(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(GreaterThan.apply)
224+
case FloatingPoint.LessThan(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(LessThan.apply)
225+
case FloatingPoint.GreaterEquals(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(GreaterEquals.apply)
226+
case FloatingPoint.LessEquals(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(LessEquals.apply)
225227
case FloatingPoint.ToFP(newExp, newSig, seq) =>
226228
val (rm, arg) = seq match {
227229
case Seq(t1, t2) => (fromSMT(t1, Some(RoundingMode)), fromSMT(t2, None))

src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,8 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
512512
case FPSub(rm, a, b) => FloatingPoint.Sub(toSMT(rm), toSMT(a), toSMT(b))
513513
case FPMul(rm, a, b) => FloatingPoint.Mul(toSMT(rm), toSMT(a), toSMT(b))
514514
case FPDiv(rm, a, b) => FloatingPoint.Div(toSMT(rm), toSMT(a), toSMT(b))
515+
case FPAbs(e) => FloatingPoint.Abs(toSMT(e))
516+
case Sqrt(rm, e) => FloatingPoint.Sqrt(toSMT(rm), toSMT(e))
515517
case FPCast(ne, ns, rm, e) => FloatingPoint.ToFP(ne, ns, toSMT(rm), toSMT(e))
516518
case FPIsInfinite(e) => FloatingPoint.IsInfinite(toSMT(e))
517519
case FPIsZero(e) => FloatingPoint.IsZero(toSMT(e))

src/main/scala/inox/utils/Serialization.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ class InoxSerializer(val trees: ast.Trees, serializeProducts: Boolean = false) e
476476
* The `Serializer[_]` identifiers in this mapping range from 10 to 105
477477
* (ignoring special identifiers that are smaller than 10).
478478
*
479-
* NEXT ID: 109
479+
* NEXT ID: 127
480480
*/
481481
protected def classSerializers: Map[Class[?], Serializer[?]] = Map(
482482
// Inox Expressions
@@ -559,6 +559,8 @@ class InoxSerializer(val trees: ast.Trees, serializeProducts: Boolean = false) e
559559
classSerializer[FPMul] (112),
560560
classSerializer[FPDiv] (113),
561561
classSerializer[FPCast] (114),
562+
classSerializer[FPAbs] (125),
563+
classSerializer[Sqrt] (126),
562564

563565
classSerializer[RoundTowardZero.type] (115),
564566
classSerializer[RoundTowardPositive.type] (116),

src/test/scala/inox/solvers/SemanticsSuite.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,7 @@ class SemanticsSuite extends AnyFunSuite {
426426

427427
for (i <- floatValues) {
428428
check(s, UMinus(Float32Literal(i)), Float32Literal(-i))
429+
check(s, FPAbs(Float32Literal(i)), Float32Literal(Math.abs(i)))
429430
}
430431

431432
for (i <- doubleValues; j <- doubleValues) {
@@ -437,6 +438,8 @@ class SemanticsSuite extends AnyFunSuite {
437438

438439
for (i <- doubleValues) {
439440
check(s, UMinus(Float64Literal(i)), Float64Literal(-i))
441+
check(s, Sqrt(RoundNearestTiesToEven, Float64Literal(i)), Float64Literal(Math.sqrt(i)))
442+
check(s, FPAbs(Float64Literal(i)), Float64Literal(Math.abs(i)))
440443
}
441444

442445
}

0 commit comments

Comments
 (0)