Skip to content

Commit a753cb2

Browse files
CAG2MarkAnsonYeungLPTK
authored
Class Lifter (#266)
Co-authored-by: Anson Yeung <[email protected]> Co-authored-by: Lionel Parreaux <[email protected]>
1 parent 3bf1467 commit a753cb2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+4958
-366
lines changed

core/shared/main/scala/utils/algorithms.scala

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mlscript.utils
33
import scala.annotation.tailrec
44
import scala.collection.immutable.SortedMap
55

6+
67
object algorithms {
78
final class CyclicGraphError(message: String) extends Exception(message)
89

@@ -29,4 +30,145 @@ object algorithms {
2930
}
3031
sort(toPred, Seq())
3132
}
33+
34+
/**
35+
* Partitions a graph into its strongly connected components. The input type must be able to
36+
* be hashed efficiently as it will be used as a key.
37+
*
38+
* @param edges The edges of the graph.
39+
* @param nodes Any additional nodes that are not necessarily in the edges list. (Overlap is fine)
40+
* @return A list of strongly connected components of the graph.
41+
*/
42+
def partitionScc[A](edges: Iterable[(A, A)], nodes: Iterable[A]): List[List[A]] = {
43+
case class SccNode[A](
44+
val node: A,
45+
val id: Int,
46+
var num: Int = -1,
47+
var lowlink: Int = -1,
48+
var visited: Boolean = false,
49+
var onStack: Boolean = false
50+
)
51+
52+
// pre-process: assign each node an id
53+
val edgesSet = edges.toSet
54+
val nodesUniq = (edgesSet.flatMap { case (a, b) => Set(a, b) } ++ nodes.toSet).toList
55+
val nodesN = nodesUniq.zipWithIndex.map { case (node, idx) => SccNode(node, idx) }
56+
val nodeToIdx = nodesN.map(node => node.node -> node.id).toMap
57+
val nodesIdx = nodesN.map { case node => node.id -> node }.toMap
58+
59+
val neighbours = edges
60+
.map { case (a, b) => (nodeToIdx(a), nodesIdx(nodeToIdx(b))) }
61+
.groupBy(_._1)
62+
.map { case (a, b) => a -> b.map(_._2) }
63+
.withDefault(_ => Nil)
64+
65+
// Tarjan's algorithm
66+
67+
var stack: List[SccNode[A]] = List.empty
68+
var sccs: List[List[A]] = List.empty
69+
var i = 0
70+
71+
def dfs(node: SccNode[A], depth: Int = 0): Unit = {
72+
def printlnsp(s: String) = {
73+
println(s)
74+
}
75+
76+
node.num = i
77+
node.lowlink = node.num
78+
node.visited = true
79+
stack = node :: stack
80+
i += 1
81+
for (n <- neighbours(node.id)) {
82+
if (!n.visited) {
83+
dfs(n, depth + 1)
84+
node.lowlink = n.lowlink.min(node.lowlink)
85+
} else if (!n.onStack) {
86+
node.lowlink = n.num.min(node.lowlink)
87+
}
88+
}
89+
if (node.lowlink == node.num) {
90+
var scc: List[A] = List.empty
91+
var cur = stack.head
92+
stack = stack.tail
93+
cur.onStack = true
94+
while (cur.id != node.id) {
95+
scc = cur.node :: scc
96+
cur = stack.head
97+
stack = stack.tail
98+
cur.onStack = true
99+
}
100+
scc = cur.node :: scc
101+
sccs = scc :: sccs
102+
}
103+
}
104+
105+
for (n <- nodesN) {
106+
if (!n.visited) dfs(n)
107+
}
108+
sccs
109+
}
110+
111+
112+
/**
113+
* Info about a graph partitioned into its strongly-connected sets. The input type must be able to
114+
* be hashed efficiently as it will be used as a key.
115+
*
116+
* @param sccs The strongly connected sets.
117+
* @param edges The edges of the strongly-connected sets. Together with `sccs`, this forms an acyclic graph.
118+
* @param inDegs The in-degrees of the above described graph.
119+
* @param outDegs The out-degrees of the above described graph.
120+
*/
121+
case class SccsInfo[A](
122+
sccs: Map[Int, List[A]],
123+
edges: Map[Int, Iterable[Int]],
124+
inDegs: Map[Int, Int],
125+
outDegs: Map[Int, Int],
126+
)
127+
128+
/**
129+
* Partitions a graph into its strongly connected components and returns additional information
130+
* about the partition. The input type must be able to be hashed efficiently as it will be used as a key.
131+
*
132+
* @param edges The edges of the graph.
133+
* @param nodes Any additional nodes that are not necessarily in the edges list. (Overlap is fine)
134+
* @return The partitioned graph and info about it.
135+
*/
136+
def sccsWithInfo[A](edges: Iterable[(A, A)], nodes: Iterable[A]): SccsInfo[A] = {
137+
val sccs = partitionScc(edges, nodes)
138+
val withIdx = sccs.zipWithIndex.map(_.swap).toMap
139+
val lookup = (
140+
for {
141+
(id, scc) <- withIdx
142+
node <- scc
143+
} yield node -> id
144+
).toMap
145+
146+
val notInSccEdges = edges.map {
147+
case (a, b) => (lookup(a), lookup(b))
148+
}.filter {
149+
case (a, b) => a != b
150+
}
151+
152+
val outs = notInSccEdges.groupBy {
153+
case (a, b) => a
154+
}
155+
156+
val sccEdges = withIdx.map {
157+
case (a, _) => a -> Nil // add default case
158+
} ++ outs.map {
159+
case (a, edges) => a -> edges.map(_._2)
160+
}.toMap
161+
162+
val inDegs = notInSccEdges.groupBy {
163+
case (a, b) => b
164+
}.map {
165+
case (b, edges) => b -> edges.size
166+
}
167+
168+
val outDegs = outs.map {
169+
case (a, edges) => a -> edges.size
170+
}
171+
172+
SccsInfo(withIdx, sccEdges, inDegs, outDegs)
173+
}
32174
}

hkmc2/shared/src/main/scala/hkmc2/Config.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def config(using Config): Config = summon
1111
case class Config(
1212
sanityChecks: Opt[SanityChecks],
1313
effectHandlers: Opt[EffectHandlers],
14+
liftDefns: Opt[LiftDefns],
1415
):
1516

1617
def stackSafety: Opt[StackSafety] = effectHandlers.flatMap(_.stackSafety)
@@ -24,6 +25,7 @@ object Config:
2425
sanityChecks = N, // TODO make the default S
2526
// sanityChecks = S(SanityChecks(light = true)),
2627
effectHandlers = N,
28+
liftDefns = N,
2729
)
2830

2931
case class SanityChecks(light: Bool)
@@ -35,6 +37,8 @@ object Config:
3537
val default: StackSafety = StackSafety(
3638
stackLimit = 500,
3739
)
40+
41+
case class LiftDefns() // there may be other settings in the future, having it as a case class now
3842

3943
end Config
4044

hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala

Lines changed: 67 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,37 @@ sealed abstract class Block extends Product with AutoLocated:
9191
case TryBlock(sub, finallyDo, rest) => sub.freeVars ++ finallyDo.freeVars ++ rest.freeVars
9292
case Assign(lhs, rhs, rest) => Set(lhs) ++ rhs.freeVars ++ rest.freeVars
9393
case AssignField(lhs, nme, rhs, rest) => lhs.freeVars ++ rhs.freeVars ++ rest.freeVars
94+
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => lhs.freeVars ++ fld.freeVars ++ rhs.freeVars ++ rest.freeVars
9495
case Define(defn, rest) => defn.freeVars ++ rest.freeVars
9596
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) =>
9697
(bod.freeVars - lhs) ++ rst.freeVars ++ hdr.flatMap(_.freeVars)
9798
case HandleBlockReturn(res) => res.freeVars
9899
case End(msg) => Set.empty
99100

101+
// TODO: freeVarsLLIR skips `fun` and `cls` in `Call` and `Instantiate` respectively, which is needed in some
102+
// other places. However, adding them breaks some LLIR tests. Supposedly, once the IR uses the new symbol system,
103+
// this should no longer happen. This version should be removed once that is resolved.
104+
lazy val freeVarsLLIR: Set[Local] = this match
105+
case Match(scrut, arms, dflt, rest) =>
106+
scrut.freeVarsLLIR ++ dflt.toList.flatMap(_.freeVarsLLIR) ++ rest.freeVarsLLIR
107+
++ arms.flatMap:
108+
(pat, arm) => arm.freeVarsLLIR -- pat.freeVars
109+
case Return(res, implct) => res.freeVarsLLIR
110+
case Throw(exc) => exc.freeVarsLLIR
111+
case Label(label, body, rest) => (body.freeVarsLLIR - label) ++ rest.freeVarsLLIR
112+
case Break(label) => Set(label)
113+
case Continue(label) => Set(label)
114+
case Begin(sub, rest) => sub.freeVarsLLIR ++ rest.freeVarsLLIR
115+
case TryBlock(sub, finallyDo, rest) => sub.freeVarsLLIR ++ finallyDo.freeVarsLLIR ++ rest.freeVarsLLIR
116+
case Assign(lhs, rhs, rest) => Set(lhs) ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
117+
case AssignField(lhs, nme, rhs, rest) => lhs.freeVarsLLIR ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
118+
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => lhs.freeVarsLLIR ++ fld.freeVarsLLIR ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
119+
case Define(defn, rest) => defn.freeVarsLLIR ++ rest.freeVarsLLIR
120+
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) =>
121+
(bod.freeVarsLLIR - lhs) ++ rst.freeVarsLLIR ++ hdr.flatMap(_.freeVars)
122+
case HandleBlockReturn(res) => res.freeVarsLLIR
123+
case End(msg) => Set.empty
124+
100125
lazy val subBlocks: Ls[Block] = this match
101126
case Match(p, arms, dflt, rest) => p.subBlocks ++ arms.map(_._2) ++ dflt.toList :+ rest
102127
case Begin(sub, rest) => sub :: rest :: Nil
@@ -122,15 +147,19 @@ sealed abstract class Block extends Product with AutoLocated:
122147
// Note that this returns the definitions in reverse order, with the bottommost definiton appearing
123148
// last. This is so that using defns.foldLeft later to add the definitions to the front of a block,
124149
// we don't need to reverse the list again to preserve the order of the definitions.
125-
def floatOutDefns =
150+
def floatOutDefns(
151+
ignore: Defn => Bool = _ => false,
152+
preserve: Defn => Bool = _ => false
153+
) =
126154
var defns: List[Defn] = Nil
127155
val transformer = new BlockTransformerShallow(SymbolSubst()):
128156
override def applyBlock(b: Block): Block = b match
129-
case Define(defn, rest) => defn match
157+
case Define(defn, rest) if !ignore(defn) => defn match
130158
case v: ValDefn => super.applyBlock(b)
131159
case _ =>
132160
defns ::= defn
133-
applyBlock(rest)
161+
if preserve(defn) then super.applyBlock(b)
162+
else applyBlock(rest)
134163
case _ => super.applyBlock(b)
135164

136165
(transformer.applyBlock(this), defns)
@@ -281,10 +310,20 @@ sealed abstract class Defn:
281310
lazy val freeVars: Set[Local] = this match
282311
case FunDefn(own, sym, params, body) => body.freeVars -- params.flatMap(_.paramSyms) - sym
283312
case ValDefn(owner, k, sym, rhs) => rhs.freeVars
284-
case ClsLikeDefn(own, isym, sym, k, paramsOpt, parentSym, methods, privateFields, publicFields, preCtor, ctor) =>
313+
case ClsLikeDefn(own, isym, sym, k, paramsOpt, auxParams, parentSym,
314+
methods, privateFields, publicFields, preCtor, ctor) =>
285315
preCtor.freeVars
286316
++ ctor.freeVars ++ methods.flatMap(_.freeVars)
287-
-- privateFields -- publicFields.map(_.sym)
317+
-- privateFields -- publicFields.map(_.sym) -- auxParams.flatMap(_.paramSyms)
318+
319+
lazy val freeVarsLLIR: Set[Local] = this match
320+
case FunDefn(own, sym, params, body) => body.freeVarsLLIR -- params.flatMap(_.paramSyms) - sym
321+
case ValDefn(owner, k, sym, rhs) => rhs.freeVarsLLIR
322+
case ClsLikeDefn(own, isym, sym, k, paramsOpt, auxParams, parentSym,
323+
methods, privateFields, publicFields, preCtor, ctor) =>
324+
preCtor.freeVarsLLIR
325+
++ ctor.freeVarsLLIR ++ methods.flatMap(_.freeVarsLLIR)
326+
-- privateFields -- publicFields.map(_.sym) -- auxParams.flatMap(_.paramSyms)
288327

289328
final case class FunDefn(
290329
owner: Opt[InnerSymbol],
@@ -304,10 +343,11 @@ final case class ValDefn(
304343

305344
final case class ClsLikeDefn(
306345
owner: Opt[InnerSymbol],
307-
isym: MemberSymbol[? <: ClassLikeDef],
346+
isym: MemberSymbol[? <: ClassLikeDef] & InnerSymbol,
308347
sym: BlockMemberSymbol,
309348
k: syntax.ClsLikeKind,
310349
paramsOpt: Opt[ParamList],
350+
auxParams: List[ParamList],
311351
parentPath: Opt[Path],
312352
methods: Ls[FunDefn],
313353
privateFields: Ls[TermSymbol],
@@ -325,6 +365,7 @@ final case class Handler(
325365
params: Ls[ParamList],
326366
body: Block,
327367
):
368+
lazy val freeVarsLLIR: Set[Local] = body.freeVarsLLIR -- params.flatMap(_.paramSyms) - sym - resumeSym
328369
lazy val freeVars: Set[Local] = body.freeVars -- params.flatMap(_.paramSyms) - sym - resumeSym
329370

330371
/* Represents either unreachable code (for functions that must return a result)
@@ -341,6 +382,11 @@ enum Case:
341382
case Cls(_, path) => path.freeVars
342383
case Tup(_, _) => Set.empty
343384

385+
lazy val freeVarsLLIR: Set[Local] = this match
386+
case Lit(_) => Set.empty
387+
case Cls(_, path) => path.freeVarsLLIR
388+
case Tup(_, _) => Set.empty
389+
344390
sealed abstract class Result:
345391

346392
// TODO rm Lam from values and thus the need for this method
@@ -353,14 +399,26 @@ sealed abstract class Result:
353399
case _ => Nil
354400

355401
lazy val freeVars: Set[Local] = this match
356-
case Call(fun, args) => args.flatMap(_.value.freeVars).toSet
357-
case Instantiate(cls, args) => args.flatMap(_.freeVars).toSet
402+
case Call(fun, args) => fun.freeVars ++ args.flatMap(_.value.freeVars).toSet
403+
case Instantiate(cls, args) => cls.freeVars ++ args.flatMap(_.freeVars).toSet
358404
case Select(qual, name) => qual.freeVars
359405
case Value.Ref(l) => Set(l)
360406
case Value.This(sym) => Set.empty
361407
case Value.Lit(lit) => Set.empty
362408
case Value.Lam(params, body) => body.freeVars -- params.paramSyms
363409
case Value.Arr(elems) => elems.flatMap(_.value.freeVars).toSet
410+
case DynSelect(qual, fld, arrayIdx) => qual.freeVars ++ fld.freeVars
411+
412+
lazy val freeVarsLLIR: Set[Local] = this match
413+
case Call(fun, args) => args.flatMap(_.value.freeVarsLLIR).toSet
414+
case Instantiate(cls, args) => args.flatMap(_.freeVarsLLIR).toSet
415+
case Select(qual, name) => qual.freeVarsLLIR
416+
case Value.Ref(l) => Set(l)
417+
case Value.This(sym) => Set.empty
418+
case Value.Lit(lit) => Set.empty
419+
case Value.Lam(params, body) => body.freeVarsLLIR -- params.paramSyms
420+
case Value.Arr(elems) => elems.flatMap(_.value.freeVarsLLIR).toSet
421+
case DynSelect(qual, fld, arrayIdx) => qual.freeVarsLLIR ++ fld.freeVarsLLIR
364422

365423
// type Local = LocalSymbol
366424
type Local = Symbol
@@ -375,6 +433,7 @@ case class Instantiate(cls: Path, args: Ls[Path]) extends Result
375433

376434
sealed abstract class Path extends Result:
377435
def selN(id: Tree.Ident): Path = Select(this, id)(N)
436+
def sel(id: Tree.Ident, sym: FieldSymbol): Path = Select(this, id)(S(sym))
378437
def selSN(id: Str): Path = selN(new Tree.Ident(id))
379438
def asArg = Arg(false, this)
380439

0 commit comments

Comments
 (0)