diff --git a/CHANGELOG.md b/CHANGELOG.md index 733505a5..9a63ecb2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Unreleased +- Backport CPython PR [#137281](https://github.com/python/cpython/pull/137281), + fixing how type parameters are collected when a `Protocol` base class is parametrized + with type variables. Now, parametrized `Generic` or `Protocol` base classes always + dictate the number and the order of the type parameters. Patch by Brian Schubert, + backporting a CPython PR by Nikita Sobolev. - Fix incorrect behaviour on Python 3.9 and Python 3.10 that meant that calling `isinstance` with `typing_extensions.Concatenate[...]` or `typing_extensions.Unpack[...]` as the first argument could have a different diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 88fa699e..8387060f 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -3599,12 +3599,14 @@ class C: pass def test_defining_generic_protocols(self): T = TypeVar('T') + T2 = TypeVar('T2') S = TypeVar('S') @runtime_checkable class PR(Protocol[T, S]): def meth(self): pass class P(PR[int, T], Protocol[T]): y = 1 + self.assertEqual(P.__parameters__, (T,)) with self.assertRaises(TypeError): issubclass(PR[int, T], PR) with self.assertRaises(TypeError): @@ -3613,16 +3615,23 @@ class P(PR[int, T], Protocol[T]): PR[int] with self.assertRaises(TypeError): P[int, str] + with self.assertRaisesRegex( + TypeError, + re.escape('Some type variables (~S) are not listed in Protocol[~T, ~T2]'), + ): + class ExtraTypeVars(P[S], Protocol[T, T2]): ... if not TYPING_3_10_0: with self.assertRaises(TypeError): PR[int, 1] with self.assertRaises(TypeError): PR[int, ClassVar] class C(PR[int, T]): pass + self.assertEqual(C.__parameters__, (T,)) self.assertIsInstance(C[str](), C) def test_defining_generic_protocols_old_style(self): T = TypeVar('T') + T2 = TypeVar('T2') S = TypeVar('S') @runtime_checkable class PR(Protocol, Generic[T, S]): @@ -3639,8 +3648,15 @@ class P(PR[int, str], Protocol): PR[int, 1] class P1(Protocol, Generic[T]): def bar(self, x: T) -> str: ... + self.assertEqual(P1.__parameters__, (T,)) class P2(Generic[T], Protocol): def bar(self, x: T) -> str: ... + self.assertEqual(P2.__parameters__, (T,)) + msg = re.escape('Some type variables (~S) are not listed in Protocol[~T, ~T2]') + with self.assertRaisesRegex(TypeError, msg): + class ExtraTypeVars(P1[S], Protocol[T, T2]): ... + with self.assertRaisesRegex(TypeError, msg): + class ExtraTypeVars(P2[S], Protocol[T, T2]): ... @runtime_checkable class PSub(P1[str], Protocol): x = 1 @@ -3653,9 +3669,33 @@ def bar(self, x: str) -> str: with self.assertRaises(TypeError): PR[int, ClassVar] + def test_protocol_parameter_order(self): + # https://github.com/python/cpython/issues/137191 + T1 = TypeVar("T1") + T2 = TypeVar("T2", default=object) + + class A(Protocol[T1]): ... + + class B0(A[T2], Generic[T1, T2]): ... + self.assertEqual(B0.__parameters__, (T1, T2)) + + class B1(A[T2], Protocol, Generic[T1, T2]): ... + self.assertEqual(B1.__parameters__, (T1, T2)) + + class B2(A[T2], Protocol[T1, T2]): ... + self.assertEqual(B2.__parameters__, (T1, T2)) + if hasattr(typing, "TypeAliasType"): exec(textwrap.dedent( """ + def test_pep695_protocol_parameter_order(self): + class A[T1](Protocol): ... + class B3[T1, T2](A[T2], Protocol): + @staticmethod + def get_typeparams(): + return (T1, T2) + self.assertEqual(B3.__parameters__, B3.get_typeparams()) + def test_pep695_generic_protocol_callable_members(self): @runtime_checkable class Foo[T](Protocol): diff --git a/src/typing_extensions.py b/src/typing_extensions.py index bd67a80a..bbcffcbe 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -3208,7 +3208,13 @@ def _is_unpacked_typevartuple(x) -> bool: ) -# Python 3.11+ _collect_type_vars was renamed to _collect_parameters +# - Python 3.11+ _collect_type_vars was renamed to _collect_parameters. +# Breakpoint: https://github.com/python/cpython/pull/31143 +# - Python 3.13+ _collect_parameters was renamed to _collect_type_parameters. +# Breakpoint: https://github.com/python/cpython/pull/118900 +# - Monkey patch Generic.__init_subclass__ on <3.15 to fix type parameter +# collection from Protocol bases with listed parameters. +# Breakpoint: https://github.com/python/cpython/pull/137281 if hasattr(typing, '_collect_type_vars'): def _collect_type_vars(types, typevar_types=None): """Collect all type variable contained in types in order of @@ -3258,21 +3264,82 @@ def _collect_type_vars(types, typevar_types=None): tvars.append(collected) return tuple(tvars) + def _generic_init_subclass(cls, *args, **kwargs): + super(Generic, cls).__init_subclass__(*args, **kwargs) + tvars = [] + if '__orig_bases__' in cls.__dict__: + error = Generic in cls.__orig_bases__ + else: + error = (Generic in cls.__bases__ and + cls.__name__ != 'Protocol' and + type(cls) not in (_TypedDictMeta, typing._TypedDictMeta)) + if error: + raise TypeError("Cannot inherit from plain Generic") + if '__orig_bases__' in cls.__dict__: + typevar_types = (TypeVar, typing.TypeVar, ParamSpec) + if hasattr(typing, "ParamSpec"): # Python 3.10+ + typevar_types += (typing.ParamSpec,) + tvars = _collect_type_vars(cls.__orig_bases__, typevar_types) + # Look for Generic[T1, ..., Tn]. + # If found, tvars must be a subset of it. + # If not found, tvars is it. + # Also check for and reject plain Generic, + # and reject multiple Generic[...]. + gvars = None + basename = None + for base in cls.__orig_bases__: + if (isinstance(base, typing._GenericAlias) and + base.__origin__ in (Generic, typing.Protocol, Protocol)): + if gvars is not None: + raise TypeError( + "Cannot inherit from Generic[...] multiple times." + ) + gvars = base.__parameters__ + basename = base.__origin__.__name__ + if gvars is not None: + tvarset = set(tvars) + gvarset = set(gvars) + if not tvarset <= gvarset: + s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) + s_args = ', '.join(str(g) for g in gvars) + raise TypeError( + f"Some type variables ({s_vars}) are" + f" not listed in {basename}[{s_args}]" + ) + tvars = gvars + cls.__parameters__ = tuple(tvars) + typing._collect_type_vars = _collect_type_vars -else: - def _collect_parameters(args): + typing.Generic.__init_subclass__ = classmethod(_generic_init_subclass) +elif sys.version_info < (3, 15): + def _collect_parameters( + args, + *, + enforce_default_ordering=_marker, + validate_all=False, + ): """Collect all type variables and parameter specifications in args in order of first appearance (lexicographic order). + Having an explicit `Generic` or `Protocol` base class determines + the exact parameter order. + For example:: - assert _collect_parameters((T, Callable[P, T])) == (T, P) + >>> P = ParamSpec('P') + >>> T = TypeVar('T') + >>> _collect_parameters((T, Callable[P, T])) + (~T, ~P) + >>> _collect_parameters((list[T], Generic[P, T])) + (~P, ~T) """ parameters = [] # A required TypeVarLike cannot appear after a TypeVarLike with default # if it was a direct call to `Generic[]` or `Protocol[]` - enforce_default_ordering = _has_generic_or_protocol_as_origin() + if enforce_default_ordering is _marker: + enforce_default_ordering = _has_generic_or_protocol_as_origin() + default_encountered = False # Also, a TypeVarLike with a default cannot appear after a TypeVarTuple @@ -3307,6 +3374,17 @@ def _collect_parameters(args): ' follows type parameter with a default') parameters.append(t) + elif ( + not validate_all + and isinstance(t, typing._GenericAlias) + and t.__origin__ in (Generic, typing.Protocol, Protocol) + ): + # If we see explicit `Generic[...]` or `Protocol[...]` base classes, + # we need to just copy them as-is. + # Unless `validate_all` is passed, in this case it means that + # we are doing a validation of `Generic` subclasses, + # then we collect all unique parameters to be able to inspect them. + parameters = t.__parameters__ else: if _is_unpacked_typevartuple(t): type_var_tuple_encountered = True @@ -3316,8 +3394,55 @@ def _collect_parameters(args): return tuple(parameters) - if not _PEP_696_IMPLEMENTED: + def _generic_init_subclass(cls, *args, **kwargs): + super(Generic, cls).__init_subclass__(*args, **kwargs) + tvars = [] + if '__orig_bases__' in cls.__dict__: + error = Generic in cls.__orig_bases__ + else: + error = (Generic in cls.__bases__ and + cls.__name__ != 'Protocol' and + type(cls) not in (_TypedDictMeta, typing._TypedDictMeta)) + if error: + raise TypeError("Cannot inherit from plain Generic") + if '__orig_bases__' in cls.__dict__: + tvars = _collect_parameters(cls.__orig_bases__, validate_all=True) + # Look for Generic[T1, ..., Tn]. + # If found, tvars must be a subset of it. + # If not found, tvars is it. + # Also check for and reject plain Generic, + # and reject multiple Generic[...]. + gvars = None + basename = None + for base in cls.__orig_bases__: + if (isinstance(base, typing._GenericAlias) and + base.__origin__ in (Generic, typing.Protocol, Protocol)): + if gvars is not None: + raise TypeError( + "Cannot inherit from Generic[...] multiple times." + ) + gvars = base.__parameters__ + basename = base.__origin__.__name__ + if gvars is not None: + tvarset = set(tvars) + gvarset = set(gvars) + if not tvarset <= gvarset: + s_vars = ', '.join(str(t) for t in tvars if t not in gvarset) + s_args = ', '.join(str(g) for g in gvars) + raise TypeError( + f"Some type variables ({s_vars}) are" + f" not listed in {basename}[{s_args}]" + ) + tvars = gvars + cls.__parameters__ = tuple(tvars) + + if _PEP_696_IMPLEMENTED: + typing._collect_type_parameters = _collect_parameters + typing._generic_init_subclass = _generic_init_subclass + else: typing._collect_parameters = _collect_parameters + typing.Generic.__init_subclass__ = classmethod(_generic_init_subclass) + # Backport typing.NamedTuple as it exists in Python 3.13. # In 3.11, the ability to define generic `NamedTuple`s was supported.