Skip to content

Commit bf4c763

Browse files
committed
Merge remote-tracking branch 'upstream/hkmc2' into logic-subtyping
2 parents 0298463 + fd6cc07 commit bf4c763

File tree

163 files changed

+8358
-1007
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

163 files changed

+8358
-1007
lines changed

core/shared/main/scala/utils/algorithms.scala

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mlscript.utils
33
import scala.annotation.tailrec
44
import scala.collection.immutable.SortedMap
55

6+
67
object algorithms {
78
final class CyclicGraphError(message: String) extends Exception(message)
89

@@ -29,4 +30,145 @@ object algorithms {
2930
}
3031
sort(toPred, Seq())
3132
}
33+
34+
/**
35+
* Partitions a graph into its strongly connected components. The input type must be able to
36+
* be hashed efficiently as it will be used as a key.
37+
*
38+
* @param edges The edges of the graph.
39+
* @param nodes Any additional nodes that are not necessarily in the edges list. (Overlap is fine)
40+
* @return A list of strongly connected components of the graph.
41+
*/
42+
def partitionScc[A](edges: Iterable[(A, A)], nodes: Iterable[A]): List[List[A]] = {
43+
case class SccNode[A](
44+
val node: A,
45+
val id: Int,
46+
var num: Int = -1,
47+
var lowlink: Int = -1,
48+
var visited: Boolean = false,
49+
var onStack: Boolean = false
50+
)
51+
52+
// pre-process: assign each node an id
53+
val edgesSet = edges.toSet
54+
val nodesUniq = (edgesSet.flatMap { case (a, b) => Set(a, b) } ++ nodes.toSet).toList
55+
val nodesN = nodesUniq.zipWithIndex.map { case (node, idx) => SccNode(node, idx) }
56+
val nodeToIdx = nodesN.map(node => node.node -> node.id).toMap
57+
val nodesIdx = nodesN.map { case node => node.id -> node }.toMap
58+
59+
val neighbours = edges
60+
.map { case (a, b) => (nodeToIdx(a), nodesIdx(nodeToIdx(b))) }
61+
.groupBy(_._1)
62+
.map { case (a, b) => a -> b.map(_._2) }
63+
.withDefault(_ => Nil)
64+
65+
// Tarjan's algorithm
66+
67+
var stack: List[SccNode[A]] = List.empty
68+
var sccs: List[List[A]] = List.empty
69+
var i = 0
70+
71+
def dfs(node: SccNode[A], depth: Int = 0): Unit = {
72+
def printlnsp(s: String) = {
73+
println(s)
74+
}
75+
76+
node.num = i
77+
node.lowlink = node.num
78+
node.visited = true
79+
stack = node :: stack
80+
i += 1
81+
for (n <- neighbours(node.id)) {
82+
if (!n.visited) {
83+
dfs(n, depth + 1)
84+
node.lowlink = n.lowlink.min(node.lowlink)
85+
} else if (!n.onStack) {
86+
node.lowlink = n.num.min(node.lowlink)
87+
}
88+
}
89+
if (node.lowlink == node.num) {
90+
var scc: List[A] = List.empty
91+
var cur = stack.head
92+
stack = stack.tail
93+
cur.onStack = true
94+
while (cur.id != node.id) {
95+
scc = cur.node :: scc
96+
cur = stack.head
97+
stack = stack.tail
98+
cur.onStack = true
99+
}
100+
scc = cur.node :: scc
101+
sccs = scc :: sccs
102+
}
103+
}
104+
105+
for (n <- nodesN) {
106+
if (!n.visited) dfs(n)
107+
}
108+
sccs
109+
}
110+
111+
112+
/**
113+
* Info about a graph partitioned into its strongly-connected sets. The input type must be able to
114+
* be hashed efficiently as it will be used as a key.
115+
*
116+
* @param sccs The strongly connected sets.
117+
* @param edges The edges of the strongly-connected sets. Together with `sccs`, this forms an acyclic graph.
118+
* @param inDegs The in-degrees of the above described graph.
119+
* @param outDegs The out-degrees of the above described graph.
120+
*/
121+
case class SccsInfo[A](
122+
sccs: Map[Int, List[A]],
123+
edges: Map[Int, Iterable[Int]],
124+
inDegs: Map[Int, Int],
125+
outDegs: Map[Int, Int],
126+
)
127+
128+
/**
129+
* Partitions a graph into its strongly connected components and returns additional information
130+
* about the partition. The input type must be able to be hashed efficiently as it will be used as a key.
131+
*
132+
* @param edges The edges of the graph.
133+
* @param nodes Any additional nodes that are not necessarily in the edges list. (Overlap is fine)
134+
* @return The partitioned graph and info about it.
135+
*/
136+
def sccsWithInfo[A](edges: Iterable[(A, A)], nodes: Iterable[A]): SccsInfo[A] = {
137+
val sccs = partitionScc(edges, nodes)
138+
val withIdx = sccs.zipWithIndex.map(_.swap).toMap
139+
val lookup = (
140+
for {
141+
(id, scc) <- withIdx
142+
node <- scc
143+
} yield node -> id
144+
).toMap
145+
146+
val notInSccEdges = edges.map {
147+
case (a, b) => (lookup(a), lookup(b))
148+
}.filter {
149+
case (a, b) => a != b
150+
}
151+
152+
val outs = notInSccEdges.groupBy {
153+
case (a, b) => a
154+
}
155+
156+
val sccEdges = withIdx.map {
157+
case (a, _) => a -> Nil // add default case
158+
} ++ outs.map {
159+
case (a, edges) => a -> edges.map(_._2)
160+
}.toMap
161+
162+
val inDegs = notInSccEdges.groupBy {
163+
case (a, b) => b
164+
}.map {
165+
case (b, edges) => b -> edges.size
166+
}
167+
168+
val outDegs = outs.map {
169+
case (a, edges) => a -> edges.size
170+
}
171+
172+
SccsInfo(withIdx, sccEdges, inDegs, outDegs)
173+
}
32174
}

core/shared/main/scala/utils/package.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,8 @@ package object utils {
228228
def TODO(msg: Any, cond: Bool): Unit = if (cond) TODO(msg)
229229
def die: Nothing = lastWords("Program reached an unexpected state.")
230230
def lastWords(msg: String): Nothing = throw new Exception(s"Internal Error: $msg")
231-
def wat(msg: String, wat: Any): Nothing = lastWords(s"$msg ($wat)")
231+
def wat(msg: String, obj: Any): Nothing = lastWords(s"$msg ($obj)")
232+
def wat(obj: Any): Nothing = wat(s"unexpected value", obj)
232233

233234
/** To make Scala unexhaustivity warnings believed to be spurious go away,
234235
* while clearly indicating the intent. */

hkmc2/shared/src/main/scala/hkmc2/Config.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def config(using Config): Config = summon
1111
case class Config(
1212
sanityChecks: Opt[SanityChecks],
1313
effectHandlers: Opt[EffectHandlers],
14+
liftDefns: Opt[LiftDefns],
1415
):
1516

1617
def stackSafety: Opt[StackSafety] = effectHandlers.flatMap(_.stackSafety)
@@ -24,6 +25,7 @@ object Config:
2425
sanityChecks = N, // TODO make the default S
2526
// sanityChecks = S(SanityChecks(light = true)),
2627
effectHandlers = N,
28+
liftDefns = N,
2729
)
2830

2931
case class SanityChecks(light: Bool)
@@ -35,6 +37,8 @@ object Config:
3537
val default: StackSafety = StackSafety(
3638
stackLimit = 500,
3739
)
40+
41+
case class LiftDefns() // there may be other settings in the future, having it as a case class now
3842

3943
end Config
4044

hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
247247
val cr = freshVar(new TempSymbol(S(unq), "ctx"))
248248
constrain(tryMkMono(ty, body), BbCtx.codeTy(tv, cr))
249249
(tv, cr, eff)
250-
case blk @ Term.Blk(LetDecl(sym, _) :: DefineVar(sym2, rhs) :: Nil, body) if sym2 is sym => // TODO: more than one!!
250+
case blk @ Term.Blk(LetDecl(sym, _) :: DefineVar(sym2, rhs) :: Nil, body)
251+
if sym2 is sym => // TODO: more than one!!
251252
val (rhsTy, rhsCtx, rhsEff) = typeCode(rhs)(using ctx)
252253
val nestCtx = ctx.nextLevel
253254
given BbCtx = nestCtx

hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala

Lines changed: 78 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ case class Program(
2222

2323
sealed abstract class Block extends Product with AutoLocated:
2424

25+
def ~(that: Block): Block = Begin(this, that)
26+
2527
protected def children: Ls[Located] = ??? // Maybe extending AutoLocated is unnecessary
2628

2729
lazy val definedVars: Set[Local] = this match
@@ -91,12 +93,37 @@ sealed abstract class Block extends Product with AutoLocated:
9193
case TryBlock(sub, finallyDo, rest) => sub.freeVars ++ finallyDo.freeVars ++ rest.freeVars
9294
case Assign(lhs, rhs, rest) => Set(lhs) ++ rhs.freeVars ++ rest.freeVars
9395
case AssignField(lhs, nme, rhs, rest) => lhs.freeVars ++ rhs.freeVars ++ rest.freeVars
96+
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => lhs.freeVars ++ fld.freeVars ++ rhs.freeVars ++ rest.freeVars
9497
case Define(defn, rest) => defn.freeVars ++ rest.freeVars
9598
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) =>
9699
(bod.freeVars - lhs) ++ rst.freeVars ++ hdr.flatMap(_.freeVars)
97100
case HandleBlockReturn(res) => res.freeVars
98101
case End(msg) => Set.empty
99102

103+
// TODO: freeVarsLLIR skips `fun` and `cls` in `Call` and `Instantiate` respectively, which is needed in some
104+
// other places. However, adding them breaks some LLIR tests. Supposedly, once the IR uses the new symbol system,
105+
// this should no longer happen. This version should be removed once that is resolved.
106+
lazy val freeVarsLLIR: Set[Local] = this match
107+
case Match(scrut, arms, dflt, rest) =>
108+
scrut.freeVarsLLIR ++ dflt.toList.flatMap(_.freeVarsLLIR) ++ rest.freeVarsLLIR
109+
++ arms.flatMap:
110+
(pat, arm) => arm.freeVarsLLIR -- pat.freeVars
111+
case Return(res, implct) => res.freeVarsLLIR
112+
case Throw(exc) => exc.freeVarsLLIR
113+
case Label(label, body, rest) => (body.freeVarsLLIR - label) ++ rest.freeVarsLLIR
114+
case Break(label) => Set(label)
115+
case Continue(label) => Set(label)
116+
case Begin(sub, rest) => sub.freeVarsLLIR ++ rest.freeVarsLLIR
117+
case TryBlock(sub, finallyDo, rest) => sub.freeVarsLLIR ++ finallyDo.freeVarsLLIR ++ rest.freeVarsLLIR
118+
case Assign(lhs, rhs, rest) => Set(lhs) ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
119+
case AssignField(lhs, nme, rhs, rest) => lhs.freeVarsLLIR ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
120+
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => lhs.freeVarsLLIR ++ fld.freeVarsLLIR ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR
121+
case Define(defn, rest) => defn.freeVarsLLIR ++ rest.freeVarsLLIR
122+
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) =>
123+
(bod.freeVarsLLIR - lhs) ++ rst.freeVarsLLIR ++ hdr.flatMap(_.freeVars)
124+
case HandleBlockReturn(res) => res.freeVarsLLIR
125+
case End(msg) => Set.empty
126+
100127
lazy val subBlocks: Ls[Block] = this match
101128
case Match(p, arms, dflt, rest) => p.subBlocks ++ arms.map(_._2) ++ dflt.toList :+ rest
102129
case Begin(sub, rest) => sub :: rest :: Nil
@@ -122,15 +149,19 @@ sealed abstract class Block extends Product with AutoLocated:
122149
// Note that this returns the definitions in reverse order, with the bottommost definiton appearing
123150
// last. This is so that using defns.foldLeft later to add the definitions to the front of a block,
124151
// we don't need to reverse the list again to preserve the order of the definitions.
125-
def floatOutDefns =
152+
def floatOutDefns(
153+
ignore: Defn => Bool = _ => false,
154+
preserve: Defn => Bool = _ => false
155+
) =
126156
var defns: List[Defn] = Nil
127157
val transformer = new BlockTransformerShallow(SymbolSubst()):
128158
override def applyBlock(b: Block): Block = b match
129-
case Define(defn, rest) => defn match
159+
case Define(defn, rest) if !ignore(defn) => defn match
130160
case v: ValDefn => super.applyBlock(b)
131161
case _ =>
132162
defns ::= defn
133-
applyBlock(rest)
163+
if preserve(defn) then super.applyBlock(b)
164+
else applyBlock(rest)
134165
case _ => super.applyBlock(b)
135166

136167
(transformer.applyBlock(this), defns)
@@ -281,10 +312,20 @@ sealed abstract class Defn:
281312
lazy val freeVars: Set[Local] = this match
282313
case FunDefn(own, sym, params, body) => body.freeVars -- params.flatMap(_.paramSyms) - sym
283314
case ValDefn(owner, k, sym, rhs) => rhs.freeVars
284-
case ClsLikeDefn(own, isym, sym, k, paramsOpt, parentSym, methods, privateFields, publicFields, preCtor, ctor) =>
315+
case ClsLikeDefn(own, isym, sym, k, paramsOpt, auxParams, parentSym,
316+
methods, privateFields, publicFields, preCtor, ctor) =>
285317
preCtor.freeVars
286318
++ ctor.freeVars ++ methods.flatMap(_.freeVars)
287-
-- privateFields -- publicFields.map(_.sym)
319+
-- privateFields -- publicFields.map(_.sym) -- auxParams.flatMap(_.paramSyms)
320+
321+
lazy val freeVarsLLIR: Set[Local] = this match
322+
case FunDefn(own, sym, params, body) => body.freeVarsLLIR -- params.flatMap(_.paramSyms) - sym
323+
case ValDefn(owner, k, sym, rhs) => rhs.freeVarsLLIR
324+
case ClsLikeDefn(own, isym, sym, k, paramsOpt, auxParams, parentSym,
325+
methods, privateFields, publicFields, preCtor, ctor) =>
326+
preCtor.freeVarsLLIR
327+
++ ctor.freeVarsLLIR ++ methods.flatMap(_.freeVarsLLIR)
328+
-- privateFields -- publicFields.map(_.sym) -- auxParams.flatMap(_.paramSyms)
288329

289330
final case class FunDefn(
290331
owner: Opt[InnerSymbol],
@@ -304,10 +345,11 @@ final case class ValDefn(
304345

305346
final case class ClsLikeDefn(
306347
owner: Opt[InnerSymbol],
307-
isym: MemberSymbol[? <: ClassLikeDef],
348+
isym: MemberSymbol[? <: ClassLikeDef] & InnerSymbol,
308349
sym: BlockMemberSymbol,
309350
k: syntax.ClsLikeKind,
310351
paramsOpt: Opt[ParamList],
352+
auxParams: List[ParamList],
311353
parentPath: Opt[Path],
312354
methods: Ls[FunDefn],
313355
privateFields: Ls[TermSymbol],
@@ -325,6 +367,7 @@ final case class Handler(
325367
params: Ls[ParamList],
326368
body: Block,
327369
):
370+
lazy val freeVarsLLIR: Set[Local] = body.freeVarsLLIR -- params.flatMap(_.paramSyms) - sym - resumeSym
328371
lazy val freeVars: Set[Local] = body.freeVars -- params.flatMap(_.paramSyms) - sym - resumeSym
329372

330373
/* Represents either unreachable code (for functions that must return a result)
@@ -340,6 +383,13 @@ enum Case:
340383
case Lit(_) => Set.empty
341384
case Cls(_, path) => path.freeVars
342385
case Tup(_, _) => Set.empty
386+
387+
lazy val freeVarsLLIR: Set[Local] = this match
388+
case Lit(_) => Set.empty
389+
case Cls(_, path) => path.freeVarsLLIR
390+
case Tup(_, _) => Set.empty
391+
392+
sealed trait TrivialResult extends Result
343393

344394
sealed abstract class Result:
345395

@@ -353,14 +403,26 @@ sealed abstract class Result:
353403
case _ => Nil
354404

355405
lazy val freeVars: Set[Local] = this match
356-
case Call(fun, args) => args.flatMap(_.value.freeVars).toSet
357-
case Instantiate(cls, args) => args.flatMap(_.freeVars).toSet
406+
case Call(fun, args) => fun.freeVars ++ args.flatMap(_.value.freeVars).toSet
407+
case Instantiate(cls, args) => cls.freeVars ++ args.flatMap(_.freeVars).toSet
358408
case Select(qual, name) => qual.freeVars
359409
case Value.Ref(l) => Set(l)
360410
case Value.This(sym) => Set.empty
361411
case Value.Lit(lit) => Set.empty
362412
case Value.Lam(params, body) => body.freeVars -- params.paramSyms
363413
case Value.Arr(elems) => elems.flatMap(_.value.freeVars).toSet
414+
case DynSelect(qual, fld, arrayIdx) => qual.freeVars ++ fld.freeVars
415+
416+
lazy val freeVarsLLIR: Set[Local] = this match
417+
case Call(fun, args) => args.flatMap(_.value.freeVarsLLIR).toSet
418+
case Instantiate(cls, args) => args.flatMap(_.freeVarsLLIR).toSet
419+
case Select(qual, name) => qual.freeVarsLLIR
420+
case Value.Ref(l) => Set(l)
421+
case Value.This(sym) => Set.empty
422+
case Value.Lit(lit) => Set.empty
423+
case Value.Lam(params, body) => body.freeVarsLLIR -- params.paramSyms
424+
case Value.Arr(elems) => elems.flatMap(_.value.freeVarsLLIR).toSet
425+
case DynSelect(qual, fld, arrayIdx) => qual.freeVarsLLIR ++ fld.freeVarsLLIR
364426

365427
// type Local = LocalSymbol
366428
type Local = Symbol
@@ -373,8 +435,9 @@ case class Call(fun: Path, args: Ls[Arg])(val isMlsFun: Bool, val mayRaiseEffect
373435

374436
case class Instantiate(cls: Path, args: Ls[Path]) extends Result
375437

376-
sealed abstract class Path extends Result:
438+
sealed abstract class Path extends TrivialResult:
377439
def selN(id: Tree.Ident): Path = Select(this, id)(N)
440+
def sel(id: Tree.Ident, sym: FieldSymbol): Path = Select(this, id)(S(sym))
378441
def selSN(id: Str): Path = selN(new Tree.Ident(id))
379442
def asArg = Arg(false, this)
380443

@@ -389,9 +452,15 @@ enum Value extends Path:
389452
case Lit(lit: Literal)
390453
case Lam(params: ParamList, body: Block)
391454
case Arr(elems: Ls[Arg])
455+
case Rcd(elems: Ls[RcdArg])
392456

393457
case class Arg(spread: Bool, value: Path)
394458

459+
// * `IndxdArg(S(idx), value)` represents a key-value pair in a record `(idx): value`
460+
// * `IndxdArg(N, value)` represents a spread element in a record `...value`
461+
case class RcdArg(idx: Opt[Path], value: Path):
462+
def spread: Bool = idx.isEmpty
463+
395464
extension (k: Block => Block)
396465

397466
def chain(other: Block => Block): Block => Block = b => k(other(b))

0 commit comments

Comments
 (0)