diff --git a/Doc/reference/compound_stmts.rst b/Doc/reference/compound_stmts.rst index e95fa3a6424e23..36bd911c05f09a 100644 --- a/Doc/reference/compound_stmts.rst +++ b/Doc/reference/compound_stmts.rst @@ -1098,6 +1098,11 @@ The same keyword should not be repeated in class patterns. The following is the logical flow for matching a class pattern against a subject value: +#. If ``name_or_attr`` is a union type, apply the subsequent steps in order to + each of its members, returning the first successful match or raising the first + encountered exception. + This mirrors the behavior of :func:`isinstance` with union types. + #. If ``name_or_attr`` is not an instance of the builtin :class:`type` , raise :exc:`TypeError`. diff --git a/Lib/test/test_patma.py b/Lib/test/test_patma.py index 5d0857b059ea23..cb6662d3eaaff1 100644 --- a/Lib/test/test_patma.py +++ b/Lib/test/test_patma.py @@ -15,6 +15,13 @@ class Point: y: int +@dataclasses.dataclass +class Point3D: + x: int + y: int + z: int + + class TestCompiler(unittest.TestCase): def test_refleaks(self): @@ -2888,6 +2895,84 @@ class B(A): ... h = 1 self.assertEqual(h, 1) + def test_patma_union_type(self): + IntOrStr = int | str + w = None + match 0: + case IntOrStr(): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_no_match(self): + StrOrBytes = str | bytes + w = None + match 0: + case StrOrBytes(): + w = 0 + self.assertIsNone(w) + + def test_union_type_positional_subpattern(self): + IntOrStr = int | str + w = None + match 0: + case IntOrStr(y): + w = y + self.assertEqual(w, 0) + + def test_union_type_keyword_subpattern(self): + EitherPoint = Point | Point3D + p = Point(x=1, y=2) + w = None + match p: + case EitherPoint(x=1, y=2): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_arg(self): + p = Point(x=1, y=2) + IntOrStr = int | str + w = None + match p: + case Point(IntOrStr(), IntOrStr()): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_kwarg(self): + p = Point(x=1, y=2) + IntOrStr = int | str + w = None + match p: + case Point(x=IntOrStr(), y=IntOrStr()): + w = 0 + self.assertEqual(w, 0) + + def test_patma_union_arg_no_match(self): + p = Point(x=1, y=2) + StrOrBytes = str | bytes + w = None + match p: + case Point(StrOrBytes(), StrOrBytes()): + w = 0 + self.assertIsNone(w) + + def test_patma_union_kwarg_no_match(self): + p = Point(x=1, y=2) + StrOrBytes = str | bytes + w = None + match p: + case Point(x=StrOrBytes(), y=StrOrBytes()): + w = 0 + self.assertIsNone(w) + + def test_union_type_match_second_member(self): + EitherPoint = Point | Point3D + p = Point3D(x=1, y=2, z=3) + w = None + match p: + case EitherPoint(x=1, y=2, z=3): + w = 0 + self.assertEqual(w, 0) + class TestSyntaxErrors(unittest.TestCase): @@ -3230,8 +3315,28 @@ def test_mapping_pattern_duplicate_key_edge_case3(self): pass """) + class TestTypeErrors(unittest.TestCase): + def test_generic_type(self): + t = list[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + + def test_legacy_generic_type(self): + from typing import List + t = List[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + def test_accepts_positional_subpatterns_0(self): class Class: __match_args__ = () @@ -3341,6 +3446,121 @@ def test_class_pattern_not_type(self): w = 0 self.assertIsNone(w) + def test_class_or_union_not_specialform(self): + from typing import Literal + name = type(Literal).__name__ + msg = rf"called match pattern must be a class or a union of classes \(got {name}\)" + w = None + with self.assertRaisesRegex(TypeError, msg): + match 1: + case Literal(): + w = 0 + self.assertIsNone(w) + + def test_typing_union(self): + from typing import Union + IntOrStr = Union[int, str] # identical to int | str since gh-105499 + w = False + match 1: + case IntOrStr(): + w = True + self.assertIs(w, True) + + def test_expanded_union_mirrors_isinstance_success(self): + ListOfInt = list[int] + t = int | ListOfInt + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case int() | ListOfInt(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIs(result, True) + self.assertIs(reference, True) + + def test_expanded_union_mirrors_isinstance_failure(self): + ListOfInt = list[int] + t = ListOfInt | int + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case ListOfInt() | int(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIsInstance(result, TypeError) + self.assertIsInstance(reference, TypeError) + + def test_union_type_mirrors_isinstance_success(self): + t = int | list[int] + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case t(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIs(result, True) + self.assertIs(reference, True) + + def test_union_type_mirrors_isinstance_failure(self): + t = list[int] | int + + try: # get the isinstance result + reference = isinstance(1, t) + except TypeError as exc: + reference = exc + + try: # get the match-case result + match 1: + case t(): + result = True + case _: + result = False + except TypeError as exc: + result = exc + + # we should ge the same result + self.assertIsInstance(result, TypeError) + self.assertIsInstance(reference, TypeError) + + def test_generic_union_type(self): + from collections.abc import Sequence, Set + t = Sequence[str] | Set[str] + w = None + with self.assertRaises(TypeError): + match ["s"]: + case t(): + w = 0 + self.assertIsNone(w) + def test_regular_protocol(self): from typing import Protocol class P(Protocol): ... diff --git a/Python/ceval.c b/Python/ceval.c index 291e753dec0ce5..fd102ea4516aab 100644 --- a/Python/ceval.c +++ b/Python/ceval.c @@ -39,6 +39,7 @@ #include "pycore_template.h" // _PyTemplate_Build() #include "pycore_traceback.h" // _PyTraceBack_FromFrame #include "pycore_tuple.h" // _PyTuple_ITEMS() +#include "pycore_unionobject.h" // _PyUnion_Check() #include "pycore_uop_ids.h" // Uops #include "dictobject.h" @@ -725,9 +726,27 @@ PyObject* _PyEval_MatchClass(PyThreadState *tstate, PyObject *subject, PyObject *type, Py_ssize_t nargs, PyObject *kwargs) { + // Recurse on unions. + if (_PyUnion_Check(type)) { + // get union members + PyObject *members = _Py_union_args(type); + const Py_ssize_t n = PyTuple_GET_SIZE(members); + + // iterate over union members and return first match + for (Py_ssize_t i = 0; i < n; i++) { + PyObject *member = PyTuple_GET_ITEM(members, i); + PyObject *attrs = _PyEval_MatchClass(tstate, subject, member, nargs, kwargs); + // match found + if (attrs != NULL) { + return attrs; + } + } + // no match found + return NULL; + } if (!PyType_Check(type)) { - const char *e = "called match pattern must be a class"; - _PyErr_Format(tstate, PyExc_TypeError, e); + const char *e = "called match pattern must be a class or a union of classes (got %s)"; + _PyErr_Format(tstate, PyExc_TypeError, e, Py_TYPE(type)->tp_name); return NULL; } assert(PyTuple_CheckExact(kwargs));