Skip to content

Commit 7ff4e16

Browse files
authored
Fix polymorphism boxing for the case of a scrutinee of type Nothing (#1108)
1 parent 6a5a659 commit 7ff4e16

File tree

8 files changed

+128
-38
lines changed

8 files changed

+128
-38
lines changed

effekt/jvm/src/test/scala/effekt/core/OptimizerTests.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class OptimizerTests extends CoreTests {
3636
def normalize(input: String, expected: String)(using munit.Location) =
3737
assertTransformsTo(input, expected) { tree =>
3838
val anfed = BindSubexpressions.transform(tree)
39-
val normalized = Normalizer.normalize(Set(mainSymbol), anfed, 50, false)
39+
val normalized = Normalizer.normalize(Set(mainSymbol), anfed, 50)
4040
Deadcode.remove(mainSymbol, normalized)
4141
}
4242

effekt/shared/src/main/scala/effekt/core/PolymorphismBoxing.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,10 @@ object PolymorphismBoxing extends Phase[CoreTransformed, CoreTransformed] {
232232
Stmt.If(transform(cond), transform(thn), transform(els))
233233
case Stmt.Match(scrutinee, clauses, default) =>
234234
scrutinee.tpe match {
235-
case ValueType.Data(symbol, targs) =>
236-
val Declaration.Data(tpeId, tparams, constructors) = DeclarationContext.getData(symbol)
235+
// if the scrutinee has type Nothing, then there shouldn't be any clauses...
236+
case Type.TBottom => Stmt.Match(transform(scrutinee), Nil, None)
237+
case ValueType.Data(tpeId, targs) =>
238+
val Declaration.Data(_, tparams, constructors) = DeclarationContext.getData(tpeId)
237239
Stmt.Match(transform(scrutinee), clauses.map {
238240
case (id, clause: Block.BlockLit) =>
239241
val constructor = constructors.find(_.id == id).get

effekt/shared/src/main/scala/effekt/core/optimizer/Normalizer.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ object Normalizer { normal =>
3636
decls: DeclarationContext, // for field selection
3737
usage: mutable.Map[Id, Usage], // mutable in order to add new information after renaming
3838
maxInlineSize: Int, // to control inlining and avoid code bloat
39-
preserveBoxing: Boolean // for LLVM, prevents some optimizations
4039
) {
4140
def bind(id: Id, expr: Expr): Context = copy(exprs = exprs + (id -> expr))
4241
def bind(id: Id, block: Block): Context = copy(blocks = blocks + (id -> block))
@@ -69,14 +68,14 @@ object Normalizer { normal =>
6968
private def isUnused(id: Id)(using ctx: Context): Boolean =
7069
ctx.usage.get(id).forall { u => u == Usage.Never }
7170

72-
def normalize(entrypoints: Set[Id], m: ModuleDecl, maxInlineSize: Int, preserveBoxing: Boolean): ModuleDecl = {
71+
def normalize(entrypoints: Set[Id], m: ModuleDecl, maxInlineSize: Int): ModuleDecl = {
7372
// usage information is used to detect recursive functions (and not inline them)
7473
val usage = Reachable(entrypoints, m)
7574

7675
val defs = m.definitions.collect {
7776
case Toplevel.Def(id, block) => id -> block
7877
}.toMap
79-
val context = Context(defs, Map.empty, DeclarationContext(m.declarations, m.externs), mutable.Map.from(usage), maxInlineSize, preserveBoxing)
78+
val context = Context(defs, Map.empty, DeclarationContext(m.declarations, m.externs), mutable.Map.from(usage), maxInlineSize)
8079

8180
val (normalizedDefs, _) = normalizeToplevel(m.definitions)(using context)
8281
m.copy(definitions = normalizedDefs)
@@ -175,7 +174,7 @@ object Normalizer { normal =>
175174
case Stmt.Let(id, tpe, expr, body) =>
176175
active(expr) match {
177176
// [[ val x = ABORT; body ]] = ABORT
178-
case abort if !C.preserveBoxing && abort.tpe == Type.TBottom =>
177+
case abort if abort.tpe == Type.TBottom =>
179178
Stmt.Let(id, tpe, abort, Return(ValueVar(id, tpe)))
180179

181180
case normalized =>
@@ -213,6 +212,9 @@ object Normalizer { normal =>
213212
Stmt.Invoke(normalized.shared, method, methodTpe, targs, vargs.map(normalize), bargs.map(normalize))
214213
}
215214

215+
case Stmt.Match(scrutinee, clauses, default) if scrutinee.tpe == Type.TBottom =>
216+
Stmt.Return(normalize(scrutinee))
217+
216218
case Stmt.Match(scrutinee, clauses, default) => active(scrutinee) match {
217219
case Pure.Make(data, tag, targs, vargs) if clauses.exists { case (id, _) => id == tag } =>
218220
val clause: BlockLit = clauses.collectFirst { case (id, cl) if id == tag => cl }.get
@@ -238,12 +240,17 @@ object Normalizer { normal =>
238240
def normalizeVal(id: Id, tpe: ValueType, binding: Stmt, body: Stmt): Stmt = normalize(binding) match {
239241

240242
// [[ val x = ABORT; body ]] = ABORT
241-
case abort if !C.preserveBoxing && abort.tpe == Type.TBottom =>
243+
case abort if abort.tpe == Type.TBottom =>
242244
abort
243245

244-
case abort @ Stmt.Shift(p, BlockLit(tparams, cparams, vparams, List(k), body))
245-
if !C.preserveBoxing && !Variables.free(body).containsBlock(k.id) =>
246-
abort
246+
// [[ val x: A = shift(p) { {k: A => R} => body2 }; body: B ]] = shift(p) { {k: >>>B<<< => R} => body2 }
247+
case abort @ Stmt.Shift(p, BlockLit(tparams, cparams, vparams,
248+
BlockParam(k, BlockType.Interface(Type.ResumeSymbol, List(tpeA, answer)), captures) :: Nil, body2))
249+
if !Variables.free(body2).containsBlock(k) =>
250+
val tpeB = body.tpe
251+
Stmt.Shift(p,
252+
BlockLit(tparams, cparams, vparams, BlockParam(k, BlockType.Interface(Type.ResumeSymbol, List(tpeB, answer)), captures) :: Nil,
253+
normalize(body2)))
247254

248255
// [[ val x = sc match { case id(ps) => body2 }; body ]] = sc match { case id(ps) => val x = body2; body }
249256
case Stmt.Match(sc, List((id2, BlockLit(tparams2, cparams2, vparams2, bparams2, body2))), None) =>

effekt/shared/src/main/scala/effekt/core/optimizer/Optimizer.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ object Optimizer extends Phase[CoreTransformed, CoreTransformed] {
2121

2222
def optimize(source: Source, mainSymbol: symbols.Symbol, core: ModuleDecl)(using Context): ModuleDecl =
2323

24-
val isLLVM = Context.config.backend().name == "llvm"
25-
2624
var tree = core
2725

2826
// (1) first thing we do is simply remove unused definitions (this speeds up all following analysis and rewrites)
@@ -39,7 +37,7 @@ object Optimizer extends Phase[CoreTransformed, CoreTransformed] {
3937

4038
def normalize(m: ModuleDecl) = {
4139
val anfed = BindSubexpressions.transform(m)
42-
val normalized = Normalizer.normalize(Set(mainSymbol), anfed, Context.config.maxInlineSize().toInt, isLLVM)
40+
val normalized = Normalizer.normalize(Set(mainSymbol), anfed, Context.config.maxInlineSize().toInt)
4341
val live = Deadcode.remove(mainSymbol, normalized)
4442
val tailRemoved = RemoveTailResumptions(live)
4543
val contified = DirectStyle.rewrite(tailRemoved)

effekt/shared/src/main/scala/effekt/generator/llvm/LLVM.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class LLVM extends Compiler[String] {
3939
// The Compilation Pipeline
4040
// ------------------------
4141
// Source => Core => Machine => LLVM
42-
lazy val Compile = allToCore(Core) andThen Aggregate andThen optimizer.Optimizer andThen core.PolymorphismBoxing andThen Machine map {
42+
lazy val Compile = steps.afterCore andThen Machine map {
4343
case (mod, main, prog) => (mod, llvm.Transformer.transform(prog))
4444
}
4545

examples/llvm/failtooption.effekt

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,27 @@
11
effect Fail(): Unit
22

33
type OptionInt {
4-
None(); Some(i: Int)
5-
}
6-
def println(oi: OptionInt): Unit = {
7-
oi match {
8-
case None() => println("None")
9-
case Some(i) => println(i)
10-
}
4+
None()
5+
Some(i: Int)
116
}
127

13-
def runFail{ f : => Int / Fail }: OptionInt = {
14-
try {
15-
Some(f())
16-
} with Fail {
17-
None()
18-
}
19-
}
8+
def println(oi: OptionInt): Unit =
9+
oi match {
10+
case None() => println("None")
11+
case Some(i) => println(i)
12+
}
2013

21-
def safeDiv(x: Int, y: Int): Int / Fail = {
22-
if(y == 0){
23-
do Fail(); 0
24-
} else {
25-
x/y
26-
}
27-
}
14+
def runFail { f : => Int / Fail }: OptionInt =
15+
try Some(f()) with Fail { None() }
16+
17+
def safeDiv(x: Int, y: Int): Int / Fail =
18+
if (y == 0) {
19+
do Fail(); 0
20+
} else {
21+
x / y
22+
}
2823

2924
def main() = {
30-
println(runFail { safeDiv(6,3) })
31-
println(runFail { safeDiv(2,0) })
25+
println(runFail { safeDiv(6, 3) })
26+
println(runFail { safeDiv(2, 0) })
3227
}

examples/llvm/matchnothing.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
11

examples/llvm/matchnothing.effekt

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import args
2+
3+
type Either[A, B] {
4+
Left(a: A)
5+
Right(b: B)
6+
}
7+
8+
def enum_from_then_to(from: Int, then: Int, to: Int): List[Int] =
9+
if (from <= to) {
10+
Cons(from, enum_from_then_to(then, (2 * then) - from, to))
11+
} else {
12+
Nil()
13+
}
14+
15+
def apply_op_inner(ls: List[Int], a: Int) { op: (Int, Int) => Either[Int, Bool] }: List[Either[Int, Bool]] =
16+
ls match {
17+
case Nil() => Nil()
18+
case Cons(b, t2) => Cons(op(a, b), apply_op_inner(t2, a) { op })
19+
}
20+
21+
def apply_op(ls: List[Int], astart: Int, astep: Int, alim: Int) { op: (Int, Int) => Either[Int, Bool] }: List[Either[Int, Bool]] =
22+
ls match {
23+
case Nil() => Nil()
24+
case Cons(a, t1) =>
25+
append(
26+
apply_op_inner(enum_from_then_to(astart, astart + astep, alim), a) { op },
27+
apply_op(t1, astart, astep, alim) { op }
28+
)
29+
}
30+
31+
def integerbench(astart: Int, astep: Int, alim: Int) { op: (Int, Int) => Either[Int, Bool] }: List[Either[Int, Bool]] =
32+
apply_op(enum_from_then_to(astart, astart + astep, alim), astart, astep, alim) { op }
33+
34+
def runbench(astart: Int, astep: Int, alim: Int) { jop: (Int, Int) => Either[Int, Bool] }: List[Either[Int, Bool]] = {
35+
val res1 = integerbench(astart, astep, alim) { jop }
36+
integerbench(astart, astep, alim) { jop }
37+
}
38+
39+
def runalltests(astart: Int, astep: Int, alim: Int): List[Either[Int, Bool]] = {
40+
def z_add(a: Int, b: Int): Either[Int, Bool] = { Left(a + b) }
41+
def z_sub(a: Int, b: Int): Either[Int, Bool] = { Left(a - b) }
42+
def z_mul(a: Int, b: Int): Either[Int, Bool] = { Left(a * b) }
43+
def z_div(a: Int, b: Int): Either[Int, Bool] = { Left(a / b) }
44+
def z_mod(a: Int, b: Int): Either[Int, Bool] = { Left(mod(a, b)) }
45+
def z_equal(a: Int, b: Int): Either[Int, Bool] = { Right(a == b) }
46+
def z_lt(a: Int, b: Int): Either[Int, Bool] = { Right(a < b) }
47+
def z_leq(a: Int, b: Int): Either[Int, Bool] = { Right(a <= b) }
48+
def z_gt(a: Int, b: Int): Either[Int, Bool] = { Right(a > b) }
49+
def z_geq(a: Int, b: Int): Either[Int, Bool] = { Right(a >= b) }
50+
51+
val add = runbench(astart, astep, alim) { z_add }
52+
val sub = runbench(astart, astep, alim) { z_sub }
53+
val mul = runbench(astart, astep, alim) { z_mul }
54+
val div = runbench(astart, astep, alim) { z_div }
55+
val mod = runbench(astart, astep, alim) { z_mod }
56+
val equal = runbench(astart, astep, alim) { z_equal }
57+
val lt = runbench(astart, astep, alim) { z_lt }
58+
val leq = runbench(astart, astep, alim) { z_leq }
59+
val gt = runbench(astart, astep, alim) { z_gt }
60+
val geq = runbench(astart, astep, alim) { z_geq }
61+
runbench(astart, astep, alim) { z_geq }
62+
}
63+
64+
def test_integer_nofib(n: Int): List[Either[Int, Bool]] =
65+
runalltests(-2100000000, n, 2100000000)
66+
67+
def print_either(e: Either[Int, Bool]): Unit =
68+
e match {
69+
case Left(i) => println(i)
70+
case Right(b) =>
71+
if (b) {
72+
println("11")
73+
} else {
74+
println("00")
75+
}
76+
}
77+
78+
def main_loop(iters: Int, n: Int): Unit = {
79+
val res = test_integer_nofib(n)
80+
if (iters == 1) {
81+
print_either(option::getOrElse(list::headOption(res)) { panic("Empty List") })
82+
} else {
83+
main_loop(iters - 1, n)
84+
}
85+
}
86+
87+
def main() = main_loop(1, 700000001)

0 commit comments

Comments
 (0)