Skip to content

Commit 7752948

Browse files
authored
Merge pull request #147 from jad-hamza/merge-calls
Implement @samarion's fix for mergeCalls code explosion
2 parents 22de8d6 + 016234b commit 7752948

File tree

3 files changed

+90
-190
lines changed

3 files changed

+90
-190
lines changed

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ script := {
8585
|
8686
|SCALACLASSPATH=$cp
8787
|
88-
|java -Xmx2G -Xms512M -Xss64M -classpath "$${SCALACLASSPATH}" -Dscala.usejavacp=true inox.Main $$@ 2>&1
88+
|java -Xmx2G -Xms512M -Xss64M -classpath "$${SCALACLASSPATH}" -Dscala.usejavacp=true inox.Main "$$@" 2>&1
8989
|""".stripMargin)
9090
file.setExecutable(true)
9191
} catch {

src/main/scala/inox/ast/SymbolOps.scala

Lines changed: 1 addition & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,188 +1028,6 @@ trait SymbolOps { self: TypeOps =>
10281028
case _ => None
10291029
} (expr)
10301030

1031-
private def mergeFunctions(expr: Expr)(implicit opts: PurityOptions): Expr = {
1032-
type Bindings = Seq[(ValDef, Expr)]
1033-
implicit class BindingsWrapper(bindings: Bindings) {
1034-
def merge(that: Bindings): Bindings = (bindings ++ that).distinct
1035-
def wrap(that: Expr): Expr = bindings.foldRight(that) { case ((vd, e), b) =>
1036-
val freshVd = vd.freshen
1037-
val fb = replaceFromSymbols(Map(vd -> freshVd.toVariable), b)
1038-
let(freshVd, e, fb)
1039-
}
1040-
}
1041-
1042-
def liftCalls(e: Expr): Expr = {
1043-
def rec(e: Expr): (Expr, Bindings) = e match {
1044-
case Let(i, fi @ FunctionInvocation(id, tps, args), body) =>
1045-
val (recArgs, argsBindings) = args.map(rec).unzip
1046-
val (recBody, bodyBindings) = rec(body)
1047-
val recBindings = argsBindings.flatten ++ bodyBindings
1048-
(recBody, ((argsBindings.flatten :+ (i -> FunctionInvocation(id, tps, recArgs).copiedFrom(fi))) ++ recBindings).distinct)
1049-
1050-
case fi @ FunctionInvocation(id, tps, args) =>
1051-
val v = Variable.fresh("call", fi.tfd.getType, true)
1052-
val (recArgs, recBindings) = args.map(rec).unzip
1053-
(v, recBindings.flatten.distinct :+ (v.toVal -> FunctionInvocation(id, tps, recArgs).copiedFrom(fi)))
1054-
1055-
case Let(i, v, b) =>
1056-
val (vr, vBindings) = rec(v)
1057-
val (br, bBindings) = rec(b)
1058-
(br, vBindings merge ((i -> vr) +: bBindings))
1059-
1060-
case Assume(pred, body) =>
1061-
val (recPred, predBindings) = rec(pred)
1062-
val (recBody, bodyBindings) = rec(body)
1063-
(Assume(recPred, bodyBindings wrap recBody), predBindings)
1064-
1065-
case IfExpr(cond, thenn, elze) =>
1066-
val (recCond, condBindings) = rec(cond)
1067-
val (recThen, thenBindings) = rec(thenn)
1068-
val (recElse, elseBindings) = rec(elze)
1069-
(IfExpr(recCond, thenBindings wrap recThen, elseBindings wrap recElse), condBindings)
1070-
1071-
case And(e +: es) =>
1072-
val (recE, eBindings) = rec(e)
1073-
val newEs = es.map { e =>
1074-
val (recE, eBindings) = rec(e)
1075-
eBindings wrap recE
1076-
}
1077-
(And(recE +: newEs), eBindings)
1078-
1079-
case Or(e +: es) =>
1080-
val (recE, eBindings) = rec(e)
1081-
val newEs = es.map { e =>
1082-
val (recE, eBindings) = rec(e)
1083-
eBindings wrap recE
1084-
}
1085-
(Or(recE +: newEs), eBindings)
1086-
1087-
case Implies(lhs, rhs) =>
1088-
val (recLhs, lhsBindings) = rec(lhs)
1089-
val (recRhs, rhsBindings) = rec(rhs)
1090-
(Implies(recLhs, rhsBindings wrap recRhs), lhsBindings)
1091-
1092-
case v: Variable => (v, Seq.empty)
1093-
1094-
case ex @ VariableExtractor(vs) if vs.nonEmpty =>
1095-
val Operator(subs, recons) = ex
1096-
val recSubs = subs.map(rec)
1097-
(recons(recSubs.map { case (e, bindings) => bindings wrap e }), Seq.empty)
1098-
1099-
case Lambda(Seq(), body) =>
1100-
val (recBody, bodyBindings) = rec(body)
1101-
(Lambda(Seq(), bodyBindings wrap recBody), Seq.empty)
1102-
1103-
case Operator(es, recons) =>
1104-
val (recEs, esBindings) = es.map(rec).unzip
1105-
(recons(recEs), esBindings.flatten.distinct)
1106-
}
1107-
1108-
val (newE, bindings) = rec(e)
1109-
bindings wrap newE
1110-
}
1111-
1112-
def mergeCalls(e: Expr): Expr = {
1113-
def evCalls(e: Expr): Map[TypedFunDef, Set[(Path, Seq[Expr])]] = {
1114-
1115-
var inLambda = false
1116-
var pathFis: Seq[(Path, FunctionInvocation)] = Seq.empty
1117-
transformWithPC(e, false /* ignore calls within types */)((e, path, op) => e match {
1118-
case l: Lambda =>
1119-
val old = inLambda
1120-
inLambda = true
1121-
val nl = op.sup(l, path)
1122-
inLambda = old
1123-
nl
1124-
1125-
case fi: FunctionInvocation =>
1126-
def freeVars(elements: Seq[Path.Element]): Set[Variable] = {
1127-
val path = Path(elements)
1128-
path.freeVariables ++ variablesOf(fi) -- path.bound.map(_.toVariable)
1129-
}
1130-
1131-
val elements = path.elements.foldRight(Some(Seq[Path.Element]()): Option[Seq[Path.Element]]) {
1132-
case (_, None) => None
1133-
case (Path.OpenBound(vd), Some(elements)) =>
1134-
if (freeVars(elements) contains vd.toVariable) None
1135-
else Some(elements) // No need to keep open bounds as we'll flatten to a condition
1136-
case (cb @ Path.CloseBound(vd, e), Some(elements)) =>
1137-
if (exists { case fi: FunctionInvocation => true case _ => false }(e)) None
1138-
else if (freeVars(elements) contains vd.toVariable) Some(cb +: elements)
1139-
else Some(elements)
1140-
case (c @ Path.Condition(cond), Some(elements)) =>
1141-
if (exists { case fi: FunctionInvocation => true case _ => false }(cond)) None
1142-
else Some(c +: elements)
1143-
}
1144-
1145-
if ((!inLambda || isPure(fi)) && elements.isDefined) {
1146-
pathFis :+= Path(elements.get) -> fi
1147-
}
1148-
op.sup(fi, path)
1149-
1150-
case _ =>
1151-
op.sup(e, path)
1152-
})
1153-
1154-
pathFis.groupBy(_._2.tfd).mapValues(_.map(p => (p._1, p._2.args)).toSet)
1155-
}
1156-
1157-
def replace(path: Path, oldE: Expr, newE: Expr, body: Expr): Expr =
1158-
transformWithPC(body, false /* ignore calls within types */)((e, env, op) => {
1159-
if ((path.bindings.toSet subsetOf env.bindings.toSet) &&
1160-
(path.bound.toSet subsetOf env.bound.toSet) &&
1161-
(path.conditions == env.conditions) && e == oldE) {
1162-
newE
1163-
} else {
1164-
op.sup(e, env)
1165-
}
1166-
})
1167-
1168-
postMap {
1169-
case IfExpr(cond, thenn, elze) =>
1170-
val condVar = Variable.fresh("cond", BooleanType(), true)
1171-
val condPath = Path(condVar)
1172-
1173-
def rec(bindings: Bindings, thenn: Expr, elze: Expr): (Bindings, Expr, Expr) = {
1174-
val thenCalls = evCalls(thenn)
1175-
val elseCalls = evCalls(elze)
1176-
1177-
(thenCalls.keySet & elseCalls.keySet).headOption match {
1178-
case Some(tfd) =>
1179-
val (pathThen, argsThen) = thenCalls(tfd).toSeq.sortBy(_._1.elements.size).head
1180-
val (pathElse, argsElse) = elseCalls(tfd).toSeq.sortBy(_._1.elements.size).head
1181-
1182-
val v = Variable.fresh("res", tfd.getType, true)
1183-
val condThen = Variable.fresh("condThen", BooleanType(), true)
1184-
1185-
val result = IfExpr(Or(condThen, freshenLocals((condPath.negate merge pathElse).toClause)),
1186-
tfd.applied((argsThen zip argsElse).map { case (argThen, argElse) =>
1187-
ifExpr(condThen, pathThen.bindings wrap argThen, pathElse.bindings wrap argElse)
1188-
}), Choose(Variable.fresh("res", tfd.getType).toVal, BooleanLiteral(true)))
1189-
1190-
val newBindings = bindings ++ Seq(
1191-
condThen.toVal -> freshenLocals((condPath merge pathThen).toClause),
1192-
v.toVal -> result
1193-
)
1194-
1195-
val newThen = replace(pathThen, tfd.applied(argsThen), v, thenn)
1196-
val newElse = replace(pathElse, tfd.applied(argsElse), v, elze)
1197-
rec(newBindings, newThen, newElse)
1198-
1199-
case None => (bindings, thenn, elze)
1200-
}
1201-
}
1202-
1203-
val (bindings, newThen, newElse) = rec(Seq.empty, thenn, elze)
1204-
Some(((condVar.toVal -> cond) +: bindings) wrap ifExpr(condVar, newThen, newElse))
1205-
1206-
case _ => None
1207-
} (e)
1208-
}
1209-
1210-
mergeCalls(liftCalls(expr))
1211-
}
1212-
12131031
private[inox] def simplifyFormula(e: Expr)(implicit ctx: Context, sem: symbols.Semantics): Expr = {
12141032
implicit val simpOpts = SimplificationOptions(ctx)
12151033
implicit val purityOpts = PurityOptions(ctx)
@@ -1219,8 +1037,7 @@ trait SymbolOps { self: TypeOps =>
12191037
((e: Expr) => simplifyGround(e)) compose
12201038
((e: Expr) => simplifyExpr(e)) compose
12211039
((e: Expr) => simplifyForalls(e)) compose
1222-
((e: Expr) => simplifyAssumptions(e)) compose
1223-
((e: Expr) => mergeFunctions(e))
1040+
((e: Expr) => simplifyAssumptions(e))
12241041
simp(e)
12251042
} else {
12261043
e

src/main/scala/inox/solvers/unrolling/TemplateGenerator.scala

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,81 @@ trait TemplateGenerator { self: Templates =>
7575
tmplClauses + (pathVar -> p)
7676
}
7777

78+
def mergeCalls(pathVar: Variable, condVar: Variable, substMap: Map[Variable, Encoded],
79+
thenClauses: TemplateClauses, elseClauses: TemplateClauses): TemplateClauses = {
80+
val builder = new Builder(pathVar, substMap)
81+
builder ++= thenClauses
82+
builder ++= elseClauses
83+
84+
// Clear all guardedExprs in builder since we're going to transform them by merging calls.
85+
// The transformed guardedExprs will be added to builder at the end of the function.
86+
builder.guardedExprs = Map.empty
87+
88+
def collectCalls(expr: Expr): Set[FunctionInvocation] =
89+
exprOps.collect { case fi: FunctionInvocation => Set(fi) case _ => Set.empty[FunctionInvocation] }(expr)
90+
def countCalls(expr: Expr): Int =
91+
exprOps.count { case fi: FunctionInvocation => 1 case _ => 0}(expr)
92+
def replaceCall(call: FunctionInvocation, newExpr: Expr)(e: Expr): Expr =
93+
exprOps.replace(Map(call -> newExpr), e)
94+
95+
def getCalls(guardedExprs: Map[Variable, Seq[Expr]]): Map[TypedFunDef, Seq[(FunctionInvocation, Set[Variable])]] =
96+
(for { (b, es) <- guardedExprs.toSeq; e <- es; fi <- collectCalls(e) } yield (b -> fi))
97+
.groupBy(_._2)
98+
.mapValues(_.map(_._1).toSet)
99+
.toSeq
100+
.groupBy(_._1.tfd)
101+
.mapValues(_.toList.distinct.sortBy(p => countCalls(p._1))) // place inner calls first
102+
.toMap
103+
104+
var thenGuarded = thenClauses._4
105+
var elseGuarded = elseClauses._4
106+
107+
val thenCalls = getCalls(thenGuarded)
108+
val elseCalls = getCalls(elseGuarded)
109+
110+
// We sort common function calls in order to merge nested calls first.
111+
var toMerge: Seq[((FunctionInvocation, Set[Variable]), (FunctionInvocation, Set[Variable]))] =
112+
(thenCalls.keySet & elseCalls.keySet)
113+
.flatMap(tfd => thenCalls(tfd) zip elseCalls(tfd))
114+
.toSeq
115+
.sortBy(p => countCalls(p._1._1) + countCalls(p._2._1))
116+
117+
while (toMerge.nonEmpty) {
118+
val ((thenCall, thenBlockers), (elseCall, elseBlockers)) = toMerge.head
119+
toMerge = toMerge.tail
120+
121+
val newExpr: Variable = Variable.fresh("call", thenCall.tfd.getType, true)
122+
builder.storeExpr(newExpr)
123+
124+
val replaceThen = replaceCall(thenCall, newExpr) _
125+
val replaceElse = replaceCall(elseCall, newExpr) _
126+
127+
thenGuarded = thenGuarded.mapValues(_.map(replaceThen))
128+
elseGuarded = elseGuarded.mapValues(_.map(replaceElse))
129+
toMerge = toMerge.map(p => (
130+
(replaceThen(p._1._1).asInstanceOf[FunctionInvocation], p._1._2),
131+
(replaceElse(p._2._1).asInstanceOf[FunctionInvocation], p._2._2)
132+
))
133+
134+
val newBlocker: Variable = Variable.fresh("bm", BooleanType(), true)
135+
builder.storeConds(thenBlockers ++ elseBlockers, newBlocker)
136+
builder.iff(orJoin((thenBlockers ++ elseBlockers).toSeq), newBlocker)
137+
138+
val newArgs = (thenCall.args zip elseCall.args).map { case (thenArg, elseArg) =>
139+
val (newArg, argClauses) = mkExprClauses(newBlocker, ifExpr(condVar, thenArg, elseArg), builder.localSubst)
140+
builder ++= argClauses
141+
newArg
142+
}
143+
144+
val newCall = thenCall.tfd.applied(newArgs)
145+
builder.storeGuarded(newBlocker, Equals(newExpr, newCall))
146+
}
147+
148+
for ((b, es) <- thenGuarded; e <- es) builder.storeGuarded(b, e)
149+
for ((b, es) <- elseGuarded; e <- es) builder.storeGuarded(b, e)
150+
builder.result
151+
}
152+
78153
protected def mkExprStructure(
79154
pathVar: Variable,
80155
expr: Expr,
@@ -157,6 +232,13 @@ trait TemplateGenerator { self: Templates =>
157232
condTree += pathVar -> (condTree.getOrElse(pathVar, Set.empty) + id)
158233
}
159234

235+
def storeConds(pathVars: Set[Variable], id: Variable): Unit = {
236+
condVars += id -> encodeSymbol(id)
237+
for (pathVar <- pathVars) {
238+
condTree += pathVar -> (condTree.getOrElse(pathVar, Set.empty) + id)
239+
}
240+
}
241+
160242
@inline def encodedCond(id: Variable): Encoded = substMap.getOrElse(id, condVars(id))
161243

162244
var exprVars = Map[Variable, Encoded]()
@@ -367,15 +449,16 @@ trait TemplateGenerator { self: Templates =>
367449
storeExpr(condVar)
368450

369451
val crec = rec(pathVar, cond, None)
370-
val trec = rec(newBool1, thenn, pol)
371-
val erec = rec(newBool2, elze, pol)
372-
373452
storeGuarded(pathVar, Equals(condVar, crec))
374453
iff(and(pathVar, condVar), newBool1)
375454
iff(and(pathVar, not(condVar)), newBool2)
376455

377-
storeGuarded(newBool1, Equals(newExpr, trec))
378-
storeGuarded(newBool2, Equals(newExpr, erec))
456+
val (trec, tClauses) = mkExprClauses(newBool1, thenn, localSubst, pol)
457+
val (erec, eClauses) = mkExprClauses(newBool2, elze, localSubst, pol)
458+
builder ++= mergeCalls(pathVar, condVar, localSubst,
459+
tClauses + (newBool1 -> Equals(newExpr, trec)),
460+
eClauses + (newBool2 -> Equals(newExpr, erec)))
461+
379462
newExpr
380463
}
381464
}

0 commit comments

Comments
 (0)