Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions effekt/shared/src/main/scala/effekt/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -691,11 +691,8 @@ object Namer extends Phase[Parsed, NameResolved] {
Context.assignSymbol(id, p)
List(p)
case source.TagPattern(id, patterns, _) =>
Context.resolveTerm(id) match {
case symbol: Constructor => ()
case _ => Context.at(id) {
Context.error("Can only pattern match on constructors of data types.")
}
if (!Context.resolveOverloadedTag(id)) {
Context.error("Can only pattern match on constructors of data types.")
}
patterns.flatMap { resolve }
case source.MultiPattern(patterns, _) =>
Expand Down Expand Up @@ -1010,6 +1007,18 @@ trait NamerOps extends ContextOps { Context: Context =>
sym
}

private[namer] def resolveOverloadedTag(id: IdRef): Boolean = at(id) {
val syms = scope.lookupTerms(id).flatMap { ssyms =>
val conss = ssyms.collect {
case c: Constructor => c
}
if conss.isEmpty then Nil else List(conss)
}

if (syms.nonEmpty) { assignSymbol(id, MatchTarget(syms)); true } else { false}

}

private[namer] def addHole(h: Hole): Unit =
val src = module.source
val holesSoFar = annotationOption(Annotations.HolesForFile, src).getOrElse(Nil)
Expand Down
100 changes: 77 additions & 23 deletions effekt/shared/src/main/scala/effekt/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,45 @@ object Typer extends Phase[NameResolved, Typechecked] {
//</editor-fold>

//<editor-fold desc="pattern matching">
private def resolveTagOverload(id: source.IdRef,
successes: List[List[(Constructor, Map[Symbol, ValueType], TyperState)]],
failures: List[(Constructor, EffektMessages)]
)(using Context): Map[Symbol, ValueType] = {

successes foreachAborting {
// continue in outer scope
case Nil => ()

// Exactly one successful result in the current scope
case List((sym, tpe, st)) =>
// use the typer state after this checking pass
Context.restoreTyperstate(st)
// reassign symbol of fun to resolved calltarget symbol
Context.annotateSymbol(id, sym)

return tpe

// Ambiguous reference
case results =>
val successfulOverloads = results.map { (sym, res, st) => (sym, findFunctionTypeFor(sym)._1) }
Context.abort(AmbiguousOverloadError(successfulOverloads, Context.rangeOf(id)))
}

failures match {
case Nil =>
Context.abort("Cannot typecheck call.")

// exactly one error
case List((sym, errs)) =>
Context.abortWith(errs)

case failed =>
// reraise all and abort
val failures = failed.map { case (block, msgs) => (block, findFunctionTypeFor(block)._1, msgs) }
Context.abort(FailedOverloadError(failures, Context.currentRange))
}
}

def checkPattern(sc: ValueType, pattern: MatchPattern)(using Context, Captures): Map[Symbol, ValueType] = Context.focusing(pattern) {
case source.IgnorePattern(_) => Map.empty
case p @ source.AnyPattern(id, _) => Map(p.symbol -> sc)
Expand All @@ -575,40 +614,55 @@ object Typer extends Phase[NameResolved, Typechecked] {
Map.empty
case p @ source.TagPattern(id, patterns, _) =>

// symbol of the constructor we match against
val sym: Constructor = p.definition

val universals = sym.tparams.take(sym.tpe.tparams.size)
val existentials = sym.tparams.drop(sym.tpe.tparams.size)
def checkTagPattern(sym: Constructor, patterns: List[MatchPattern]): Map[Symbol, ValueType] = {
val universals = sym.tparams.take(sym.tpe.tparams.size)
val existentials = sym.tparams.drop(sym.tpe.tparams.size)

// create fresh unification variables
val freshUniversals = universals.map { t => Context.freshTypeVar(t, pattern) }
// create fresh **bound** variables
val freshExistentials = existentials.map { t => TypeVar.TypeParam(t.name) }
// create fresh unification variables
val freshUniversals = universals.map { t => Context.freshTypeVar(t, pattern) }
// create fresh **bound** variables
val freshExistentials = existentials.map { t => TypeVar.TypeParam(t.name) }

Context.annotate(Annotations.TypeParameters, p, freshExistentials)
Context.annotate(Annotations.TypeParameters, p, freshExistentials)

val targs = (freshUniversals ++ freshExistentials).map { t => ValueTypeRef(t) }
val targs = (freshUniversals ++ freshExistentials).map { t => ValueTypeRef(t) }

// (4) Compute blocktype of this constructor with rigid type vars
// i.e. Cons : `(?t1, List[?t1]) => List[?t1]`
val (vps, _, ret, _) = Context.instantiate(sym.toType, targs, Nil)
// (4) Compute blocktype of this constructor with rigid type vars
// i.e. Cons : `(?t1, List[?t1]) => List[?t1]`
val (vps, _, ret, _) = Context.instantiate(sym.toType, targs, Nil)

// (5) given a scrutinee of `List[Int]`, we learn `?t1 -> Int`
matchPattern(sc, ret, p)
// (5) given a scrutinee of `List[Int]`, we learn `?t1 -> Int`
matchPattern(sc, ret, p)

// (6) check nested patterns
var bindings = Map.empty[Symbol, ValueType]
// (6) check nested patterns
var bindings = Map.empty[Symbol, ValueType]

if (patterns.size != vps.size)
if (patterns.size != vps.size)
Context.abort(s"Wrong number of pattern arguments, given ${patterns.size}, expected ${vps.size}.")

(patterns zip vps) foreach {
case (pat, par: ValueType) =>
bindings ++= checkPattern(par, pat)
(patterns zip vps) foreach {
case (pat, par: ValueType) =>
bindings ++= checkPattern(par, pat)
}

bindings
}

bindings
// symbol of the constructor we match against
val scopes = p.definition match {
case MatchTarget(syms) => syms
case sym: Constructor => List(Set(sym))
}

val bindingss = scopes map { scope => tryEach(scope.toList) { constructor =>
checkTagPattern(constructor, patterns)
}}

val successes = bindingss.map { scope => scope._1 }
val errors = bindingss.flatMap { scope => scope._2 }

resolveTagOverload(id, successes, errors)

case source.MultiPattern(patterns, _) =>
Context.panic("Multi-pattern should have been split at the match and not occur nested.")
} match { case res => Context.annotateInferredType(pattern, sc); res }
Expand Down
2 changes: 1 addition & 1 deletion effekt/shared/src/main/scala/effekt/source/Tree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -781,7 +781,7 @@ object Named {
case Handler => symbols.BlockTypeConstructor.Interface
case OpClause => symbols.Operation
case Implementation => symbols.BlockTypeConstructor.Interface
case TagPattern => symbols.Constructor
case TagPattern => symbols.Constructor | symbols.MatchTarget
}

extension [T <: Definitions](t: T & Definition) {
Expand Down
3 changes: 3 additions & 0 deletions effekt/shared/src/main/scala/effekt/symbols/Scope.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ object scopes {
else syms.head
}}

def lookupTerms(id: IdRef)(using E: ErrorReporter): List[Set[TermSymbol]] =
all(id.path, scope) { _.terms.getOrElse(id.name, Set.empty) }

def lookupType(id: IdRef)(using E: ErrorReporter): TypeSymbol =
lookupTypeOption(id.path, id.name) getOrElse { E.abort(pp"Could not resolve type ${id}") }

Expand Down
10 changes: 10 additions & 0 deletions effekt/shared/src/main/scala/effekt/symbols/symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,16 @@ case class CallTarget(symbols: List[Set[BlockSymbol]]) extends BlockSymbol {
val decl = NoSource
}

/**
* Synthetic symbol representing potentially multiple tags for matching
*
* Refined by typer.
*/
case class MatchTarget(symbols: List[Set[Constructor]]) extends BlockSymbol {
val name = NoName
val decl = NoSource
}

/**
* Introduced by Transformer
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ object ExhaustivityChecker {
def preprocessPattern(p: source.MatchPattern)(using Context): Pattern = p match {
case AnyPattern(id, _) => Pattern.Any()
case IgnorePattern(_) => Pattern.Any()
case p @ TagPattern(id, patterns, _) => Pattern.Tag(p.definition, patterns.map(preprocessPattern))
case p @ TagPattern(id, patterns, _) => Pattern.Tag(p.definition.asInstanceOf[Constructor], patterns.map(preprocessPattern))
case LiteralPattern(lit, _) => Pattern.Literal(lit.value, lit.tpe)
case MultiPattern(patterns, _) =>
Context.panic("Multi-pattern should have been split in preprocess already / nested MultiPattern")
Expand Down
1 change: 1 addition & 0 deletions examples/pos/patternmatching/matching-overloaded.check
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
X()
9 changes: 9 additions & 0 deletions examples/pos/patternmatching/matching-overloaded.effekt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
type A { X() }
type B { X() }

def main() = {
val a: A = X()
a match {
case X() => println("X()")
}
}
Loading