Skip to content

Commit f151d58

Browse files
authored
Move UCS normalization before lowering (#308)
1 parent 55e4008 commit f151d58

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+1574
-1654
lines changed

hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -295,19 +295,24 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
295295
val nestCtx1 = ctx.nest
296296
val nestCtx2 = ctx.nest
297297
val patTy = pattern match
298-
case Pattern.ClassLike(sym, _, _, _) =>
299-
val (clsTy, tv, emptyTy) = sym.asCls.flatMap(_.defn) match
300-
case S(cls) =>
301-
(ClassLikeType(sym, cls.tparams.map(_ => freshWildcard(sym))), (freshVar(new TempSymbol(S(scrutinee), "scrut"))), ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty)))
302-
case _ =>
303-
error(msg"Cannot match ${scrutinee.toString} as ${sym.toString}" -> split.toLoc :: Nil)
304-
(Bot, Bot, Bot)
305-
scrutinee match // * refine
306-
case Ref(sym: LocalSymbol) =>
307-
nestCtx1 += sym -> clsTy
308-
nestCtx2 += sym -> tv
309-
case _ => () // TODO: refine all variables holding this value?
310-
clsTy | (tv & Type.mkNegType(emptyTy))
298+
case pat: Pattern.ClassLike =>
299+
pat.constructor.symbol.flatMap(_.asCls) match
300+
case S(sym) =>
301+
val (clsTy, tv, emptyTy) = sym.defn.map(sym -> _) match
302+
case S((sym, cls)) =>
303+
(ClassLikeType(sym, cls.tparams.map(_ => freshWildcard(sym))), (freshVar(new TempSymbol(S(scrutinee), "scrut"))), ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty)))
304+
case _ =>
305+
error(msg"Cannot match ${scrutinee.toString} as ${sym.toString}" -> split.toLoc :: Nil)
306+
(Bot, Bot, Bot)
307+
scrutinee match // * refine
308+
case Ref(sym: LocalSymbol) =>
309+
nestCtx1 += sym -> clsTy
310+
nestCtx2 += sym -> tv
311+
case _ => () // TODO: refine all variables holding this value?
312+
clsTy | (tv & Type.mkNegType(emptyTy))
313+
case N =>
314+
error(msg"Not a valid class: ${pat.constructor.describe}" -> pat.constructor.toLoc :: Nil)
315+
Bot
311316
case Pattern.Lit(lit) => lit match
312317
case _: Tree.BoolLit => BbCtx.boolTy
313318
case _: Tree.IntLit => BbCtx.intTy

hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
220220
args(fs)(args => k(Value.Arr(args)))
221221
case ref @ st.Ref(sym) =>
222222
sym match
223-
case ctx.builtins.source.bms | ctx.builtins.js.bms | ctx.builtins.debug.bms =>
223+
case ctx.builtins.source.bms | ctx.builtins.js.bms | ctx.builtins.debug.bms | ctx.builtins.annotations.bms =>
224224
raise:
225225
ErrorReport(
226226
msg"Module '${sym.nme}' is virtual (i.e., \"compiler fiction\"); cannot be used directly" -> t.toLoc ::
@@ -466,29 +466,36 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
466466
)
467467
pat match
468468
case Pattern.Lit(lit) => mkMatch(Case.Lit(lit) -> go(tail, topLevel = false))
469-
case Pattern.ClassLike(cls: ClassSymbol, _trm, _args0, _refined)
470-
// Do not elaborate `_trm` when the `cls` is virtual.
471-
if Elaborator.ctx.builtins.virtualClasses contains cls =>
472-
// [invariant:0] Some classes (e.g., `Int`) from `Prelude` do
473-
// not exist at runtime. If we do lowering on `trm`, backends
474-
// (e.g., `JSBuilder`) will not be able to handle the corresponding selections.
475-
// In this case the second parameter of `Case.Cls` will not be used.
476-
// So we make it `Predef.unreachable` here.
477-
mkMatch(Case.Cls(cls, unreachableFn) -> go(tail, topLevel = false))
478-
case Pattern.ClassLike(cls, trm, args0, _refined) =>
479-
subTerm_nonTail(trm): st =>
480-
val args = args0.getOrElse(Nil)
481-
val clsParams = cls match
482-
case cls: ClassSymbol => cls.tree.clsParams
483-
case _: ModuleSymbol => Nil
484-
assert(args0.isEmpty || clsParams.length === args.length)
469+
case Pattern.ClassLike(ctor, argsOpt, _mode, _refined) =>
470+
/** Make a continuation that creates the match. */
471+
def k(ctorSym: ClassLikeSymbol, clsParams: Ls[TermSymbol])(st: Path): Block =
472+
val args = argsOpt.map(_.map(_.scrutinee)).getOrElse(Nil)
473+
// Normalization should reject cases where the user provides
474+
// more sub-patterns than there are actual class parameters.
475+
assert(argsOpt.isEmpty || args.length <= clsParams.length)
485476
def mkArgs(args: Ls[TermSymbol -> BlockLocalSymbol])(using Subst): Case -> Block = args match
486477
case Nil =>
487-
Case.Cls(cls, st) -> go(tail, topLevel = false)
478+
Case.Cls(ctorSym, st) -> go(tail, topLevel = false)
488479
case (param, arg) :: args =>
489480
val (cse, blk) = mkArgs(args)
490481
(cse, Assign(arg, Select(sr, param.id/*FIXME incorrect Ident?*/)(S(param)), blk))
491-
mkMatch(mkArgs(clsParams.iterator.zip(args).collect { case (s1, S(s2)) => (s1, s2) }.toList))
482+
mkMatch(mkArgs(clsParams.iterator.zip(args).toList))
483+
ctor.symbol.flatMap(_.asClsOrMod) match
484+
case S(cls: ClassSymbol) if ctx.builtins.virtualClasses contains cls =>
485+
// [invariant:0] Some classes (e.g., `Int`) from `Prelude` do
486+
// not exist at runtime. If we do lowering on `trm`, backends
487+
// (e.g., `JSBuilder`) will not be able to handle the corresponding selections.
488+
// In this case the second parameter of `Case.Cls` will not be used.
489+
// So we do not elaborate `ctor` when the `cls` is virtual
490+
// and use it `Predef.unreachable` here.
491+
k(cls, Nil)(unreachableFn)
492+
case S(cls: ClassSymbol) => subTerm_nonTail(ctor)(k(cls, cls.tree.clsParams))
493+
case S(mod: ModuleSymbol) => subTerm_nonTail(ctor)(k(mod, Nil))
494+
case N =>
495+
// Normalization have already checked the constructor
496+
// resolves to a class or module. Branches with unresolved
497+
// constructors should have been removed.
498+
lastWords("Pattern.ClassLike: constructor is neither a class nor a module")
492499
case Pattern.Tuple(len, inf) => mkMatch(Case.Tup(len, inf) -> go(tail, topLevel = false))
493500
case Pattern.Record(entries) =>
494501
val objectSym = ctx.builtins.Object
@@ -513,11 +520,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
513520
Throw(Instantiate(Select(Value.Ref(State.globalThisSymbol), Tree.Ident("Error"))(N),
514521
Value.Lit(syntax.Tree.StrLit("match error")) :: Nil)) // TODO add failed-match scrutinee info
515522

516-
if k.isInstanceOf[TailOp] && isIf then go(iftrm.normalized, topLevel = true)
523+
val normalize = ucs.Normalization()
524+
val normalized = tl.scoped("ucs:normalize"):
525+
normalize(iftrm.desugared)
526+
tl.scoped("ucs:normalized"):
527+
tl.log(s"Normalized:\n${Split.display(normalized)}")
528+
529+
if k.isInstanceOf[TailOp] && isIf then go(normalized, topLevel = true)
517530
else
518531
val body = if isWhile
519-
then Label(lbl, go(iftrm.normalized, topLevel = true), End())
520-
else go(iftrm.normalized, topLevel = true)
532+
then Label(lbl, go(normalized, topLevel = true), End())
533+
else go(normalized, topLevel = true)
521534
Begin(
522535
body,
523536
if usesResTmp then k(Value.Ref(l))

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ object Elaborator:
156156
object debug extends VirtualModule(assumeBuiltinMod("debug")):
157157
val printStack = assumeObject("printStack")
158158
val getLocals = assumeObject("getLocals")
159+
object annotations extends VirtualModule(assumeBuiltinMod("annotations")):
160+
val compile = assumeObject("compile")
159161
def getBuiltinOp(op: Str): Opt[Str] =
160162
if getBuiltin(op).isDefined then builtinBinOps.get(op) else N
161163
/** Classes that do not use `instanceof` in pattern matching. */
@@ -210,10 +212,22 @@ object Elaborator:
210212
BlockMemberSymbol(id.name, Nil, true)
211213
val matchResultClsSymbol =
212214
val id = new Ident("MatchResult")
213-
ClassSymbol(TypeDef(syntax.Cls, App(id, Tup(Ident("captures") :: Nil)), N, N), id)
215+
val td = TypeDef(syntax.Cls, App(id, Tup(Ident("captures") :: Nil)), N, N)
216+
val cs = ClassSymbol(td, id)
217+
val flag = FldFlags.empty.copy(value = true)
218+
val ps = PlainParamList(Param(flag, VarSymbol(Ident("captures")), N, Modulefulness(N)(false)) :: Nil)
219+
cs.defn = S(ClassDef.Parameterized(N, syntax.Cls, cs, BlockMemberSymbol(cs.name, Nil),
220+
Nil, ps, N, ObjBody(Blk(Nil, Term.Lit(UnitLit(false)))), N, Nil))
221+
cs
214222
val matchFailureClsSymbol =
215223
val id = new Ident("MatchFailure")
216-
ClassSymbol(TypeDef(syntax.Cls, App(id, Tup(Ident("errors") :: Nil)), N, N), id)
224+
val td = TypeDef(syntax.Cls, App(id, Tup(Ident("errors") :: Nil)), N, N)
225+
val cs = ClassSymbol(td, id)
226+
val flag = FldFlags.empty.copy(value = true)
227+
val ps = PlainParamList(Param(flag, VarSymbol(Ident("errors")), N, Modulefulness(N)(false)) :: Nil)
228+
cs.defn = S(ClassDef.Parameterized(N, syntax.Cls, cs, BlockMemberSymbol(cs.name, Nil),
229+
Nil, ps, N, ObjBody(Blk(Nil, Term.Lit(UnitLit(false)))), N, Nil))
230+
cs
217231
val builtinOpsMap =
218232
val baseBuiltins = builtins.map: op =>
219233
op -> BuiltinSymbol(op,
@@ -263,14 +277,6 @@ extends Importer:
263277
N
264278
case _ => N
265279

266-
/** To perform a reverse lookup for a term that references a symbol in the current context. */
267-
def reference(target: ClassSymbol | ModuleSymbol): Ctxl[Opt[Term]] =
268-
def go(ctx: Ctx): Opt[Term] =
269-
ctx.env.values.collectFirst:
270-
case elem if elem.symbol.flatMap(_.asClsLike).contains(target) => elem.ref(target.id)
271-
.orElse(ctx.parent.flatMap(go))
272-
go(ctx).map(Term.SynthSel(_, Ident("class"))(S(target)))
273-
274280
def cls(trm: Term, inAppPrefix: Bool)
275281
: Ctxl[Term]
276282
= trace[Term](s"Elab class ${trm}", r => s"~> $r"):
@@ -467,10 +473,7 @@ extends Importer:
467473
val des = new ucs.Desugarer(this)(tree)
468474
scoped("ucs:desugared"):
469475
log(s"Desugared:\n${Split.display(des)}")
470-
val nor = new ucs.Normalization(this)(des)
471-
scoped("ucs:normalized"):
472-
log(s"Normalized:\n${Split.display(nor)}")
473-
Term.IfLike(Keyword.`if`, des)(nor)
476+
Term.IfLike(Keyword.`if`, des)
474477
case InfixApp(lhs, Keyword.`then`, rhs) =>
475478
raise:
476479
ErrorReport(msg"Unexpected infix use of 'then' keyword here" -> tree.toLoc :: Nil)
@@ -597,23 +600,17 @@ extends Importer:
597600
val desugared = new ucs.Desugarer(this)(tree)
598601
scoped("ucs:desugared"):
599602
log(s"Desugared:\n${Split.display(desugared)}")
600-
val normalized = new ucs.Normalization(this)(desugared)
601-
scoped("ucs:normalized"):
602-
log(s"Normalized:\n${Split.display(normalized)}")
603-
Term.IfLike(kw, desugared)(normalized)
603+
Term.IfLike(kw, desugared)
604604
case Tree.Quoted(body) => Term.Quoted(subterm(body))
605605
case Tree.Unquoted(body) => Term.Unquoted(subterm(body))
606606
case tree @ Tree.Case(_, branches) =>
607607
val scrut = VarSymbol(Ident("caseScrut"))
608608
val des = new ucs.Desugarer(this)(tree, scrut)
609609
scoped("ucs:desugared"):
610610
log(s"Desugared:\n${Split.display(des)}")
611-
val nor = new ucs.Normalization(this)(des)
612-
scoped("ucs:normalized"):
613-
log(s"Normalized:\n${Split.display(nor)}")
614611
Term.Lam(PlainParamList(
615612
Param(FldFlags.empty, scrut, N, Modulefulness.none) :: Nil
616-
), Term.IfLike(Keyword.`if`, des)(nor))
613+
), Term.IfLike(Keyword.`if`, des))
617614
case Modified(Keyword.`return`, kwLoc, body) =>
618615
ctx.getRetHandler match
619616
case ReturnHandler.Required(sym) =>

hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@ package semantics
33

44
import mlscript.utils.*, shorthands.*
55
import syntax.*, Tree.Ident
6+
import Elaborator.{Ctx, ctx}
67
import ucs.DeBrujinSplit
78

9+
import Pattern.*
10+
811
/** Flat patterns for pattern matching */
912
enum Pattern extends AutoLocated:
1013

@@ -13,37 +16,84 @@ enum Pattern extends AutoLocated:
1316
/** An individual argument is None when it is not matched, i.e. when an underscore is used there.
1417
* The whole argument list is None when no argument list is being matched at all, as in `x is Some then ...`. */
1518
case ClassLike(
16-
sym: ClassSymbol | ModuleSymbol,
17-
trm: Term,
18-
args: Opt[List[Opt[BlockLocalSymbol]]],
19-
var refined: Bool,
19+
val constructor: Term,
20+
val arguments: Opt[Ls[Argument]],
21+
val mode: MatchMode,
22+
var refined: Bool
2023
)(val tree: Tree)
2124

22-
case Synonym(symbol: PatternSymbol, patternArguments: Ls[(split: DeBrujinSplit, tree: Tree)])
23-
2425
case Tuple(size: Int, inf: Bool)
2526

2627
case Record(entries: List[(Ident -> BlockLocalSymbol)])
28+
2729

2830
def subTerms: Ls[Term] = this match
29-
case ClassLike(_, t, _, _) => t :: Nil
30-
case _: (Lit | Synonym | Tuple | Record) => Nil
31+
case p: ClassLike => p.constructor :: (p.mode match
32+
case MatchMode.Default | _: MatchMode.StringPrefix => Nil
33+
case MatchMode.Annotated(annotation) => annotation :: Nil)
34+
case _: (Lit | Tuple | Record) => Nil
3135

3236
def children: Ls[Located] = this match
3337
case Lit(literal) => literal :: Nil
34-
case ClassLike(_, t, args, _) =>
35-
t :: args.fold(Nil)(_.collect { case S(symbol) => symbol })
36-
case Synonym(_, arguments) => arguments.map(_.tree)
38+
case ClassLike(ctor, scruts, _, _) => ctor :: scruts.fold(Nil)(_.map(_.scrutinee))
3739
case Tuple(fields, _) => Nil
3840
case Record(entries) => entries.flatMap { case (nme, als) => nme :: als :: Nil }
3941

4042
def showDbg: Str = this match
4143
case Lit(literal) => literal.idStr
42-
case ClassLike(sym, t, ps, rfd) => (if rfd then "refined " else "") +
43-
sym.nme + ps.fold("")(_.iterator.map(_.fold("_")(_.toString)).mkString("(", ", ", ")"))
44-
case Synonym(symbol, arguments) =>
45-
symbol.nme + arguments.iterator.map(_.tree.showDbg).mkString("(", ", ", ")")
44+
case ClassLike(ctor, args, _, rfd) =>
45+
def showCtor(ctor: Term): Str = ctor match
46+
// This prints the symbol name without `refNum` and "member:" prefix.
47+
case Term.Ref(sym: BlockMemberSymbol) => sym.nme
48+
// This prints the symbol without `refNum`.
49+
case Term.Ref(sym) => sym.toString
50+
case Term.Sel(p, i) => s"${showCtor(p)}.${i.name}"
51+
case Term.SynthSel(p, i) => s"${showCtor(p)}.${i.name}"
52+
case _ => ctor.showDbg
53+
(if rfd then "refined " else "") + showCtor(ctor) +
54+
args.fold("")(_.iterator.map(_.scrutinee.nme).mkString("(", ", ", ")"))
4655
case Tuple(size, inf) => "[]" + (if inf then ">=" else "=") + size
4756
case Record(Nil) => "{}"
4857
case Record(entries) =>
4958
entries.iterator.map(_.name + ": " + _).mkString("{ ", ", ", " }")
59+
60+
object Pattern:
61+
/** Represent the type of arguments in `ClassLike` patterns. This type alias
62+
* is used to reduce repetition in the code.
63+
*
64+
* - Field `pattern` is for error messages.
65+
* - Field `split` is for pattern compilation.
66+
* **TODO(ucs/rp)**: Replace with suitable representation when implement
67+
* the new pattern compilation.
68+
*/
69+
type Argument = (scrutinee: BlockLocalSymbol, pattern: Tree, split: Opt[DeBrujinSplit])
70+
71+
/** A class-like pattern whose symbol is resolved to a class. */
72+
object Class:
73+
def unapply(p: Pattern): Opt[ClassSymbol] = p match
74+
case p: Pattern.ClassLike => p.constructor.symbol.flatMap(_.asCls)
75+
case _ => N
76+
77+
/** A class-like pattern whose symbol is resolved to a module. */
78+
object Module:
79+
def unapply(p: Pattern): Opt[ModuleSymbol] = p match
80+
case p: Pattern.ClassLike => p.constructor.symbol.flatMap(_.asModOrObj)
81+
case _ => N
82+
83+
enum MatchMode:
84+
/** The default mode. If the constructor resolves to:
85+
* - a `ClassSymbol`, then check if the scrutinee is an instance;
86+
* - a `ModuleSymbol`, then check if the scrutinee is the object;
87+
* - a `PatternSymbol`, then call `unapply` on the pattern.
88+
*/
89+
case Default
90+
/** Call `unapplyStringPrefix` instead of `unapply`. */
91+
case StringPrefix(prefix: TempSymbol, postfix: TempSymbol)
92+
/** The pattern is annotated. The normalization will intepret the pattern
93+
* matching behavior based on the resolved symbol
94+
*/
95+
case Annotated(annotation: Term)
96+
97+
object ClassLike:
98+
def apply(constructor: Term, arguments: Opt[Ls[BlockLocalSymbol]]): ClassLike =
99+
ClassLike(constructor, arguments.map(_.map(s => (s, Tree.Dummy, N))), MatchMode.Default, false)(Tree.Dummy)

hkmc2/shared/src/main/scala/hkmc2/semantics/Resolver.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ class Resolver(tl: TraceLogger)
289289
case Split.Else(default) =>
290290
traverse(default, expect = Any)
291291
case Split.End =>
292-
split(t.normalized)
292+
split(t.desugared)
293293

294294
case Term.New(cls, args, rft) =>
295295
traverse(cls, expect = Any)

hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ abstract class Symbol(using State) extends Located:
4040
case _ => N
4141
def asMod: Opt[ModuleSymbol] = asModOrObj.filter(_.tree.k is Mod)
4242
def asObj: Opt[ModuleSymbol] = asModOrObj.filter(_.tree.k is Obj)
43+
def asClsOrMod: Opt[ClassSymbol | ModuleSymbol] = asCls orElse asModOrObj
4344
/*
4445
def asTrm: Opt[TermSymbol] = this match
4546
case trm: TermSymbol => S(trm)
@@ -224,6 +225,7 @@ case class ErrorSymbol(val nme: Str, tree: Tree)(using State) extends MemberSymb
224225

225226
sealed trait ClassLikeSymbol extends Symbol:
226227
self: MemberSymbol[? <: ClassDef | ModuleDef] =>
228+
val tree: Tree.TypeDef
227229
def subst(using sub: SymbolSubst): ClassLikeSymbol
228230

229231

hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ sealed trait ResolvableImpl:
6464
case S(td: ClassLikeDef) => S(td)
6565
case _ => N
6666

67-
def withIArgs(iargsLs: Ls[Term.Tup]): Term =
67+
def withIArgs(iargsLs: Ls[Term.Tup]): this.type =
6868
if !(this.iargsLs.isEmpty || this.iargsLs.get == iargsLs) then
6969
lastWords:
7070
s"the implicit arguments for term ${t.showDbg} " +
@@ -89,7 +89,7 @@ enum Term extends Statement:
8989
case SynthSel(prefix: Term, nme: Tree.Ident)(var sym: Opt[FieldSymbol]) extends Term with ResolvableImpl
9090
case DynSel(prefix: Term, fld: Term, arrayIdx: Bool)
9191
case Tup(fields: Ls[Elem])(val tree: Tree.Tup)
92-
case IfLike(kw: Keyword.`if`.type | Keyword.`while`.type, desugared: Split)(val normalized: Split)
92+
case IfLike(kw: Keyword.`if`.type | Keyword.`while`.type, desugared: Split)
9393
case Lam(params: ParamList, body: Term)
9494
case FunTy(lhs: Term, rhs: Term, eff: Opt[Term])
9595
case Forall(tvs: Ls[QuantVar], outer: Opt[VarSymbol], body: Term)

0 commit comments

Comments
 (0)