Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/it/scala/inox/solvers/MinimizerSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright 2009-2018 EPFL, Lausanne */

package inox
package solvers
package unrolling

class MinimizerSuite extends SolvingTestSuite with DatastructureUtils {
import trees._
import dsl._
import SolverResponses._

override def configurations = Seq(Seq(optSelectedSolvers(Set("smt-z3-min"))))

override def optionsString(options: Options): String = {
"solvr=" + options.findOptionOrDefault(optSelectedSolvers).head
}

implicit val symbols: inox.trees.Symbols = NoSymbols

val program = inox.Program(inox.trees)(symbols)

test("automated minimization of n times n") { implicit ctx =>
val x = Variable.fresh("x", Int32Type())
val prop = GreaterEquals(Times(x, x), Int32Literal(10))

val factory = SolverFactory.optimizer(program, ctx)
val optimizer = factory.getNewSolver()
try {
optimizer.assertCnstr(Not(prop))
optimizer.check(Model) match {
case SatWithModel(model) =>
model.vars.get(x.toVal).get match {
case Int32Literal(c) => assert(c == 0)
}
case _ =>
fail("Expected sat-with-model")
}
} finally {
factory.reclaim(optimizer)
}
}
}
34 changes: 33 additions & 1 deletion src/main/scala/inox/solvers/SolverFactory.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ object SolverFactory {
"smt-cvc4" -> "CVC4 through SMT-LIB",
"smt-z3" -> "Z3 through SMT-LIB",
"smt-z3-opt" -> "Z3 optimizer through SMT-LIB",
"smt-z3-min" -> "Z3 minimizer through SMT-LIB",
"smt-z3:<exec>" -> "Z3 through SMT-LIB with custom executable name",
"princess" -> "Princess with inox unrolling"
)
Expand All @@ -82,6 +83,7 @@ object SolverFactory {
"smt-cvc4" -> (() => hasCVC4, Seq("nativez3", "smt-z3", "princess"), "'cvc4' binary"),
"smt-z3" -> (() => hasZ3, Seq("nativez3", "smt-cvc4", "princess"), "'z3' binary"),
"smt-z3-opt" -> (() => hasZ3, Seq("nativez3-opt"), "'z3' binary"),
"smt-z3-min" -> (() => hasZ3, Seq("nativez3-opt"), "'z3' binary"),
"princess" -> (() => true, Seq(), "Princess solver")
)

Expand Down Expand Up @@ -280,6 +282,36 @@ object SolverFactory {
}
})

case "smt-z3-min" => create(p)(finalName, {
val chooseEnc = ChooseEncoder(p)(enc)
val fullEnc = enc andThen chooseEnc
val theoryEnc = theories.Z3(fullEnc.targetProgram)
val progEnc = fullEnc andThen theoryEnc
val targetProg = progEnc.targetProgram
val targetSem = targetProg.getSemantics

() => new {
val program: p.type = p
val context = ctx
val encoder: enc.type = enc
} with UnrollingOptimizer with TimeoutSolver {
override protected val semantics = sem
override protected val chooses: chooseEnc.type = chooseEnc
override protected val theories: theoryEnc.type = theoryEnc
override protected lazy val fullEncoder = fullEnc
override protected lazy val programEncoder = progEnc
override protected lazy val targetProgram: targetProg.type = targetProg
override protected val targetSemantics = targetSem

protected val underlying = new {
val program: progEnc.targetProgram.type = progEnc.targetProgram
val context = ctx
} with smtlib.optimization.Z3Minimizer {
val semantics: program.Semantics = targetSem
}
}
})

case _ if finalName == "smt-z3" || finalName.startsWith("smt-z3:") => create(p)(finalName, {
val chooseEnc = ChooseEncoder(p)(enc)
val fullEnc = enc andThen chooseEnc
Expand Down Expand Up @@ -410,7 +442,7 @@ object SolverFactory {
(solversOpt getOrElse optSelectedSolvers.default).toSeq match {
case Seq() => throw FatalError("No selected solver")
case Seq(single) =>
val name = if (single.endsWith("-opt")) single else single + "-opt"
val name = if (single.endsWith("-opt") || single.endsWith("-min")) single else single + "-opt"
getFromName(name, force = solversOpt.isDefined)(p, ctx)(ProgramEncoder.empty(p))(p.getSemantics).asInstanceOf[SolverFactory {
val program: p.type
type S <: Optimizer with TimeoutSolver { val program: p.type }
Expand Down
95 changes: 95 additions & 0 deletions src/main/scala/inox/solvers/smtlib/optimization/Z3Minimizer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/* Copyright 2009-2018 EPFL, Lausanne */

package inox
package solvers
package smtlib
package optimization

trait Z3Minimizer extends Z3Optimizer {
import program._
import program.trees._
import program.symbols._
import exprOps.variablesOf

/**
* Gets the 'zero' literals for the given field,
* given the types of "parents" of the expr to avoid infinite recursion.
*/
private def getZeroLiterals(field: ValDef, parentTypes: Set[Type]): Seq[Expr] = field.tpe match {
case BVType(signed, size) => Seq(BVLiteral(signed, 0, size))
case adt @ ADTType(_, _) =>
if (parentTypes.contains(adt)) Seq()
else adt.lookupSort match {
case Some(tsort) => tsort.constructors.flatMap(_.fields).distinct.flatMap(f => getZeroLiterals(f, parentTypes + adt))
case None => Seq()
}
case IntegerType() => Seq(IntegerLiteral(0))
case _ => Seq()
}

/**
* Gets the "size" of the given type.
* We arbitrarily decide that a BV is "smaller" than an int which is "smaller" than an ADT;
* one could choose a policy to decide the exact size of a BV, but this then gets into questions such as
* "at which point does a BV become bigger than an int" which have no clear answer anyway.
*/
private def sizeOf(t: Type): Int = t match {
case BVType(_, size) => 1
case IntegerType() => 5
case adt @ ADTType(_, _) => 10
case _ => 0
}

/**
* Gets the expressions that should be minimized for the given expr,
* given the types of "parents" of the expr to avoid an infinite recursion.
*/
private def sizersOf(e: Expr, parentTypes: Set[Type]): Seq[Expr] = {
/**
* Gets the expressions that should be minimized for the given ADT constructors.
* This is not as simple as it first appears because of the split between the bitvector and integer worlds;
* one cannot merely return a big "sum of all fields" expression, since adding a bitvector and an integer makes no sense.
* Furthermore, ADT constructors may themselves include ADTs as fields.
* For instance, for ADT X with ctors X1(a: BV32) and X2(a: BV32, b: Int) and ADT Y with ctors Y1(a: BV64) and Y2(b: X),
* the minimizers of y: Y are ["if y is Y1 then y.a else BV64(0)", "if y is Y1 then BV32(0) else y.b.a", "if y is Y1 then Int(0) else if y.b is X1 then Int(0) else y.b.b"],
* plus one representing the overall size of the ADT to favor "smaller" constructors such as X1 over X2.
* (The concept of "smaller" is not well-defined either, e.g., one could reasonably argue that P(a: BV16, b: BV16) is either "smaller" or "larger" than Q(c: BV64)
* due to the size in bits vs number of fields; we go with the number of fields)
*/
def adtSizers(ctors: Seq[TypedADTConstructor]): Seq[Expr] = {
// zeroes must be of the same length as what sizer returns
def rec(ctors: Seq[TypedADTConstructor], zeroes: Seq[Expr], sizer: TypedADTConstructor => Seq[Expr]): Seq[Expr] = ctors match {
case Seq() => zeroes
// not strictly necessary but nice to not have an if whose else branch is unsatisfiable
case Seq(ctor) => sizer(ctor)
case Seq(ctor, tl @ _*) => sizer(ctor).zip(rec(tl, zeroes, sizer)).map{case (a,b) => IfExpr(IsConstructor(e, ctor.id), a, b)}
}
def fieldSizer(field: ValDef)(ctor: TypedADTConstructor): Seq[Expr] = {
if (ctor.fields.contains(field)) sizersOf(ADTSelector(e, field.id), parentTypes + field.tpe) else getZeroLiterals(field, parentTypes)
}
// BV64 for the size, let's not force the use of integers if it's not necessary
val adtSize = rec(ctors, Seq(BVLiteral(false, 0, 32)), c => Seq(BVLiteral(false, c.fields.map(f => sizeOf(f.tpe)).sum, 64)))
val fieldSizes = ctors.flatMap(_.fields).distinct.filter(f => !parentTypes.contains(f.tpe)).flatMap(f => rec(ctors, getZeroLiterals(f, parentTypes), fieldSizer(f)))
adtSize ++ fieldSizes
}

e.getType match {
case BVType(signed, size) => Seq(if (signed) IfExpr(GreaterEquals(e, BVLiteral(signed, 0, size)), e, UMinus(e)) else e)
case IntegerType() => Seq(IfExpr(GreaterEquals(e, IntegerLiteral(0)), e, UMinus(e)))
case adt @ ADTType(_, _) =>
if (parentTypes.contains(adt)) Seq()
else adt.lookupSort match {
case Some(tsort) => adtSizers(tsort.constructors)
case None => Seq()
}
case _ => Seq()
}
}

override def assertCnstr(expr: Expr): Unit = {
for (freeVar <- variablesOf(expr) ; toMinimize <- sizersOf(freeVar, Set())) {
minimize(toMinimize)
}
super.assertCnstr(expr)
}
}