Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Unreleased

- Backport CPython PR [#137281](https://github.com/python/cpython/pull/137281),
fixing how type parameters are collected when a `Protocol` base class is parametrized
with type variables. Now, parametrized `Generic` or `Protocol` base classes always
dictate the number and the order of the type parameters. Patch by Brian Schubert,
backporting a CPython PR by Nikita Sobolev.
- Raise `TypeError` when attempting to subclass `typing_extensions.ParamSpec` on
Python 3.9. The `typing` implementation has always raised an error, and the
`typing_extensions` implementation has raised an error on Python 3.10+ since
Expand Down
40 changes: 40 additions & 0 deletions src/test_typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3580,12 +3580,14 @@ class C: pass

def test_defining_generic_protocols(self):
T = TypeVar('T')
T2 = TypeVar('T2')
S = TypeVar('S')
@runtime_checkable
class PR(Protocol[T, S]):
def meth(self): pass
class P(PR[int, T], Protocol[T]):
y = 1
self.assertEqual(P.__parameters__, (T,))
with self.assertRaises(TypeError):
issubclass(PR[int, T], PR)
with self.assertRaises(TypeError):
Expand All @@ -3594,16 +3596,23 @@ class P(PR[int, T], Protocol[T]):
PR[int]
with self.assertRaises(TypeError):
P[int, str]
with self.assertRaisesRegex(
TypeError,
re.escape('Some type variables (~S) are not listed in Protocol[~T, ~T2]'),
):
class ExtraTypeVars(P[S], Protocol[T, T2]): ...
if not TYPING_3_10_0:
with self.assertRaises(TypeError):
PR[int, 1]
with self.assertRaises(TypeError):
PR[int, ClassVar]
class C(PR[int, T]): pass
self.assertEqual(C.__parameters__, (T,))
self.assertIsInstance(C[str](), C)

def test_defining_generic_protocols_old_style(self):
T = TypeVar('T')
T2 = TypeVar('T2')
S = TypeVar('S')
@runtime_checkable
class PR(Protocol, Generic[T, S]):
Expand All @@ -3620,8 +3629,15 @@ class P(PR[int, str], Protocol):
PR[int, 1]
class P1(Protocol, Generic[T]):
def bar(self, x: T) -> str: ...
self.assertEqual(P1.__parameters__, (T,))
class P2(Generic[T], Protocol):
def bar(self, x: T) -> str: ...
self.assertEqual(P2.__parameters__, (T,))
msg = re.escape('Some type variables (~S) are not listed in Protocol[~T, ~T2]')
with self.assertRaisesRegex(TypeError, msg):
class ExtraTypeVars(P1[S], Protocol[T, T2]): ...
with self.assertRaisesRegex(TypeError, msg):
class ExtraTypeVars(P2[S], Protocol[T, T2]): ...
@runtime_checkable
class PSub(P1[str], Protocol):
x = 1
Expand All @@ -3634,9 +3650,33 @@ def bar(self, x: str) -> str:
with self.assertRaises(TypeError):
PR[int, ClassVar]

def test_protocol_parameter_order(self):
# https://github.com/python/cpython/issues/137191
T1 = TypeVar("T1")
T2 = TypeVar("T2", default=object)

class A(Protocol[T1]): ...

class B0(A[T2], Generic[T1, T2]): ...
self.assertEqual(B0.__parameters__, (T1, T2))

class B1(A[T2], Protocol, Generic[T1, T2]): ...
self.assertEqual(B1.__parameters__, (T1, T2))

class B2(A[T2], Protocol[T1, T2]): ...
self.assertEqual(B2.__parameters__, (T1, T2))

if hasattr(typing, "TypeAliasType"):
exec(textwrap.dedent(
"""
def test_pep695_protocol_parameter_order(self):
class A[T1](Protocol): ...
class B3[T1, T2](A[T2], Protocol):
@staticmethod
def get_typeparams():
return (T1, T2)
self.assertEqual(B3.__parameters__, B3.get_typeparams())

def test_pep695_generic_protocol_callable_members(self):
@runtime_checkable
class Foo[T](Protocol):
Expand Down
137 changes: 131 additions & 6 deletions src/typing_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3203,7 +3203,13 @@ def _is_unpacked_typevartuple(x) -> bool:
)


# Python 3.11+ _collect_type_vars was renamed to _collect_parameters
# - Python 3.11+ _collect_type_vars was renamed to _collect_parameters.
# Breakpoint: https://github.com/python/cpython/pull/31143
# - Python 3.13+ _collect_parameters was renamed to _collect_type_parameters.
# Breakpoint: https://github.com/python/cpython/pull/118900
# - Monkey patch Generic.__init_subclass__ on <3.15 to fix type parameter
# collection from Protocol bases with listed parameters.
# Breakpoint: https://github.com/python/cpython/pull/137281
if hasattr(typing, '_collect_type_vars'):
def _collect_type_vars(types, typevar_types=None):
"""Collect all type variable contained in types in order of
Expand Down Expand Up @@ -3253,21 +3259,82 @@ def _collect_type_vars(types, typevar_types=None):
tvars.append(collected)
return tuple(tvars)

def _generic_init_subclass(cls, *args, **kwargs):
super(Generic, cls).__init_subclass__(*args, **kwargs)
tvars = []
if '__orig_bases__' in cls.__dict__:
error = Generic in cls.__orig_bases__
else:
error = (Generic in cls.__bases__ and
cls.__name__ != 'Protocol' and
type(cls) not in (_TypedDictMeta, typing._TypedDictMeta))
if error:
raise TypeError("Cannot inherit from plain Generic")
if '__orig_bases__' in cls.__dict__:
typevar_types = (TypeVar, typing.TypeVar, ParamSpec)
if hasattr(typing, "ParamSpec"): # Python 3.10+
typevar_types += (typing.ParamSpec,)
tvars = _collect_type_vars(cls.__orig_bases__, typevar_types)
Comment on lines +3279 to +3282
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: for compatibility with both 3.9 and 3.10+, c.f. python/cpython#26091

# Look for Generic[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
# Also check for and reject plain Generic,
# and reject multiple Generic[...].
gvars = None
basename = None
for base in cls.__orig_bases__:
if (isinstance(base, typing._GenericAlias) and
base.__origin__ in (Generic, typing.Protocol, Protocol)):
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...] multiple times."
)
gvars = base.__parameters__
basename = base.__origin__.__name__
if gvars is not None:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
s_args = ', '.join(str(g) for g in gvars)
raise TypeError(
f"Some type variables ({s_vars}) are"
f" not listed in {basename}[{s_args}]"
)
tvars = gvars
cls.__parameters__ = tuple(tvars)

typing._collect_type_vars = _collect_type_vars
else:
def _collect_parameters(args):
typing.Generic.__init_subclass__ = classmethod(_generic_init_subclass)
elif sys.version_info < (3, 15):
def _collect_parameters(
args,
*,
enforce_default_ordering=_marker,
validate_all=False,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: since typing only explicitly passes validate_all from inside _generic_init_subclass (which this monkey patches below), there's no need for sys._getframe hacks like what was done for enforce_default_ordering in #392.

):
Comment on lines +3315 to +3320
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that there have been previous issues where monkey patching typing internals can break future versions of typing if new parameters are added (c.f. python/cpython#118900). I think there's low risk of that here, since this only monkey patches internal functions on Python <3.15, and it seems unlikely that new parameter would be backported to 3.14 at this point. But if we wanted to help guard against that, I suppose adding an ignored **kwargs parameter might be an option?

"""Collect all type variables and parameter specifications in args
in order of first appearance (lexicographic order).

Having an explicit `Generic` or `Protocol` base class determines
the exact parameter order.

For example::

assert _collect_parameters((T, Callable[P, T])) == (T, P)
>>> P = ParamSpec('P')
>>> T = TypeVar('T')
>>> _collect_parameters((T, Callable[P, T]))
(~T, ~P)
>>> _collect_parameters((list[T], Generic[P, T]))
(~P, ~T)
"""
parameters = []

# A required TypeVarLike cannot appear after a TypeVarLike with default
# if it was a direct call to `Generic[]` or `Protocol[]`
enforce_default_ordering = _has_generic_or_protocol_as_origin()
if enforce_default_ordering is _marker:
enforce_default_ordering = _has_generic_or_protocol_as_origin()
Comment on lines +3340 to +3341
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: using a placeholder default value for enforce_default_ordering allows the same function to work as _collect_parameters on Python <3.13 and as _collect_type_parameters on Python 3.13+. Another alternative would be to split this into separate <3.13 and 3.13+ implementations in respective version branches


default_encountered = False

# Also, a TypeVarLike with a default cannot appear after a TypeVarTuple
Expand Down Expand Up @@ -3302,6 +3369,17 @@ def _collect_parameters(args):
' follows type parameter with a default')

parameters.append(t)
elif (
not validate_all
and isinstance(t, typing._GenericAlias)
and t.__origin__ in (Generic, typing.Protocol, Protocol)
):
# If we see explicit `Generic[...]` or `Protocol[...]` base classes,
# we need to just copy them as-is.
# Unless `validate_all` is passed, in this case it means that
# we are doing a validation of `Generic` subclasses,
# then we collect all unique parameters to be able to inspect them.
parameters = t.__parameters__
else:
if _is_unpacked_typevartuple(t):
type_var_tuple_encountered = True
Expand All @@ -3311,8 +3389,55 @@ def _collect_parameters(args):

return tuple(parameters)

if not _PEP_696_IMPLEMENTED:
def _generic_init_subclass(cls, *args, **kwargs):
super(Generic, cls).__init_subclass__(*args, **kwargs)
tvars = []
if '__orig_bases__' in cls.__dict__:
error = Generic in cls.__orig_bases__
else:
error = (Generic in cls.__bases__ and
cls.__name__ != 'Protocol' and
type(cls) not in (_TypedDictMeta, typing._TypedDictMeta))
if error:
raise TypeError("Cannot inherit from plain Generic")
if '__orig_bases__' in cls.__dict__:
tvars = _collect_parameters(cls.__orig_bases__, validate_all=True)
# Look for Generic[T1, ..., Tn].
# If found, tvars must be a subset of it.
# If not found, tvars is it.
# Also check for and reject plain Generic,
# and reject multiple Generic[...].
gvars = None
basename = None
for base in cls.__orig_bases__:
if (isinstance(base, typing._GenericAlias) and
base.__origin__ in (Generic, typing.Protocol, Protocol)):
if gvars is not None:
raise TypeError(
"Cannot inherit from Generic[...] multiple times."
)
gvars = base.__parameters__
basename = base.__origin__.__name__
if gvars is not None:
tvarset = set(tvars)
gvarset = set(gvars)
if not tvarset <= gvarset:
s_vars = ', '.join(str(t) for t in tvars if t not in gvarset)
s_args = ', '.join(str(g) for g in gvars)
raise TypeError(
f"Some type variables ({s_vars}) are"
f" not listed in {basename}[{s_args}]"
)
tvars = gvars
cls.__parameters__ = tuple(tvars)

if _PEP_696_IMPLEMENTED:
typing._collect_type_parameters = _collect_parameters
typing._generic_init_subclass = _generic_init_subclass
else:
typing._collect_parameters = _collect_parameters
typing.Generic.__init_subclass__ = classmethod(_generic_init_subclass)


# Backport typing.NamedTuple as it exists in Python 3.13.
# In 3.11, the ability to define generic `NamedTuple`s was supported.
Expand Down
Loading