@@ -6,6 +6,7 @@ import effekt.context.Context
6
6
import effekt .context .assertions .*
7
7
import effekt .cps .*
8
8
import effekt .core .{ DeclarationContext , Id }
9
+ import effekt .cps .Variables .{ all , free }
9
10
10
11
import scala .collection .mutable
11
12
@@ -21,45 +22,31 @@ object TransformerCps extends Transformer {
21
22
val DEALLOC = Variable (JSName (" DEALLOC" ))
22
23
val TRAMPOLINE = Variable (JSName (" TRAMPOLINE" ))
23
24
24
- class Used (var used : Boolean )
25
- case class DefInfo (id : Id , vparams : List [Id ], bparams : List [Id ], ks : Id , k : Id , used : Used )
26
-
27
- object DefInfo {
28
- def unapply (b : cps.Block )(using C : TransformerContext ): Option [(Id , List [Id ], List [Id ], Id , Id , Used )] = b match {
29
- case cps.Block .BlockVar (id) => C .definitions.get(id) match {
30
- case Some (DefInfo (id, vparams, bparams, ks, k, used)) => Some ((id, vparams, bparams, ks, k, used))
31
- case None => None
32
- }
33
- case _ => None
34
- }
35
- }
25
+ class RecursiveUsage (var jumped : Boolean )
26
+ case class RecursiveDefInfo (id : Id , vparams : List [Id ], bparams : List [Id ], ks : Id , k : Id , used : RecursiveUsage )
27
+ case class ContinuationInfo (k : Id , param : Id , ks : Id )
36
28
37
29
case class TransformerContext (
38
30
requiresThunk : Boolean ,
31
+ // known definitions of expressions (used to inline into externs)
39
32
bindings : Map [Id , js.Expr ],
40
33
// definitions of externs (used to inline them)
41
34
externs : Map [Id , cps.Extern .Def ],
42
- // currently, lexically enclosing functions and their parameters (used to determine whether a call is recursive, to rewrite into a loop)
43
- definitions : Map [Id , DefInfo ],
35
+ // the innermost (in direct style) enclosing functions (used to rewrite a definition to a loop)
36
+ recursive : Option [RecursiveDefInfo ],
37
+ // the direct-style continuation, if available (used in case cps.Stmt.LetCont)
38
+ directStyle : Option [ContinuationInfo ],
39
+ // the current direct-style metacontinuation
40
+ metacont : Option [Id ],
41
+ // substitutions for renaming of metaconts (to avoid rebinding them)
42
+ metaconts : Map [Id , Id ],
44
43
// the original declaration context (used to compile pattern matching)
45
44
declarations : DeclarationContext ,
46
45
// the usual compiler context
47
46
errors : Context
48
47
)
49
48
implicit def autoContext (using C : TransformerContext ): Context = C .errors
50
49
51
- def lookup (id : Id )(using C : TransformerContext ): js.Expr = C .bindings.getOrElse(id, nameRef(id))
52
-
53
- def enterDefinition (id : Id , used : Used , block : cps.Block )(using C : TransformerContext ): TransformerContext = block match {
54
- case cps.BlockLit (vparams, bparams, ks, k, body) =>
55
- C .copy(definitions = Map (id -> DefInfo (id, vparams, bparams, ks, k, used)))
56
- case _ => C
57
- }
58
-
59
- def clearDefinitions (using C : TransformerContext ): TransformerContext = C .copy(definitions = Map .empty)
60
-
61
- def bindingAll [R ](bs : List [(Id , js.Expr )])(body : TransformerContext ?=> R )(using C : TransformerContext ): R =
62
- body(using C .copy(bindings = C .bindings ++ bs))
63
50
64
51
/**
65
52
* Entrypoint used by the compiler to compile whole programs
@@ -78,6 +65,9 @@ object TransformerCps extends Transformer {
78
65
false ,
79
66
Map .empty,
80
67
externs.collect { case d : Extern .Def => (d.id, d) }.toMap,
68
+ None ,
69
+ None ,
70
+ None ,
81
71
Map .empty,
82
72
D , C )
83
73
@@ -100,6 +90,9 @@ object TransformerCps extends Transformer {
100
90
false ,
101
91
Map .empty,
102
92
input.externs.collect { case d : Extern .Def => (d.id, d) }.toMap,
93
+ None ,
94
+ None ,
95
+ None ,
103
96
Map .empty,
104
97
D , C )
105
98
@@ -158,18 +151,18 @@ object TransformerCps extends Transformer {
158
151
159
152
def toJS (id : Id , b : cps.Block )(using TransformerContext ): js.Expr = b match {
160
153
case cps.Block .BlockLit (vparams, bparams, ks, k, body) =>
161
- val used = new Used (false )
154
+ val used = new RecursiveUsage (false )
162
155
163
- val translatedBody = toJS(body)(using enterDefinition (id, used, b)).stmts
156
+ val translatedBody = toJS(body)(using recursive (id, used, b)).stmts
164
157
165
- if used.used then
158
+ if used.jumped then
166
159
js.Lambda (vparams.map(nameDef) ++ bparams.map(nameDef) ++ List (nameDef(ks), nameDef(k)),
167
160
List (js.While (RawExpr (" true" ), translatedBody, Some (uniqueName(id)))))
168
161
else
169
162
js.Lambda (vparams.map(nameDef) ++ bparams.map(nameDef) ++ List (nameDef(ks), nameDef(k)),
170
163
translatedBody)
171
164
172
- case other => toJS(other)( using clearDefinitions)
165
+ case other => toJS(other)
173
166
}
174
167
175
168
def toJS (b : cps.Block )(using TransformerContext ): js.Expr = b match {
@@ -185,17 +178,18 @@ object TransformerCps extends Transformer {
185
178
case cps.Implementation (interface, operations) =>
186
179
js.Object (operations.map {
187
180
case cps.Operation (id, vps, bps, ks, k, body) =>
188
- nameDef(id) -> js.Lambda (vps.map(nameDef) ++ bps.map(nameDef) ++ List (nameDef(ks), nameDef(k)), toJS(body)(using clearDefinitions ).stmts)
181
+ nameDef(id) -> js.Lambda (vps.map(nameDef) ++ bps.map(nameDef) ++ List (nameDef(ks), nameDef(k)), toJS(body)(using nonrecursive(ks) ).stmts)
189
182
})
190
183
}
191
184
192
- def toJS (ks : cps.MetaCont ): js.Expr = nameRef(ks.id)
185
+ def toJS (ks : cps.MetaCont )(using T : TransformerContext ): js.Expr =
186
+ nameRef(T .metaconts.getOrElse(ks.id, ks.id))
193
187
194
188
def toJS (k : cps.Cont )(using T : TransformerContext ): js.Expr = k match {
195
189
case Cont .ContVar (id) =>
196
190
nameRef(id)
197
191
case Cont .ContLam (result, ks, body) =>
198
- js.Lambda (List (nameDef(result), nameDef(ks)), toJS(body)(using clearDefinitions ).stmts)
192
+ js.Lambda (List (nameDef(result), nameDef(ks)), toJS(body)(using nonrecursive(ks) ).stmts)
199
193
}
200
194
201
195
def toJS (e : cps.Expr )(using D : TransformerContext ): js.Expr = e match {
@@ -225,9 +219,18 @@ object TransformerCps extends Transformer {
225
219
js.Const (nameDef(id), toJS(binding)) :: toJS(body).run(k)
226
220
}
227
221
228
- case cps.Stmt .LetCont (id, binding, body) =>
222
+ // [[ let k(x, ks) = ...; if (...) jump k(42, ks2) else jump k(10, ks3) ]] =
223
+ // let x; if (...) { x = 42; ks = ks2 } else { x = 10; ks = ks3 } ...
224
+ case cps.Stmt .LetCont (id, Cont .ContLam (param, ks, body), body2) if canBeDirect(id, body2) =>
229
225
Binding { k =>
230
- js.Const (nameDef(id), toJS(binding)) :: requiringThunk { toJS(body)(using clearDefinitions) }.run(k)
226
+ js.Let (nameDef(param), js.Undefined ) ::
227
+ toJS(body2)(using withDirectStyle(id, param, ks)).stmts ++
228
+ toJS(body)(using directstyle(ks)).run(k)
229
+ }
230
+
231
+ case cps.Stmt .LetCont (id, binding @ Cont .ContLam (result2, ks2, body2), body) =>
232
+ Binding { k =>
233
+ js.Const (nameDef(id), toJS(binding)(using nonrecursive(ks2))) :: requiringThunk { toJS(body) }.run(k)
231
234
}
232
235
233
236
case cps.Stmt .Match (sc, Nil , None ) =>
@@ -245,19 +248,30 @@ object TransformerCps extends Transformer {
245
248
pure(js.Switch (js.Member (scrutinee, `tag`),
246
249
clauses.map { case (tag, clause) =>
247
250
val (e, binding) = toJS(scrutinee, tag, clause);
248
- (e, binding.stmts)
251
+
252
+ val stmts = binding.stmts
253
+
254
+ stmts.last match {
255
+ case terminator : (js.Stmt .Return | js.Stmt .Break | js.Stmt .Continue ) => (e, stmts)
256
+ case other => (e, stmts :+ js.Break ())
257
+ }
249
258
},
250
259
default.map { s => toJS(s).stmts }))
251
260
261
+ case cps.Stmt .Jump (k, arg, ks) if D .directStyle.exists(c => c.k == k) => D .directStyle match {
262
+ case Some (ContinuationInfo (k2, param2, ks2)) => pure(js.Assign (nameRef(param2), toJS(arg)))
263
+ case None => sys error " Should not happen"
264
+ }
265
+
252
266
case cps.Stmt .Jump (k, arg, ks) =>
253
267
pure(js.Return (maybeThunking(js.Call (nameRef(k), toJS(arg), toJS(ks)))))
254
268
255
- case cps.Stmt .App (callee @ DefInfo (id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont (ks), Cont .ContVar (k)) if ks1 == ks && k1 == k =>
269
+ case cps.Stmt .App (Recursive (id, vparams, bparams, ks1, k1, used), vargs, bargs, MetaCont (ks), Cont .ContVar (k)) if sameScope(ks, k, ks1, k1) =>
256
270
Binding { k2 =>
257
271
val stmts = mutable.ListBuffer .empty[js.Stmt ]
258
272
stmts.append(js.RawStmt (" /* prepare tail call */" ))
259
273
260
- used.used = true
274
+ used.jumped = true
261
275
262
276
// const x3 = [[ arg ]]; ...
263
277
val vtmps = (vparams zip vargs).map { (id, arg) =>
@@ -336,13 +350,13 @@ object TransformerCps extends Transformer {
336
350
}
337
351
338
352
case cps.Stmt .Reset (prog, ks, k) =>
339
- pure(js.Return (Call (RESET , toJS(prog)(using clearDefinitions ), toJS(ks), toJS(k))))
353
+ pure(js.Return (Call (RESET , toJS(prog)(using nonrecursive(prog) ), toJS(ks), toJS(k))))
340
354
341
355
case cps.Stmt .Shift (prompt, body, ks, k) =>
342
- pure(js.Return (Call (SHIFT , nameRef(prompt), noThunking { toJS(body)(using clearDefinitions ) }, toJS(ks), toJS(k))))
356
+ pure(js.Return (Call (SHIFT , nameRef(prompt), noThunking { toJS(body)(using nonrecursive(body) ) }, toJS(ks), toJS(k))))
343
357
344
358
case cps.Stmt .Resume (r, b, ks2, k2) =>
345
- pure(js.Return (js.Call (RESUME , nameRef(r), toJS(b)(using clearDefinitions ), toJS(ks2), toJS(k2))))
359
+ pure(js.Return (js.Call (RESUME , nameRef(r), toJS(b)(using nonrecursive(b) ), toJS(ks2), toJS(k2))))
346
360
347
361
case cps.Stmt .Hole () =>
348
362
pure(js.Return ($effekt.call(" hole" )))
@@ -373,7 +387,7 @@ object TransformerCps extends Transformer {
373
387
// Inlining Externs
374
388
// ----------------
375
389
376
- def inlineExtern (id : Id , args : List [cps.Pure ])(using T : TransformerContext ): js.Expr =
390
+ private def inlineExtern (id : Id , args : List [cps.Pure ])(using T : TransformerContext ): js.Expr =
377
391
T .externs.get(id) match {
378
392
case Some (cps.Extern .Def (id, params, Nil , async,
379
393
ExternBody .StringExternBody (featureFlag, Template (strings, templateArgs)))) if ! async =>
@@ -383,11 +397,96 @@ object TransformerCps extends Transformer {
383
397
case _ => js.Call (nameRef(id), args.map(toJS))
384
398
}
385
399
386
- def canInline (extern : cps.Extern ): Boolean = extern match {
400
+ private def canInline (extern : cps.Extern ): Boolean = extern match {
387
401
case cps.Extern .Def (_, _, Nil , async, ExternBody .StringExternBody (_, Template (_, _))) => ! async
388
402
case _ => false
389
403
}
390
404
405
+ private def bindingAll [R ](bs : List [(Id , js.Expr )])(body : TransformerContext ?=> R )(using C : TransformerContext ): R =
406
+ body(using C .copy(bindings = C .bindings ++ bs))
407
+
408
+ private def lookup (id : Id )(using C : TransformerContext ): js.Expr = C .bindings.getOrElse(id, nameRef(id))
409
+
410
+
411
+ // Helpers for Direct-Style Transformation
412
+ // ---------------------------------------
413
+
414
+ /**
415
+ * Used to determine whether a call with continuations [[ ks ]] (after substitution) and [[ k ]]
416
+ * is the same as the original function definition (that is [[ ks1 ]] and [[ k1 ]].
417
+ */
418
+ private def sameScope (ks : Id , k : Id , ks1 : Id , k1 : Id )(using C : TransformerContext ): Boolean =
419
+ ks1 == C .metaconts.getOrElse(ks, ks) && k1 == k
420
+
421
+ private def withDirectStyle (id : Id , param : Id , ks : Id )(using C : TransformerContext ): TransformerContext =
422
+ C .copy(directStyle = Some (ContinuationInfo (id, param, ks)))
423
+
424
+ private def recursive (id : Id , used : RecursiveUsage , block : cps.Block )(using C : TransformerContext ): TransformerContext = block match {
425
+ case cps.BlockLit (vparams, bparams, ks, k, body) =>
426
+ C .copy(recursive = Some (RecursiveDefInfo (id, vparams, bparams, ks, k, used)), directStyle = None , metacont = Some (ks))
427
+ case _ => C
428
+ }
429
+
430
+ private def nonrecursive (ks : Id )(using C : TransformerContext ): TransformerContext =
431
+ C .copy(recursive = None , directStyle = None , metacont = Some (ks))
432
+
433
+ private def nonrecursive (block : cps.BlockLit )(using C : TransformerContext ): TransformerContext = nonrecursive(block.ks)
434
+
435
+ // ks | let k1 x1 ks1 = { let k2 x2 ks2 = jump k v ks2 }; ... = jump k v ks
436
+ private def directstyle (ks : Id )(using C : TransformerContext ): TransformerContext =
437
+ val outer = C .metacont.getOrElse { sys error " Metacontinuation missing..." }
438
+ val outerSubstituted = C .metaconts.getOrElse(outer, outer)
439
+ val subst = C .metaconts.updated(ks, outerSubstituted)
440
+ C .copy(metacont = Some (ks), metaconts = subst)
441
+
442
+ private object Recursive {
443
+ def unapply (b : cps.Block )(using C : TransformerContext ): Option [(Id , List [Id ], List [Id ], Id , Id , RecursiveUsage )] = b match {
444
+ case cps.Block .BlockVar (id) => C .recursive.collect {
445
+ case RecursiveDefInfo (id2, vparams, bparams, ks, k, used) if id == id2 => (id, vparams, bparams, ks, k, used)
446
+ }
447
+ case _ => None
448
+ }
449
+ }
450
+
451
+ private def canBeDirect (k : Id , stmt : Stmt )(using T : TransformerContext ): Boolean =
452
+ def notIn (term : Stmt | Block | Expr | (Id , Clause ) | Cont ) =
453
+ val freeVars = term match {
454
+ case s : Stmt => free(s)
455
+ case b : Block => free(b)
456
+ case p : Expr => free(p)
457
+ case (id, Clause (_, body)) => free(body)
458
+ case c : Cont => free(c)
459
+ }
460
+ ! freeVars.contains(k)
461
+ stmt match {
462
+ case Stmt .Jump (k2, arg, ks2) if k2 == k => notIn(arg) && T .metacont.contains(ks2.id)
463
+ case Stmt .Jump (k2, arg, ks2) => notIn(arg)
464
+ // TODO this could be a tailcall!
465
+ case Stmt .App (callee, vargs, bargs, ks, k) => notIn(stmt)
466
+ case Stmt .Invoke (callee, method, vargs, bargs, ks, k2) => notIn(stmt)
467
+ case Stmt .If (cond, thn, els) => canBeDirect(k, thn) && canBeDirect(k, els)
468
+ case Stmt .Match (scrutinee, clauses, default) => clauses.forall {
469
+ case (id, Clause (vparams, body)) => canBeDirect(k, body)
470
+ } && default.forall(body => canBeDirect(k, body))
471
+ case Stmt .LetDef (id, binding, body) => notIn(binding) && canBeDirect(k, body)
472
+ case Stmt .LetExpr (id, binding, body) => notIn(binding) && canBeDirect(k, body)
473
+ case Stmt .LetCont (id, Cont .ContLam (result, ks2, body), body2) =>
474
+ def willBeDirectItself = canBeDirect(id, body2) && canBeDirect(k, body)(using directstyle(ks2))
475
+ def notFreeinContinuation = notIn(body) && canBeDirect(k, body2)
476
+ willBeDirectItself || notFreeinContinuation
477
+ case Stmt .Region (id, ks, body) => notIn(body)
478
+ case Stmt .Alloc (id, init, region, body) => notIn(init) && canBeDirect(k, body)
479
+ case Stmt .Var (id, init, ks2, body) => notIn(init) && canBeDirect(k, body)
480
+ case Stmt .Dealloc (ref, body) => canBeDirect(k, body)
481
+ case Stmt .Get (ref, id, body) => canBeDirect(k, body)
482
+ case Stmt .Put (ref, value, body) => notIn(value) && canBeDirect(k, body)
483
+ case Stmt .Reset (prog, ks, k) => notIn(stmt)
484
+ case Stmt .Shift (prompt, body, ks, k) => notIn(stmt)
485
+ case Stmt .Resume (resumption, body, ks, k) => notIn(stmt)
486
+ case Stmt .Hole () => true
487
+ }
488
+
489
+
391
490
392
491
// Thunking
393
492
// --------
0 commit comments