From cc20ac8340a367fe2793f38d9b2d8416ff48aea3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcial=20Gai=C3=9Fert?= Date: Wed, 9 Jul 2025 19:15:03 +0200 Subject: [PATCH 1/2] Add failing test case --- examples/pos/patternmatching/matching-overloaded.check | 1 + examples/pos/patternmatching/matching-overloaded.effekt | 9 +++++++++ 2 files changed, 10 insertions(+) create mode 100644 examples/pos/patternmatching/matching-overloaded.check create mode 100644 examples/pos/patternmatching/matching-overloaded.effekt diff --git a/examples/pos/patternmatching/matching-overloaded.check b/examples/pos/patternmatching/matching-overloaded.check new file mode 100644 index 000000000..01e50b183 --- /dev/null +++ b/examples/pos/patternmatching/matching-overloaded.check @@ -0,0 +1 @@ +X() \ No newline at end of file diff --git a/examples/pos/patternmatching/matching-overloaded.effekt b/examples/pos/patternmatching/matching-overloaded.effekt new file mode 100644 index 000000000..6024102f0 --- /dev/null +++ b/examples/pos/patternmatching/matching-overloaded.effekt @@ -0,0 +1,9 @@ +type A { X() } +type B { X() } + +def main() = { + val a: A = X() + a match { + case X() => println("X()") + } +} \ No newline at end of file From aef7ae6da60abc56e596dc46eaa09ab44e06b0d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcial=20Gai=C3=9Fert?= Date: Wed, 9 Jul 2025 19:15:50 +0200 Subject: [PATCH 2/2] Initial draft of overload resolution for tag patterns --- .../shared/src/main/scala/effekt/Namer.scala | 19 +++- .../shared/src/main/scala/effekt/Typer.scala | 100 ++++++++++++++---- .../src/main/scala/effekt/source/Tree.scala | 2 +- .../src/main/scala/effekt/symbols/Scope.scala | 3 + .../main/scala/effekt/symbols/symbols.scala | 10 ++ .../effekt/typer/ExhaustivityChecker.scala | 2 +- 6 files changed, 106 insertions(+), 30 deletions(-) diff --git a/effekt/shared/src/main/scala/effekt/Namer.scala b/effekt/shared/src/main/scala/effekt/Namer.scala index 8ee9b2d5a..9473baee8 100644 --- a/effekt/shared/src/main/scala/effekt/Namer.scala +++ b/effekt/shared/src/main/scala/effekt/Namer.scala @@ -691,11 +691,8 @@ object Namer extends Phase[Parsed, NameResolved] { Context.assignSymbol(id, p) List(p) case source.TagPattern(id, patterns, _) => - Context.resolveTerm(id) match { - case symbol: Constructor => () - case _ => Context.at(id) { - Context.error("Can only pattern match on constructors of data types.") - } + if (!Context.resolveOverloadedTag(id)) { + Context.error("Can only pattern match on constructors of data types.") } patterns.flatMap { resolve } case source.MultiPattern(patterns, _) => @@ -1010,6 +1007,18 @@ trait NamerOps extends ContextOps { Context: Context => sym } + private[namer] def resolveOverloadedTag(id: IdRef): Boolean = at(id) { + val syms = scope.lookupTerms(id).flatMap { ssyms => + val conss = ssyms.collect { + case c: Constructor => c + } + if conss.isEmpty then Nil else List(conss) + } + + if (syms.nonEmpty) { assignSymbol(id, MatchTarget(syms)); true } else { false} + + } + private[namer] def addHole(h: Hole): Unit = val src = module.source val holesSoFar = annotationOption(Annotations.HolesForFile, src).getOrElse(Nil) diff --git a/effekt/shared/src/main/scala/effekt/Typer.scala b/effekt/shared/src/main/scala/effekt/Typer.scala index 49ac20c7e..0b096d733 100644 --- a/effekt/shared/src/main/scala/effekt/Typer.scala +++ b/effekt/shared/src/main/scala/effekt/Typer.scala @@ -567,6 +567,45 @@ object Typer extends Phase[NameResolved, Typechecked] { // // + private def resolveTagOverload(id: source.IdRef, + successes: List[List[(Constructor, Map[Symbol, ValueType], TyperState)]], + failures: List[(Constructor, EffektMessages)] + )(using Context): Map[Symbol, ValueType] = { + + successes foreachAborting { + // continue in outer scope + case Nil => () + + // Exactly one successful result in the current scope + case List((sym, tpe, st)) => + // use the typer state after this checking pass + Context.restoreTyperstate(st) + // reassign symbol of fun to resolved calltarget symbol + Context.annotateSymbol(id, sym) + + return tpe + + // Ambiguous reference + case results => + val successfulOverloads = results.map { (sym, res, st) => (sym, findFunctionTypeFor(sym)._1) } + Context.abort(AmbiguousOverloadError(successfulOverloads, Context.rangeOf(id))) + } + + failures match { + case Nil => + Context.abort("Cannot typecheck call.") + + // exactly one error + case List((sym, errs)) => + Context.abortWith(errs) + + case failed => + // reraise all and abort + val failures = failed.map { case (block, msgs) => (block, findFunctionTypeFor(block)._1, msgs) } + Context.abort(FailedOverloadError(failures, Context.currentRange)) + } + } + def checkPattern(sc: ValueType, pattern: MatchPattern)(using Context, Captures): Map[Symbol, ValueType] = Context.focusing(pattern) { case source.IgnorePattern(_) => Map.empty case p @ source.AnyPattern(id, _) => Map(p.symbol -> sc) @@ -575,40 +614,55 @@ object Typer extends Phase[NameResolved, Typechecked] { Map.empty case p @ source.TagPattern(id, patterns, _) => - // symbol of the constructor we match against - val sym: Constructor = p.definition - - val universals = sym.tparams.take(sym.tpe.tparams.size) - val existentials = sym.tparams.drop(sym.tpe.tparams.size) + def checkTagPattern(sym: Constructor, patterns: List[MatchPattern]): Map[Symbol, ValueType] = { + val universals = sym.tparams.take(sym.tpe.tparams.size) + val existentials = sym.tparams.drop(sym.tpe.tparams.size) - // create fresh unification variables - val freshUniversals = universals.map { t => Context.freshTypeVar(t, pattern) } - // create fresh **bound** variables - val freshExistentials = existentials.map { t => TypeVar.TypeParam(t.name) } + // create fresh unification variables + val freshUniversals = universals.map { t => Context.freshTypeVar(t, pattern) } + // create fresh **bound** variables + val freshExistentials = existentials.map { t => TypeVar.TypeParam(t.name) } - Context.annotate(Annotations.TypeParameters, p, freshExistentials) + Context.annotate(Annotations.TypeParameters, p, freshExistentials) - val targs = (freshUniversals ++ freshExistentials).map { t => ValueTypeRef(t) } + val targs = (freshUniversals ++ freshExistentials).map { t => ValueTypeRef(t) } - // (4) Compute blocktype of this constructor with rigid type vars - // i.e. Cons : `(?t1, List[?t1]) => List[?t1]` - val (vps, _, ret, _) = Context.instantiate(sym.toType, targs, Nil) + // (4) Compute blocktype of this constructor with rigid type vars + // i.e. Cons : `(?t1, List[?t1]) => List[?t1]` + val (vps, _, ret, _) = Context.instantiate(sym.toType, targs, Nil) - // (5) given a scrutinee of `List[Int]`, we learn `?t1 -> Int` - matchPattern(sc, ret, p) + // (5) given a scrutinee of `List[Int]`, we learn `?t1 -> Int` + matchPattern(sc, ret, p) - // (6) check nested patterns - var bindings = Map.empty[Symbol, ValueType] + // (6) check nested patterns + var bindings = Map.empty[Symbol, ValueType] - if (patterns.size != vps.size) + if (patterns.size != vps.size) Context.abort(s"Wrong number of pattern arguments, given ${patterns.size}, expected ${vps.size}.") - (patterns zip vps) foreach { - case (pat, par: ValueType) => - bindings ++= checkPattern(par, pat) + (patterns zip vps) foreach { + case (pat, par: ValueType) => + bindings ++= checkPattern(par, pat) + } + + bindings } - bindings + // symbol of the constructor we match against + val scopes = p.definition match { + case MatchTarget(syms) => syms + case sym: Constructor => List(Set(sym)) + } + + val bindingss = scopes map { scope => tryEach(scope.toList) { constructor => + checkTagPattern(constructor, patterns) + }} + + val successes = bindingss.map { scope => scope._1 } + val errors = bindingss.flatMap { scope => scope._2 } + + resolveTagOverload(id, successes, errors) + case source.MultiPattern(patterns, _) => Context.panic("Multi-pattern should have been split at the match and not occur nested.") } match { case res => Context.annotateInferredType(pattern, sc); res } diff --git a/effekt/shared/src/main/scala/effekt/source/Tree.scala b/effekt/shared/src/main/scala/effekt/source/Tree.scala index 1e99012b1..238583fce 100644 --- a/effekt/shared/src/main/scala/effekt/source/Tree.scala +++ b/effekt/shared/src/main/scala/effekt/source/Tree.scala @@ -781,7 +781,7 @@ object Named { case Handler => symbols.BlockTypeConstructor.Interface case OpClause => symbols.Operation case Implementation => symbols.BlockTypeConstructor.Interface - case TagPattern => symbols.Constructor + case TagPattern => symbols.Constructor | symbols.MatchTarget } extension [T <: Definitions](t: T & Definition) { diff --git a/effekt/shared/src/main/scala/effekt/symbols/Scope.scala b/effekt/shared/src/main/scala/effekt/symbols/Scope.scala index ec47ada3f..fad044f5c 100644 --- a/effekt/shared/src/main/scala/effekt/symbols/Scope.scala +++ b/effekt/shared/src/main/scala/effekt/symbols/Scope.scala @@ -214,6 +214,9 @@ object scopes { else syms.head }} + def lookupTerms(id: IdRef)(using E: ErrorReporter): List[Set[TermSymbol]] = + all(id.path, scope) { _.terms.getOrElse(id.name, Set.empty) } + def lookupType(id: IdRef)(using E: ErrorReporter): TypeSymbol = lookupTypeOption(id.path, id.name) getOrElse { E.abort(pp"Could not resolve type ${id}") } diff --git a/effekt/shared/src/main/scala/effekt/symbols/symbols.scala b/effekt/shared/src/main/scala/effekt/symbols/symbols.scala index 0af96f81f..c3fa42988 100644 --- a/effekt/shared/src/main/scala/effekt/symbols/symbols.scala +++ b/effekt/shared/src/main/scala/effekt/symbols/symbols.scala @@ -214,6 +214,16 @@ case class CallTarget(symbols: List[Set[BlockSymbol]]) extends BlockSymbol { val decl = NoSource } +/** + * Synthetic symbol representing potentially multiple tags for matching + * + * Refined by typer. + */ +case class MatchTarget(symbols: List[Set[Constructor]]) extends BlockSymbol { + val name = NoName + val decl = NoSource +} + /** * Introduced by Transformer */ diff --git a/effekt/shared/src/main/scala/effekt/typer/ExhaustivityChecker.scala b/effekt/shared/src/main/scala/effekt/typer/ExhaustivityChecker.scala index 8dde6c09e..d0124a644 100644 --- a/effekt/shared/src/main/scala/effekt/typer/ExhaustivityChecker.scala +++ b/effekt/shared/src/main/scala/effekt/typer/ExhaustivityChecker.scala @@ -110,7 +110,7 @@ object ExhaustivityChecker { def preprocessPattern(p: source.MatchPattern)(using Context): Pattern = p match { case AnyPattern(id, _) => Pattern.Any() case IgnorePattern(_) => Pattern.Any() - case p @ TagPattern(id, patterns, _) => Pattern.Tag(p.definition, patterns.map(preprocessPattern)) + case p @ TagPattern(id, patterns, _) => Pattern.Tag(p.definition.asInstanceOf[Constructor], patterns.map(preprocessPattern)) case LiteralPattern(lit, _) => Pattern.Literal(lit.value, lit.tpe) case MultiPattern(patterns, _) => Context.panic("Multi-pattern should have been split in preprocess already / nested MultiPattern")