Skip to content

Commit c888ccf

Browse files
agilotvkuncak
authored andcommitted
Add FP arithmetic rounding modes
1 parent a0d8ec8 commit c888ccf

File tree

6 files changed

+168
-14
lines changed

6 files changed

+168
-14
lines changed

src/main/scala/inox/ast/Deconstructors.scala

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,51 @@ trait TreeDeconstructor {
307307
(NoIdentifiers, NoVariables, Seq(t1, t2), NoTypes, NoFlags,
308308
(_, _, es, _, _) => t.FPEquals(es(0), es(1)))
309309
},
310+
classOf[s.FPAdd] -> { expr =>
311+
val s.FPAdd(rm, t1, t2) = expr: @unchecked
312+
(NoIdentifiers, NoVariables, Seq(rm, t1, t2), NoTypes, NoFlags,
313+
(_, _, es, _, _) => t.FPAdd(es(0), es(1), es(2)))
314+
},
315+
classOf[s.FPSub] -> { expr =>
316+
val s.FPSub(rm, t1, t2) = expr: @unchecked
317+
(NoIdentifiers, NoVariables, Seq(rm, t1, t2), NoTypes, NoFlags,
318+
(_, _, es, _, _) => t.FPSub(es(0), es(1), es(2)))
319+
},
320+
classOf[s.FPMul] -> { expr =>
321+
val s.FPMul(rm, t1, t2) = expr: @unchecked
322+
(NoIdentifiers, NoVariables, Seq(rm, t1, t2), NoTypes, NoFlags,
323+
(_, _, es, _, _) => t.FPMul(es(0), es(1), es(2)))
324+
},
325+
classOf[s.FPDiv] -> { expr =>
326+
val s.FPDiv(rm, t1, t2) = expr: @unchecked
327+
(NoIdentifiers, NoVariables, Seq(rm, t1, t2), NoTypes, NoFlags,
328+
(_, _, es, _, _) => t.FPDiv(es(0), es(1), es(2)))
329+
},
330+
classOf[s.FPCast] -> { expr =>
331+
val s.FPCast(eb, sb, rm, e) = expr: @unchecked
332+
(NoIdentifiers, NoVariables, Seq(rm, e), NoTypes, NoFlags,
333+
(_, _, es, _, _) => t.FPCast(eb, sb, es(0), es(1)))
334+
},
335+
classOf[s.RoundTowardZero.type] -> { expr =>
336+
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
337+
(_, _, _, _, _) => t.RoundTowardZero)
338+
},
339+
classOf[s.RoundTowardNegative.type] -> { expr =>
340+
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
341+
(_, _, _, _, _) => t.RoundTowardNegative)
342+
},
343+
classOf[s.RoundTowardPositive.type] -> { expr =>
344+
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
345+
(_, _, _, _, _) => t.RoundTowardPositive)
346+
},
347+
classOf[s.RoundNearestTiesToAway.type] -> { expr =>
348+
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
349+
(_, _, _, _, _) => t.RoundNearestTiesToAway)
350+
},
351+
classOf[s.RoundNearestTiesToEven.type] -> { expr =>
352+
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
353+
(_, _, _, _, _) => t.RoundNearestTiesToEven)
354+
},
310355
classOf[s.Tuple] -> { expr =>
311356
val s.Tuple(args) = expr: @unchecked
312357
(NoIdentifiers, NoVariables, args, NoTypes, NoFlags,
@@ -473,7 +518,10 @@ trait TreeDeconstructor {
473518
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
474519
(_, _, _, _, _) => t.FPType(exponent, significand))
475520
},
476-
521+
classOf[s.RoundingMode.type] -> { tpe =>
522+
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
523+
(_, _, _, _, _) => t.RoundingMode)
524+
},
477525
// @nv: can't use `s.Untyped.getClass` as it is not yet created at this point
478526
scala.reflect.classTag[s.Untyped.type].runtimeClass -> { _ =>
479527
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags, (_, _, _, _, _) => t.Untyped)

src/main/scala/inox/ast/Expressions.scala

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,13 +641,66 @@ trait Expressions { self: Trees =>
641641
}
642642
}
643643

644-
/* FP operaions */
644+
/* FP operations */
645645

646646
sealed case class FPEquals(lhs: Expr, rhs: Expr) extends Expr with CachingTyped {
647647
override protected def computeType(using Symbols): Type =
648648
if getFPType(lhs, rhs).isTyped then BooleanType() else Untyped
649649
}
650650

651+
sealed case class FPAdd(rm: Expr, lhs: Expr, rhs: Expr) extends Expr with CachingTyped {
652+
override protected def computeType(using Symbols): Type =
653+
if getRoundingMode(rm).isTyped then getFPType(lhs, rhs) else Untyped
654+
}
655+
656+
sealed case class FPSub(rm: Expr, lhs: Expr, rhs: Expr) extends Expr with CachingTyped {
657+
override protected def computeType(using Symbols): Type =
658+
if getRoundingMode(rm).isTyped then getFPType(lhs, rhs) else Untyped
659+
}
660+
661+
sealed case class FPMul(rm: Expr, lhs: Expr, rhs: Expr) extends Expr with CachingTyped {
662+
override protected def computeType(using Symbols): Type =
663+
if getRoundingMode(rm).isTyped then getFPType(lhs, rhs) else Untyped
664+
}
665+
666+
sealed case class FPDiv(rm: Expr, lhs: Expr, rhs: Expr) extends Expr with CachingTyped {
667+
override protected def computeType(using Symbols): Type =
668+
if getRoundingMode(rm).isTyped then getFPType(lhs, rhs) else Untyped
669+
}
670+
671+
sealed case class FPCast(newExponent: Int, newSignificand: Int, rm: Expr, expr: Expr) extends Expr with CachingTyped {
672+
override protected def computeType(using Symbols): Type =
673+
if getRoundingMode(rm).isTyped &&
674+
(getFPType(expr).isTyped ||
675+
getBVType(expr).isTyped ||
676+
getRealType(expr).isTyped)
677+
then
678+
FPType(newExponent, newSignificand)
679+
else Untyped
680+
}
681+
682+
683+
/* Rounding modes */
684+
object RoundTowardZero extends Expr with CachingTyped {
685+
override protected def computeType(using Symbols): Type = RoundingMode
686+
}
687+
688+
object RoundTowardPositive extends Expr with CachingTyped {
689+
override protected def computeType(using Symbols): Type = RoundingMode
690+
}
691+
692+
object RoundTowardNegative extends Expr with CachingTyped {
693+
override protected def computeType(using Symbols): Type = RoundingMode
694+
}
695+
696+
object RoundNearestTiesToEven extends Expr with CachingTyped {
697+
override protected def computeType(using Symbols): Type = RoundingMode
698+
}
699+
700+
object RoundNearestTiesToAway extends Expr with CachingTyped {
701+
override protected def computeType(using Symbols): Type = RoundingMode
702+
}
703+
651704

652705
/* Tuple operations */
653706

src/main/scala/inox/ast/Printers.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,19 @@ trait Printer {
280280
case BVSignedToUnsigned(e) => p"$e.toUnsigned"
281281

282282
case FPEquals(l, r) => p"$l === $r"
283+
case FPCast(8, 24, _, e) => p"$e.toFloat"
284+
case FPCast(11, 53, _, e) => p"$e.toDouble"
285+
case FPCast(eb, sb, _, e) => p"$e.toBV($eb, $sb)"
286+
case FPAdd(_, e1, e2) => p"$e1 + $e2"
287+
case FPSub(_, e1, e2) => p"$e1 - $e2"
288+
case FPMul(_, e1, e2) => p"$e1 * $e2"
289+
case FPDiv(_, e1, e2) => p"$e1 / $e2"
290+
291+
case RoundTowardZero => p"RTZ"
292+
case RoundTowardPositive => p"RTP"
293+
case RoundTowardNegative => p"RTN"
294+
case RoundNearestTiesToEven => p"RNE"
295+
case RoundNearestTiesToAway => p"RNA"
283296

284297
case fs @ FiniteSet(rs, _) => p"Set(${rs})"
285298
case fs @ FiniteBag(rs, _) => p"Bag(${rs.toSeq})"
@@ -357,6 +370,10 @@ trait Printer {
357370
case CharType() => p"Char"
358371
case BooleanType() => p"Boolean"
359372
case StringType() => p"String"
373+
case Float32Type() => p"Float"
374+
case Float64Type() => p"Double"
375+
case FPType(eb, sb) => p"Float${eb + sb}"
376+
case RoundingMode => p"RoundingMode"
360377
case SetType(bt) => p"Set[$bt]"
361378
case BagType(bt) => p"Bag[$bt]"
362379
case MapType(ft, tt) => p"Map[$ft, $tt]"

src/main/scala/inox/ast/Types.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ trait Types { self: Trees =>
8787
object Float32Type extends FPTypeExtractor(8, 24)
8888
object Float64Type extends FPTypeExtractor(11, 53)
8989

90+
object RoundingMode extends Type
91+
9092
sealed case class TypeParameter(id: Identifier, flags: Seq[Flag]) extends Type {
9193
def freshen = TypeParameter(id.freshen, flags).copiedFrom(this)
9294

@@ -247,6 +249,9 @@ trait Types { self: Trees =>
247249
case _ => Untyped
248250
}
249251

252+
protected def getRoundingMode(tpe: Typed, tpes: Typed*)(using Symbols): Type =
253+
checkAllTypes(tpe +: tpes, RoundingMode, RoundingMode)
254+
250255
protected final def getCharType(tpe: Typed, tpes: Typed*)(using Symbols): Type =
251256
checkAllTypes(tpe +: tpes, CharType(), CharType())
252257

src/main/scala/inox/solvers/smtlib/SMTLIBParser.scala

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -97,17 +97,14 @@ trait SMTLIBParser {
9797
case _ => BVLiteral(true, n, size.intValue)
9898
}
9999

100-
case FloatingPoint.FPLit(sign, exponent, significand) => otpe match {
101-
case Some(FPType(eb, sb)) =>
102-
(fromSMT(sign, Some(BVType(true, 1))),
103-
fromSMT(exponent, Some(BVType(true, eb))),
104-
fromSMT(significand, Some(BVType(true, sb - 1)))) match {
105-
case (BVLiteral(true, bitset1, 1), BVLiteral(true, bitset2, `eb`), BVLiteral(true, bitset3, rem)) if rem == sb - 1 =>
106-
FPLiteral(eb, sb, bitset1.map(_ + eb + sb - 1) ++ bitset2.map(_ + sb - 1) ++ bitset3)
107-
case _ => throw new MissformedSMTException(term, "FP Lit has inconsistent components")
108-
}
109-
case _ => throw new MissformedSMTException(term, "FP Lit is not of type Float")
110-
}
100+
case FloatingPoint.FPLit(sign, exponent, significand) =>
101+
(fromSMT(sign, Some(BVType(true, 1))),
102+
fromSMT(exponent, None),
103+
fromSMT(significand, None)) match {
104+
case (BVLiteral(true, bitset1, 1), BVLiteral(true, bitset2, eb), BVLiteral(true, bitset3, sbminus1)) =>
105+
FPLiteral(eb, sbminus1 + 1, bitset1.map(_ + eb + sbminus1) ++ bitset2.map(_ + sbminus1) ++ bitset3)
106+
case _ => throw new MissformedSMTException(term, "FP Lit has inconsistent components")
107+
}
111108
case FloatingPoint.PlusZero(exponent, significand) => FPLiteral.plusZero(exponent.toInt, significand.toInt)
112109
case FloatingPoint.MinusZero(exponent, significand) => FPLiteral.minusZero(exponent.toInt, significand.toInt)
113110
case FloatingPoint.NaN(exponent, significand) => FPLiteral.NaN(exponent.toInt, significand.toInt)
@@ -215,8 +212,28 @@ trait SMTLIBParser {
215212
case _ => true
216213
}, (i + 1).bigInteger.intValueExact))
217214

218-
case FloatingPoint.Eq(e1, e2) => fromSMTUnifyType(e1, e2, None)(FPEquals.apply)
215+
case FloatingPoint.Eq(t1, t2) => fromSMTUnifyType(t1, t2, None)(FPEquals.apply)
216+
case FloatingPoint.Add(rm, t1, t2) => fromSMTUnifyType(t1, t2, otpe)((e1, e2) => FPAdd(fromSMT(rm, RoundingMode), e1, e2))
217+
case FloatingPoint.Sub(rm, t1, t2) => fromSMTUnifyType(t1, t2, otpe)((e1, e2) => FPSub(fromSMT(rm, RoundingMode), e1, e2))
218+
case FloatingPoint.Mul(rm, t1, t2) => fromSMTUnifyType(t1, t2, otpe)((e1, e2) => FPMul(fromSMT(rm, RoundingMode), e1, e2))
219+
case FloatingPoint.Div(rm, t1, t2) => fromSMTUnifyType(t1, t2, otpe)((e1, e2) => FPDiv(fromSMT(rm, RoundingMode), e1, e2))
220+
case FloatingPoint.Neg(t) => UMinus(fromSMT(t, otpe))
221+
case FloatingPoint.GreaterThan(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(GreaterThan.apply)
222+
case FloatingPoint.LessThan(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(LessThan.apply)
223+
case FloatingPoint.GreaterEquals(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(GreaterEquals.apply)
224+
case FloatingPoint.LessEquals(t1, t2) => fromSMTUnifyType(t1, t2, Some(BooleanType()))(LessEquals.apply)
225+
case FloatingPoint.ToFP(newExp, newSig, seq) =>
226+
val (rm, arg) = seq match {
227+
case Seq(t1, t2) => (fromSMT(t1, Some(RoundingMode)), fromSMT(t2, None))
228+
case Seq(t) => (RoundNearestTiesToEven, fromSMT(t, None))
229+
}
230+
FPCast(newExp.toInt, newSig.toInt, rm, arg)
219231

232+
case FloatingPoint.RoundTowardZero() => RoundTowardZero
233+
case FloatingPoint.RoundTowardPositive() => RoundTowardPositive
234+
case FloatingPoint.RoundTowardNegative() => RoundTowardNegative
235+
case FloatingPoint.RoundNearestTiesToEven() => RoundNearestTiesToEven
236+
case FloatingPoint.RoundNearestTiesToAway() => RoundNearestTiesToAway
220237

221238
case ArraysEx.Select(e1, e2) => otpe match {
222239
case Some(tpe) =>
@@ -260,6 +277,8 @@ trait SMTLIBParser {
260277
case Sort(SimpleIdentifier(SSymbol("Real")), Seq()) => RealType()
261278
case Sort(SimpleIdentifier(SSymbol("String")), Seq()) => StringType()
262279
case Sort(SimpleIdentifier(SSymbol("Array")), Seq(from, to)) => MapType(fromSMT(from), fromSMT(to))
280+
case FloatingPoint.FloatingPointSort(eb, sb) => FPType(eb.toInt, sb.toInt)
281+
case FloatingPoint.RoundingModeSort() => RoundingMode
263282
case _ => throw new MissformedSMTException(sort, "unexpected sort: " + sort)
264283
}
265284
}

src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
152152
case RealType() => Reals.RealSort()
153153
case BVType(_,l) => FixedSizeBitVectors.BitVectorSort(l)
154154
case FPType(e, s) => FloatingPoint.FloatingPointSort(e, s)
155+
case RoundingMode => FloatingPoint.RoundingModeSort()
155156
case CharType() => FixedSizeBitVectors.BitVectorSort(16)
156157
case StringType() => Strings.StringSort()
157158

@@ -507,6 +508,17 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
507508
case BVSignedToUnsigned(e) => toSMT(e)
508509

509510
case FPEquals(a, b) => FloatingPoint.Eq(toSMT(a), toSMT(b))
511+
case FPAdd(rm, a, b) => FloatingPoint.Add(toSMT(rm), toSMT(a), toSMT(b))
512+
case FPSub(rm, a, b) => FloatingPoint.Sub(toSMT(rm), toSMT(a), toSMT(b))
513+
case FPMul(rm, a, b) => FloatingPoint.Mul(toSMT(rm), toSMT(a), toSMT(b))
514+
case FPDiv(rm, a, b) => FloatingPoint.Div(toSMT(rm), toSMT(a), toSMT(b))
515+
case FPCast(ne, ns, rm, e) => FloatingPoint.ToFP(ne, ns, toSMT(rm), toSMT(e))
516+
517+
case RoundTowardZero => FloatingPoint.RoundTowardZero()
518+
case RoundTowardNegative => FloatingPoint.RoundTowardNegative()
519+
case RoundTowardPositive => FloatingPoint.RoundTowardPositive()
520+
case RoundNearestTiesToAway => FloatingPoint.RoundNearestTiesToAway()
521+
case RoundNearestTiesToEven => FloatingPoint.RoundNearestTiesToEven()
510522

511523
case And(sub) => SmtLibConstructors.and(sub.map(toSMT))
512524
case Or(sub) => SmtLibConstructors.or(sub.map(toSMT))

0 commit comments

Comments
 (0)