Skip to content

Commit 87a080d

Browse files
marzipankaiserLena Käufel
authored andcommitted
Allow lambda case patterns for multiple parameters (effekt-lang#914)
Resolves effekt-lang#761 by allowing multiple patterns in a lambda case seperated by `,` like: ``` def foo[A,B,C](){ fn: (A, B) => C }: C = ... //... foo(){ case true, y => ... case x, y => ... } ``` ## Implementation This becomes a `Match` with multiple scrutinees, which is resolved in the pattern matching compiler. Typer checks that the clauses have the correct number of patterns (which will be assumed later).
1 parent 9b9266c commit 87a080d

15 files changed

+196
-37
lines changed

effekt/shared/src/main/scala/effekt/Namer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,8 @@ object Namer extends Phase[Parsed, NameResolved] {
654654
}
655655
}
656656
patterns.flatMap { resolve }
657+
case source.MultiPattern(patterns) =>
658+
patterns.flatMap { resolve }
657659
}
658660

659661
def resolve(p: source.MatchGuard)(using Context): Unit = p match {

effekt/shared/src/main/scala/effekt/RecursiveDescent.scala

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ class RecursiveDescent(positions: Positions, tokens: Seq[Token], source: Source)
406406
val default = when(`else`) { Some(stmt()) } { None }
407407
val body = semi() ~> stmts()
408408
val clause = MatchClause(p, guards, body).withRangeOf(p, sc)
409-
val matching = Match(sc, List(clause), default).withRangeOf(startMarker, sc)
409+
val matching = Match(List(sc), List(clause), default).withRangeOf(startMarker, sc)
410410
Return(matching)
411411
}
412412

@@ -757,8 +757,13 @@ class RecursiveDescent(positions: Positions, tokens: Seq[Token], source: Source)
757757

758758
def matchClause(): MatchClause =
759759
nonterminal:
760+
val patterns = `case` ~> some(matchPattern, `,`)
761+
val pattern = patterns match {
762+
case List(pat) => pat
763+
case pats => MultiPattern(pats)
764+
}
760765
MatchClause(
761-
`case` ~> matchPattern(),
766+
pattern,
762767
manyWhile(`and` ~> matchGuard(), `and`),
763768
// allow a statement enclosed in braces or without braces
764769
// both is allowed since match clauses are already delimited by `case`
@@ -802,7 +807,7 @@ class RecursiveDescent(positions: Positions, tokens: Seq[Token], source: Source)
802807
while (peek(`match`)) {
803808
val clauses = `match` ~> braces { manyWhile(matchClause(), `case`) }
804809
val default = when(`else`) { Some(stmt()) } { None }
805-
sc = Match(sc, clauses, default)
810+
sc = Match(List(sc), clauses, default)
806811
}
807812
sc
808813

@@ -944,14 +949,18 @@ class RecursiveDescent(positions: Positions, tokens: Seq[Token], source: Source)
944949
peek.kind match {
945950
// { case ... => ... }
946951
case `case` => someWhile(matchClause(), `case`) match { case cs =>
952+
val arity = cs match {
953+
case MatchClause(MultiPattern(ps), _, _) :: _ => ps.length
954+
case _ => 1
955+
}
947956
// TODO positions should be improved here and fresh names should be generated for the scrutinee
948957
// also mark the temp name as synthesized to prevent it from being listed in VSCode
949-
val name = "__tmpRes"
958+
val names = List.tabulate(arity){ n => s"__arg${n}" }
950959
BlockLiteral(
951960
Nil,
952-
List(ValueParam(IdDef(name), None)),
961+
names.map{ name => ValueParam(IdDef(name), None) },
953962
Nil,
954-
Return(Match(Var(IdRef(Nil, name)), cs, None))) : BlockLiteral
963+
Return(Match(names.map{ name => Var(IdRef(Nil, name)) }, cs, None))) : BlockLiteral
955964
}
956965
case _ =>
957966
// { (x: Int) => ... }
@@ -1453,8 +1462,36 @@ class RecursiveDescent(positions: Positions, tokens: Seq[Token], source: Source)
14531462
// case _ => ()
14541463
// }
14551464

1456-
positions.setStart(res, source.offsetToPosition(start))
1457-
positions.setFinish(res, source.offsetToPosition(end))
1465+
val startPos = source.offsetToPosition(start)
1466+
val endPos = source.offsetToPosition(end)
1467+
1468+
// recursively add positions to subtrees that are not yet annotated
1469+
// this is better than nothing and means we have positions for desugared stuff
1470+
def annotatePositions(res: Any): Unit = res match {
1471+
case l: List[_] =>
1472+
if (positions.getRange(l).isEmpty) {
1473+
positions.setStart(l, startPos)
1474+
positions.setFinish(l, endPos)
1475+
l.foreach(annotatePositions)
1476+
}
1477+
case t: Tree =>
1478+
val recurse = positions.getRange(t).isEmpty
1479+
if(positions.getStart(t).isEmpty) positions.setStart(t, startPos)
1480+
if(positions.getFinish(t).isEmpty) positions.setFinish(t, endPos)
1481+
t match {
1482+
case p: Product if recurse =>
1483+
p.productIterator.foreach { c =>
1484+
annotatePositions(c)
1485+
}
1486+
case _ => ()
1487+
}
1488+
case _ => ()
1489+
}
1490+
annotatePositions(res)
1491+
1492+
// still annotate, in case it is not Tree
1493+
positions.setStart(res, startPos)
1494+
positions.setFinish(res, endPos)
14581495

14591496
res
14601497
}

effekt/shared/src/main/scala/effekt/Typer.scala

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -287,18 +287,43 @@ object Typer extends Phase[NameResolved, Typechecked] {
287287

288288
Result(ret, (effs -- handled) ++ handlerEffs)
289289

290-
case tree @ source.Match(sc, clauses, default) =>
290+
case tree @ source.Match(scs, clauses, default) =>
291291

292-
// (1) Check scrutinee
292+
// (1) Check scrutinees
293293
// for example. tpe = List[Int]
294-
val Result(tpe, effs) = checkExpr(sc, None)
294+
val results = scs.map{ sc => checkExpr(sc, None) }
295295

296-
var resEff = effs
296+
var resEff = ConcreteEffects.union(results.map{ case Result(tpe, effs) => effs })
297+
298+
// check that number of patterns matches number of scrutinees
299+
val arity = scs.length
300+
clauses.foreach {
301+
case cls @ source.MatchClause(source.MultiPattern(patterns), guards, body) =>
302+
if (patterns.length != arity) {
303+
Context.at(cls){
304+
Context.error(pp"Number of patterns (${patterns.length}) does not match number of parameters / scrutinees (${arity}).")
305+
}
306+
}
307+
case cls @ source.MatchClause(pattern, guards, body) =>
308+
if (arity != 1) {
309+
Context.at(cls) {
310+
Context.error(pp"Number of patterns (1) does not match number of parameters / scrutinees (${arity}).")
311+
}
312+
}
313+
}
297314

298315
val tpes = clauses.map {
299316
case source.MatchClause(p, guards, body) =>
300-
// (3) infer types for pattern
301-
Context.bind(checkPattern(tpe, p))
317+
// (3) infer types for pattern(s)
318+
p match {
319+
case source.MultiPattern(ps) =>
320+
(results zip ps).foreach { case (Result(tpe, effs), p) =>
321+
Context.bind(checkPattern(tpe, p))
322+
}
323+
case p =>
324+
val Result(tpe, effs) = results.head
325+
Context.bind(checkPattern(tpe, p))
326+
}
302327
// infer types for guards
303328
val Result((), guardEffs) = checkGuards(guards)
304329
// check body of the clause
@@ -592,6 +617,8 @@ object Typer extends Phase[NameResolved, Typechecked] {
592617
}
593618

594619
bindings
620+
case source.MultiPattern(patterns) =>
621+
Context.panic("Multi-pattern should have been split at the match and not occur nested.")
595622
} match { case res => Context.annotateInferredType(pattern, sc); res }
596623

597624
def checkGuard(guard: MatchGuard)(using Context, Captures): Result[Map[Symbol, ValueType]] = guard match {

effekt/shared/src/main/scala/effekt/core/Transformer.scala

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -427,14 +427,14 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
427427
Context.bind(loopCall)
428428

429429
// Empty match (matching on Nothing)
430-
case source.Match(sc, Nil, None) =>
430+
case source.Match(List(sc), Nil, None) =>
431431
val scrutinee: ValueVar = Context.bind(transformAsPure(sc))
432432
Context.bind(core.Match(scrutinee, Nil, None))
433433

434-
case source.Match(sc, cs, default) =>
434+
case source.Match(scs, cs, default) =>
435435
// (1) Bind scrutinee and all clauses so we do not have to deal with sharing on demand.
436-
val scrutinee: ValueVar = Context.bind(transformAsPure(sc))
437-
val clauses = cs.zipWithIndex.map((c, i) => preprocess(s"k${i}", scrutinee, c))
436+
val scrutinees: List[ValueVar] = scs.map{ sc => Context.bind(transformAsPure(sc)) }
437+
val clauses = cs.zipWithIndex.map((c, i) => preprocess(s"k${i}", scrutinees, c))
438438
val defaultClause = default.map(stmt => preprocess("k_els", Nil, Nil, transform(stmt))).toList
439439
val compiledMatch = PatternMatchingCompiler.compile(clauses ++ defaultClause)
440440
Context.bind(compiledMatch)
@@ -653,8 +653,14 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
653653
})
654654
}
655655

656-
def preprocess(label: String, sc: ValueVar, clause: source.MatchClause)(using Context): Clause =
657-
preprocess(label, List((sc, clause.pattern)), clause.guards, transform(clause.body))
656+
def preprocess(label: String, scs: List[ValueVar], clause: source.MatchClause)(using Context): Clause = {
657+
val patterns = (clause.pattern, scs) match {
658+
case (source.MultiPattern(ps), scs) => scs.zip(ps)
659+
case (pattern, List(sc)) => List((sc, clause.pattern))
660+
case (_, _) => Context.abort("Malformed multi-match")
661+
}
662+
preprocess(label, patterns, clause.guards, transform(clause.body))
663+
}
658664

659665
def preprocess(label: String, patterns: List[(ValueVar, source.MatchPattern)], guards: List[source.MatchGuard], body: core.Stmt)(using Context): Clause = {
660666
import PatternMatchingCompiler.*
@@ -663,6 +669,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
663669
case p @ source.AnyPattern(id) => List(ValueParam(p.symbol))
664670
case source.TagPattern(id, patterns) => patterns.flatMap(boundInPattern)
665671
case _: source.LiteralPattern | _: source.IgnorePattern => Nil
672+
case source.MultiPattern(patterns) => patterns.flatMap(boundInPattern)
666673
}
667674
def boundInGuard(g: source.MatchGuard): List[core.ValueParam] = g match {
668675
case MatchGuard.BooleanGuard(condition) => Nil
@@ -672,6 +679,7 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
672679
case source.AnyPattern(id) => List()
673680
case p @ source.TagPattern(id, patterns) => Context.annotation(Annotations.TypeParameters, p) ++ patterns.flatMap(boundTypesInPattern)
674681
case _: source.LiteralPattern | _: source.IgnorePattern => Nil
682+
case source.MultiPattern(patterns) => patterns.flatMap(boundTypesInPattern)
675683
}
676684
def boundTypesInGuard(g: source.MatchGuard): List[Id] = g match {
677685
case MatchGuard.BooleanGuard(condition) => Nil
@@ -708,6 +716,8 @@ object Transformer extends Phase[Typechecked, CoreTransformed] {
708716
Pattern.Ignore()
709717
case source.LiteralPattern(source.Literal(value, tpe)) =>
710718
Pattern.Literal(Literal(value, transform(tpe)), equalsFor(tpe))
719+
case source.MultiPattern(patterns) =>
720+
Context.panic("Multi-pattern should have been split on toplevel / nested MultiPattern")
711721
}
712722

713723
def transformGuard(p: source.MatchGuard): List[Condition] =

effekt/shared/src/main/scala/effekt/source/Tree.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ enum Term extends Tree {
382382
// Control Flow
383383
case If(guards: List[MatchGuard], thn: Stmt, els: Stmt)
384384
case While(guards: List[MatchGuard], block: Stmt, default: Option[Stmt])
385-
case Match(scrutinee: Term, clauses: List[MatchClause], default: Option[Stmt])
385+
case Match(scrutinees: List[Term], clauses: List[MatchClause], default: Option[Stmt])
386386

387387
/**
388388
* Handling effects
@@ -504,6 +504,15 @@ enum MatchPattern extends Tree {
504504
* A pattern that matches a single literal value
505505
*/
506506
case LiteralPattern(l: Literal)
507+
508+
/**
509+
* A pattern for multiple values
510+
*
511+
* case a, b => ...
512+
*
513+
* Currently should *only* occur in lambda-cases during Parsing
514+
*/
515+
case MultiPattern(patterns: List[MatchPattern]) extends MatchPattern
507516
}
508517
export MatchPattern.*
509518

effekt/shared/src/main/scala/effekt/typer/BoxUnboxInference.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ object BoxUnboxInference extends Phase[NameResolved, NameResolved] {
6868
case While(guards, body, default) =>
6969
While(guards.map(rewrite), rewrite(body), default.map(rewrite))
7070

71-
case Match(sc, clauses, default) =>
72-
Match(rewriteAsExpr(sc), clauses.map(rewrite), default.map(rewrite))
71+
case Match(scs, clauses, default) =>
72+
Match(scs.map(rewriteAsExpr), clauses.map(rewrite), default.map(rewrite))
7373

7474
case s @ Select(recv, name) if s.definition.isInstanceOf[Field] =>
7575
Select(rewriteAsExpr(recv), name)

effekt/shared/src/main/scala/effekt/typer/ConcreteEffects.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ object ConcreteEffects {
5252
def apply(effs: Effects)(using Context): ConcreteEffects = apply(effs.toList)
5353

5454
def empty: ConcreteEffects = fromList(Nil)
55+
56+
def union(effs: IterableOnce[ConcreteEffects]): ConcreteEffects = {
57+
ConcreteEffects.fromList(effs.iterator.flatMap{ e => e.effects }.toList)
58+
}
5559
}
5660

5761
val Pure = ConcreteEffects.empty

effekt/shared/src/main/scala/effekt/typer/ExhaustivityChecker.scala

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,28 @@ object ExhaustivityChecker {
9292

9393
// Scrutinees are identified by tracing from the original scrutinee.
9494
enum Trace {
95-
case Root(scrutinee: source.Term)
95+
case Root(scrutinees: source.Term)
9696
case Child(c: Constructor, field: Field, outer: Trace)
9797
}
9898

9999

100-
def preprocess(root: source.Term, cl: source.MatchClause)(using Context): Clause = cl match {
101-
case source.MatchClause(pattern, guards, body) =>
100+
def preprocess(roots: List[source.Term], cl: source.MatchClause)(using Context): Clause = (roots, cl) match {
101+
case (List(root), source.MatchClause(pattern, guards, body)) =>
102102
Clause.normalized(Condition.Patterns(Map(Trace.Root(root) -> preprocessPattern(pattern))) :: guards.map(preprocessGuard), cl)
103+
case (roots, source.MatchClause(MultiPattern(patterns), guards, body)) =>
104+
val rootConds: Map[Trace, Pattern] = (roots zip patterns).map { case (root, pattern) =>
105+
Trace.Root(root) -> preprocessPattern(pattern)
106+
}.toMap
107+
Clause.normalized(Condition.Patterns(rootConds) :: guards.map(preprocessGuard), cl)
108+
case (_, _) => Context.abort("Malformed multi-match")
103109
}
104110
def preprocessPattern(p: source.MatchPattern)(using Context): Pattern = p match {
105111
case AnyPattern(id) => Pattern.Any()
106112
case IgnorePattern() => Pattern.Any()
107113
case p @ TagPattern(id, patterns) => Pattern.Tag(p.definition, patterns.map(preprocessPattern))
108114
case LiteralPattern(lit) => Pattern.Literal(lit.value, lit.tpe)
115+
case MultiPattern(patterns) =>
116+
Context.panic("Multi-pattern should have been split in preprocess already / nested MultiPattern")
109117
}
110118
def preprocessGuard(g: source.MatchGuard)(using Context): Condition = g match {
111119
case MatchGuard.BooleanGuard(condition) =>
@@ -121,7 +129,7 @@ object ExhaustivityChecker {
121129
* - non exhaustive pattern match should generate a list of patterns, so the IDE can insert them
122130
* - redundant cases should generate a list of cases that can be deleted.
123131
*/
124-
class Exhaustivity(allClauses: List[source.MatchClause]) {
132+
class Exhaustivity(allClauses: List[source.MatchClause], originalScrutinees: List[source.Term]) {
125133

126134
// Redundancy Information
127135
// ----------------------
@@ -152,7 +160,8 @@ object ExhaustivityChecker {
152160
def reportNonExhaustive()(using C: ErrorReporter): Unit = {
153161
@tailrec
154162
def traceToCase(at: Trace, acc: String): String = at match {
155-
case Trace.Root(_) => acc
163+
case Trace.Root(_) if originalScrutinees.length == 1 => acc
164+
case Trace.Root(e) => originalScrutinees.map { f => if e == f then acc else "_" }.mkString(", ")
156165
case Trace.Child(childCtor, field, outer) =>
157166
val newAcc = s"${childCtor.name}(${childCtor.fields.map { f => if f == field then acc else "_" }.mkString(", ")})"
158167
traceToCase(outer, newAcc)
@@ -191,13 +200,23 @@ object ExhaustivityChecker {
191200
}
192201
}
193202

194-
def checkExhaustive(scrutinee: source.Term, cls: List[source.MatchClause])(using C: Context): Unit = {
195-
val initialClauses: List[Clause] = cls.map(preprocess(scrutinee, _))
196-
given E: Exhaustivity = new Exhaustivity(cls)
197-
checkScrutinee(Trace.Root(scrutinee), Context.inferredTypeOf(scrutinee), initialClauses)
203+
def checkExhaustive(scrutinees: List[source.Term], cls: List[source.MatchClause])(using C: Context): Unit = {
204+
val initialClauses: List[Clause] = cls.map(preprocess(scrutinees, _))
205+
given E: Exhaustivity = new Exhaustivity(cls, scrutinees)
206+
checkScrutinees(scrutinees.map(Trace.Root(_)), scrutinees.map{ scrutinee => Context.inferredTypeOf(scrutinee) }, initialClauses)
198207
E.report()
199208
}
200209

210+
def checkScrutinees(scrutinees: List[Trace], tpes: List[ValueType], clauses: List[Clause])(using E: Exhaustivity): Unit = {
211+
(scrutinees, tpes) match {
212+
case (List(scrutinee), List(tpe)) => checkScrutinee(scrutinee, tpe, clauses)
213+
case _ =>
214+
clauses match {
215+
case Nil => E.missingDefault(tpes.head, scrutinees.head)
216+
case head :: tail => matchClauses(head, tail)
217+
}
218+
}
219+
}
201220

202221
def checkScrutinee(scrutinee: Trace, tpe: ValueType, clauses: List[Clause])(using E: Exhaustivity): Unit = {
203222

effekt/shared/src/main/scala/effekt/typer/Wellformedness.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,17 @@ object Wellformedness extends Phase[Typechecked, Typechecked], Visit[WFContext]
161161
pp"The return type ${tpe} of the region body is not allowed to refer to region ${reg.capture}."
162162
})
163163

164-
case tree @ source.Match(scrutinee, clauses, default) => Context.at(tree) {
164+
case tree @ source.Match(scrutinees, clauses, default) => Context.at(tree) {
165165
// TODO copy annotations from default to synthesized defaultClause (in particular positions)
166-
val defaultClause = default.toList.map(body => source.MatchClause(source.IgnorePattern(), Nil, body))
167-
ExhaustivityChecker.checkExhaustive(scrutinee, clauses ++ defaultClause)
166+
val defaultPattern = scrutinees match {
167+
case List(_) => source.IgnorePattern()
168+
case scs => source.MultiPattern(List.fill(scs.length){source.IgnorePattern()})
169+
}
170+
171+
val defaultClause = default.toList.map(body => source.MatchClause(defaultPattern, Nil, body))
172+
ExhaustivityChecker.checkExhaustive(scrutinees, clauses ++ defaultClause)
168173

169-
query(scrutinee)
174+
scrutinees foreach { query }
170175
clauses foreach { query }
171176
default foreach query
172177

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
def foo[A, B](a: A){ body: A => B }: B = body(a)
2+
3+
def main() = {
4+
foo(true){ // ERROR Non-exhaustive
5+
case false => ()
6+
}
7+
}

0 commit comments

Comments
 (0)