Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Unreleased

- Backport CPython PR [#137281](https://github.com/python/cpython/pull/137281),
fixing how type parameters are collected when a `Protocol` base class is parametrized
with type variables. Now, parametrized `Generic` or `Protocol` base classes always
dictate the number and the order of the type parameters. Patch by Brian Schubert,
backporting a CPython PR by Nikita Sobolev.
- Fix incorrect behaviour on Python 3.9 and Python 3.10 that meant that
calling `isinstance` with `typing_extensions.Concatenate[...]` or
`typing_extensions.Unpack[...]` as the first argument could have a different
Expand Down
40 changes: 40 additions & 0 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3599,12 +3599,14 @@ class C: pass

def test_defining_generic_protocols(self):
T = TypeVar('T')
T2 = TypeVar('T2')
S = TypeVar('S')
@runtime_checkable
class PR(Protocol[T, S]):
def meth(self): pass
class P(PR[int, T], Protocol[T]):
y = 1
self.assertEqual(P.__parameters__, (T,))
with self.assertRaises(TypeError):
issubclass(PR[int, T], PR)
with self.assertRaises(TypeError):
Expand All @@ -3613,16 +3615,23 @@ class P(PR[int, T], Protocol[T]):
PR[int]
with self.assertRaises(TypeError):
P[int, str]
with self.assertRaisesRegex(
TypeError,
re.escape('Some type variables (~S) are not listed in Protocol[~T, ~T2]'),
):
class ExtraTypeVars(P[S], Protocol[T, T2]): ...
if not TYPING_3_10_0:
with self.assertRaises(TypeError):
PR[int, 1]
with self.assertRaises(TypeError):
PR[int, ClassVar]
class C(PR[int, T]): pass
self.assertEqual(C.__parameters__, (T,))
self.assertIsInstance(C[str](), C)

def test_defining_generic_protocols_old_style(self):
T = TypeVar('T')
T2 = TypeVar('T2')
S = TypeVar('S')
@runtime_checkable
class PR(Protocol, Generic[T, S]):
Expand All @@ -3639,8 +3648,15 @@ class P(PR[int, str], Protocol):
PR[int, 1]
class P1(Protocol, Generic[T]):
def bar(self, x: T) -> str: ...
self.assertEqual(P1.__parameters__, (T,))
class P2(Generic[T], Protocol):
def bar(self, x: T) -> str: ...
self.assertEqual(P2.__parameters__, (T,))
msg = re.escape('Some type variables (~S) are not listed in Protocol[~T, ~T2]')
with self.assertRaisesRegex(TypeError, msg):
class ExtraTypeVars(P1[S], Protocol[T, T2]): ...
with self.assertRaisesRegex(TypeError, msg):
class ExtraTypeVars(P2[S], Protocol[T, T2]): ...
@runtime_checkable
class PSub(P1[str], Protocol):
x = 1
Expand All @@ -3653,9 +3669,33 @@ def bar(self, x: str) -> str:
with self.assertRaises(TypeError):
PR[int, ClassVar]

def test_protocol_parameter_order(self):
# https://github.com/python/cpython/issues/137191
T1 = TypeVar("T1")
T2 = TypeVar("T2", default=object)

class A(Protocol[T1]): ...

class B0(A[T2], Generic[T1, T2]): ...
self.assertEqual(B0.__parameters__, (T1, T2))

class B1(A[T2], Protocol, Generic[T1, T2]): ...
self.assertEqual(B1.__parameters__, (T1, T2))

class B2(A[T2], Protocol[T1, T2]): ...
self.assertEqual(B2.__parameters__, (T1, T2))

if hasattr(typing, "TypeAliasType"):
exec(textwrap.dedent(
"""
def test_pep695_protocol_parameter_order(self):
class A[T1](Protocol): ...
class B3[T1, T2](A[T2], Protocol):
@staticmethod
def get_typeparams():
return (T1, T2)
self.assertEqual(B3.__parameters__, B3.get_typeparams())

def test_pep695_generic_protocol_callable_members(self):
@runtime_checkable
class Foo[T](Protocol):
Expand Down
137 changes: 131 additions & 6 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3208,7 +3208,13 @@ def _is_unpacked_typevartuple(x) -> bool:
)


# Python 3.11+ _collect_type_vars was renamed to _collect_parameters
# - Python 3.11+ _collect_type_vars was renamed to _collect_parameters.
# Breakpoint: https://github.com/python/cpython/pull/31143
# - Python 3.13+ _collect_parameters was renamed to _collect_type_parameters.
# Breakpoint: https://github.com/python/cpython/pull/118900
# - Monkey patch Generic.__init_subclass__ on <3.15 to fix type parameter
# collection from Protocol bases with listed parameters.
# Breakpoint: https://github.com/python/cpython/pull/137281
if hasattr(typing, '_collect_type_vars'):
def _collect_type_vars(types, typevar_types=None):
"""Collect all type variable contained in types in order of
Expand Down Expand Up @@ -3258,21 +3264,82 @@ def _collect_type_vars(types, typevar_types=None):
tvars.append(collected)
return tuple(tvars)

def _generic_init_subclass(cls, *args, **kwargs):
super(Generic, cls).__init_subclass__(*args, **kwargs)
tvars = []
if '__orig_bases__' in cls.__dict__:
error = Generic in cls.__orig_bases__
else:
error = (Generic in cls.__bases__ and
cls.__name__ != 'Protocol' and
type(cls) not in (_TypedDictMeta, typing._TypedDictMeta))
if error:
raise TypeError("Cannot inherit from plain Generic")
if '__orig_bases__' in cls.__dict__:
typevar_types = (TypeVar, typing.TypeVar, ParamSpec)
if hasattr(typing, "ParamSpec"): # Python 3.10+
typevar_types += (typing.ParamSpec,)
tvars = _collect_type_vars(cls.__orig_bases__, typevar_types)
Comment on lines +3279 to +3282
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: for compatibility with both 3.9 and 3.10+, c.f. python/cpython#26091

# Look for Generic[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
# Also check for and reject plain Generic,
# and reject multiple Generic[...].
gvars = None
basename = None
for base in cls.__orig_bases__:
if (isinstance(base, typing._GenericAlias) and
base.__origin__ in (Generic, typing.Protocol, Protocol)):
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...] multiple times."
)
gvars = base.__parameters__
basename = base.__origin__.__name__
if gvars is not None:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
s_args = ', '.join(str(g) for g in gvars)
raise TypeError(
f"Some type variables ({s_vars}) are"
f" not listed in {basename}[{s_args}]"
)
tvars = gvars
cls.__parameters__ = tuple(tvars)

typing._collect_type_vars = _collect_type_vars
else:
def _collect_parameters(args):
typing.Generic.__init_subclass__ = classmethod(_generic_init_subclass)
elif sys.version_info < (3, 15):
def _collect_parameters(
args,
*,
enforce_default_ordering=_marker,
validate_all=False,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: since typing only explicitly passes validate_all from inside _generic_init_subclass (which this monkey patches below), there's no need for sys._getframe hacks like what was done for enforce_default_ordering in #392.

):
Comment on lines +3315 to +3320
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that there have been previous issues where monkey patching typing internals can break future versions of typing if new parameters are added (c.f. python/cpython#118900). I think there's low risk of that here, since this only monkey patches internal functions on Python <3.15, and it seems unlikely that new parameter would be backported to 3.14 at this point. But if we wanted to help guard against that, I suppose adding an ignored **kwargs parameter might be an option?

"""Collect all type variables and parameter specifications in args
in order of first appearance (lexicographic order).

Having an explicit `Generic` or `Protocol` base class determines
the exact parameter order.

For example::

assert _collect_parameters((T, Callable[P, T])) == (T, P)
>>> P = ParamSpec('P')
>>> T = TypeVar('T')
>>> _collect_parameters((T, Callable[P, T]))
(~T, ~P)
>>> _collect_parameters((list[T], Generic[P, T]))
(~P, ~T)
"""
parameters = []

# A required TypeVarLike cannot appear after a TypeVarLike with default
# if it was a direct call to `Generic[]` or `Protocol[]`
enforce_default_ordering = _has_generic_or_protocol_as_origin()
if enforce_default_ordering is _marker:
enforce_default_ordering = _has_generic_or_protocol_as_origin()
Comment on lines +3340 to +3341
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: using a placeholder default value for enforce_default_ordering allows the same function to work as _collect_parameters on Python <3.13 and as _collect_type_parameters on Python 3.13+. Another alternative would be to split this into separate <3.13 and 3.13+ implementations in respective version branches


default_encountered = False

# Also, a TypeVarLike with a default cannot appear after a TypeVarTuple
Expand Down Expand Up @@ -3307,6 +3374,17 @@ def _collect_parameters(args):
' follows type parameter with a default')

parameters.append(t)
elif (
not validate_all
and isinstance(t, typing._GenericAlias)
and t.__origin__ in (Generic, typing.Protocol, Protocol)
):
# If we see explicit `Generic[...]` or `Protocol[...]` base classes,
# we need to just copy them as-is.
# Unless `validate_all` is passed, in this case it means that
# we are doing a validation of `Generic` subclasses,
# then we collect all unique parameters to be able to inspect them.
parameters = t.__parameters__
else:
if _is_unpacked_typevartuple(t):
type_var_tuple_encountered = True
Expand All @@ -3316,8 +3394,55 @@ def _collect_parameters(args):

return tuple(parameters)

if not _PEP_696_IMPLEMENTED:
def _generic_init_subclass(cls, *args, **kwargs):
super(Generic, cls).__init_subclass__(*args, **kwargs)
tvars = []
if '__orig_bases__' in cls.__dict__:
error = Generic in cls.__orig_bases__
else:
error = (Generic in cls.__bases__ and
cls.__name__ != 'Protocol' and
type(cls) not in (_TypedDictMeta, typing._TypedDictMeta))
if error:
raise TypeError("Cannot inherit from plain Generic")
if '__orig_bases__' in cls.__dict__:
tvars = _collect_parameters(cls.__orig_bases__, validate_all=True)
# Look for Generic[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
# Also check for and reject plain Generic,
# and reject multiple Generic[...].
gvars = None
basename = None
for base in cls.__orig_bases__:
if (isinstance(base, typing._GenericAlias) and
base.__origin__ in (Generic, typing.Protocol, Protocol)):
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...] multiple times."
)
gvars = base.__parameters__
basename = base.__origin__.__name__
if gvars is not None:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
s_args = ', '.join(str(g) for g in gvars)
raise TypeError(
f"Some type variables ({s_vars}) are"
f" not listed in {basename}[{s_args}]"
)
tvars = gvars
cls.__parameters__ = tuple(tvars)

if _PEP_696_IMPLEMENTED:
typing._collect_type_parameters = _collect_parameters
typing._generic_init_subclass = _generic_init_subclass
else:
typing._collect_parameters = _collect_parameters
typing.Generic.__init_subclass__ = classmethod(_generic_init_subclass)


# Backport typing.NamedTuple as it exists in Python 3.13.
# In 3.11, the ability to define generic `NamedTuple`s was supported.
Expand Down
Loading