@@ -7,7 +7,7 @@ import mlscript.utils.*, shorthands.*
77import Message .MessageContext
88import utils .TraceLogger
99import syntax .Literal
10- import Keyword .{as , and , `do` , `else` , is , let , `then` }
10+ import Keyword .{as , and , `do` , `else` , is , let , `then` , where }
1111import collection .mutable .{HashMap , SortedSet }
1212import Elaborator .{ctx , Ctxl }
1313import scala .annotation .targetName
@@ -21,6 +21,7 @@ object Desugarer:
2121
2222 class ScrutineeData :
2323 val classes : HashMap [ClassSymbol , List [BlockLocalSymbol ]] = HashMap .empty
24+ val fields : HashMap [Ident , BlockLocalSymbol ] = HashMap .empty
2425 val tupleLead : HashMap [Int , BlockLocalSymbol ] = HashMap .empty
2526 val tupleLast : HashMap [Int , BlockLocalSymbol ] = HashMap .empty
2627end Desugarer
@@ -89,18 +90,13 @@ class Desugarer(val elaborator: Elaborator)
8990 def ++ (fallback : Split ): Split =
9091 if fallback == Split .End then
9192 split
92- else if split.isFull then
93- raise :
94- ErrorReport :
95- msg " The following branches are unreachable. " -> fallback.toLoc ::
96- msg " Because the previous split is full. " -> split.toLoc :: Nil
97- split
9893 else (split match
9994 case Split .Cons (head, tail) => Split .Cons (head, tail ++ fallback)
10095 case Split .Let (name, term, tail) => Split .Let (name, term, tail ++ fallback)
101- case Split .Else (_) /* impossible */ | Split .End => fallback)
96+ case Split .Else (_) | Split .End => fallback)
10297
10398 private val subScrutineeMap = HashMap .empty[BlockLocalSymbol , ScrutineeData ]
99+ private val fieldScrutineeMap = HashMap .empty[BlockLocalSymbol , ScrutineeData ]
104100
105101 extension (symbol : BlockLocalSymbol )
106102 def getSubScrutinees (cls : ClassSymbol ): List [BlockLocalSymbol ] =
@@ -113,6 +109,11 @@ class Desugarer(val elaborator: Elaborator)
113109 def getTupleLastSubScrutinee (index : Int ): BlockLocalSymbol =
114110 val data = subScrutineeMap.getOrElseUpdate(symbol, new ScrutineeData )
115111 data.tupleLast.getOrElseUpdate(index, TempSymbol (N , s " last $index" ))
112+ def getFieldScrutinee (fieldName : Ident ): BlockLocalSymbol =
113+ subScrutineeMap
114+ .getOrElseUpdate(symbol, new ScrutineeData )
115+ .fields
116+ .getOrElseUpdate(fieldName, TempSymbol (N , s " field ${fieldName.name}" ))
116117
117118
118119 def default : Split => Sequel = split => _ => split
@@ -459,9 +460,7 @@ class Desugarer(val elaborator: Elaborator)
459460 if pat.patternParams.size > 0 then
460461 error(
461462 msg " Pattern ` ${pat.nme}` expects ${" pattern argument" .pluralize(pat.patternParams.size, true )}" ->
462- pat.patternParams.foldLeft[Opt [Loc ]](N ):
463- case (N , param) => param.sym.toLoc
464- case (S (loc), param) => S (loc ++ param.sym.toLoc),
463+ Loc (pat.patternParams.iterator.map(_.sym)),
465464 msg " But no arguments were given " -> ctor.toLoc)
466465 fallback
467466 else
@@ -508,12 +507,8 @@ class Desugarer(val elaborator: Elaborator)
508507 if pat.patternParams.size != patArgs.size then
509508 error(
510509 msg " Pattern ` ${pat.nme}` expects ${" pattern argument" .pluralize(pat.patternParams.size, true )}" ->
511- pat.patternParams.foldLeft[Opt [Loc ]](N ):
512- case (N , param) => param.sym.toLoc
513- case (S (loc), param) => S (loc ++ param.sym.toLoc),
514- msg " But ${" pattern argument" .pluralize(patArgs.size, true )} were given " -> args.foldLeft[Opt [Loc ]](N ):
515- case (N , arg) => arg.toLoc
516- case (S (loc), arg) => S (loc ++ arg.toLoc))
510+ Loc (pat.patternParams.iterator.map(_.sym)),
511+ msg " But ${" pattern argument" .pluralize(patArgs.size, true )} were given " -> Loc (args))
517512 fallback
518513 else
519514 Branch (ref, Pattern .Synonym (pat, patArgs.zip(args)), sequel(ctx)) ~: fallback
@@ -589,6 +584,12 @@ class Desugarer(val elaborator: Elaborator)
589584 Branch (ref, Pattern .Lit (IntLit (- value)), sequel(ctx)) ~: fallback
590585 case App (Ident (" -" ), Tup (DecLit (value) :: Nil )) => fallback => ctx =>
591586 Branch (ref, Pattern .Lit (DecLit (- value)), sequel(ctx)) ~: fallback
587+ case App (Ident (" &" ), Tree .Tup (lhs :: rhs :: Nil )) => fallback => ctx =>
588+ val newSequel = expandMatch(scrutSymbol, rhs, sequel)(fallback)
589+ expandMatch(scrutSymbol, lhs, newSequel)(fallback)(ctx)
590+ case App (Ident (" |" ), Tree .Tup (lhs :: rhs :: Nil )) => fallback => ctx =>
591+ val newFallback = expandMatch(scrutSymbol, rhs, sequel)(fallback)(ctx)
592+ expandMatch(scrutSymbol, lhs, sequel)(newFallback)(ctx)
592593 // A single constructor pattern.
593594 case Annotated (Ident (" compile" ), app @ App (ctor : Ctor , Tup (args))) =>
594595 dealWithAppCtorCase(app, ctor, args, true )
@@ -607,18 +608,50 @@ class Desugarer(val elaborator: Elaborator)
607608 case pattern and consequent => fallback => ctx =>
608609 val innerSplit = termSplit(consequent, identity)(Split .End )
609610 expandMatch(scrutSymbol, pattern, innerSplit)(fallback)(ctx)
611+ case pattern where condition => fallback => ctx =>
612+ val sym = TempSymbol (N , " conditionTemp" )
613+ val newSequel = expandMatch(sym, Tree .BoolLit (true ), sequel)(fallback)
614+ val newNewSequel = (ctx : Ctx ) => Split .Let (sym, term(condition)(using ctx), newSequel(ctx))
615+ expandMatch(scrutSymbol, pattern, newNewSequel)(fallback)(ctx)
610616 case Jux (Ident (" .." ), Ident (_)) => fallback => _ =>
611617 raise(ErrorReport (msg " Illegal rest pattern. " -> pattern.toLoc :: Nil ))
612618 fallback
613- case InfixApp (id : Ident , Keyword .`:`, pat) => fallback => ctx =>
614- val sym = VarSymbol (id)
615- val ctx2 = ctx
616- // + (id.name -> sym) // * This binds the field's name in the context; probably surprising
617- Split .Let (sym, ref.sel(id, N ),
618- expandMatch(sym, pat, sequel)(fallback)(ctx2))
619+ case InfixApp (fieldName : Ident , Keyword .`:`, pat) => fallback => ctx =>
620+ val symbol = scrutSymbol.getFieldScrutinee(fieldName)
621+ Branch (
622+ ref,
623+ Pattern .Record ((fieldName, symbol) :: Nil ),
624+ subMatches((R (symbol), pat) :: Nil , sequel)(Split .End )(ctx)
625+ ) ~: fallback
626+ case Pun (false , fieldName) => fallback => ctx =>
627+ val symbol = scrutSymbol.getFieldScrutinee(fieldName)
628+ Branch (
629+ ref,
630+ Pattern .Record ((fieldName, symbol) :: Nil ),
631+ subMatches((R (symbol), fieldName) :: Nil , sequel)(Split .End )(ctx)
632+ ) ~: fallback
619633 case Block (st :: Nil ) => fallback => ctx =>
620634 expandMatch(scrutSymbol, st, sequel)(fallback)(ctx)
621- // case Block(sts) => fallback => ctx => // TODO
635+ case Block (sts) => fallback => ctx => // we assume this is a record
636+ sts.foldRight[Option [List [(Tree .Ident , BlockLocalSymbol , Tree )]]](S (Nil )){
637+ // this collects the record parts, or fails if some statement does not correspond
638+ // to a record field
639+ case (_, N ) => N // we only need to fail once to return N
640+ case (p, S (tl)) => p match
641+ case InfixApp (fieldName : Ident , Keyword .`:`, pat) =>
642+ S ((fieldName, scrutSymbol.getFieldScrutinee(fieldName), pat) :: tl)
643+ case Pun (false , fieldName) =>
644+ S ((fieldName, scrutSymbol.getFieldScrutinee(fieldName), fieldName) :: tl)
645+ case p =>
646+ raise(ErrorReport (msg " invalid record field pattern " -> p.toLoc :: Nil ))
647+ None
648+ }.fold(fallback)(recordContent =>
649+ Branch (
650+ ref,
651+ Pattern .Record (recordContent.map((fieldName, symbol, _) => (fieldName, symbol))),
652+ subMatches(recordContent.map((_, symbol, pat) => (R (symbol), pat)), sequel)(Split .End )(ctx)
653+ ) ~: fallback
654+ )
622655 case Bra (BracketKind .Curly | BracketKind .Round , inner) => fallback => ctx =>
623656 expandMatch(scrutSymbol, inner, sequel)(fallback)(ctx)
624657 case pattern => fallback => _ =>
0 commit comments