diff --git a/CHANGELOG.md b/CHANGELOG.md index 733505a5..64693af2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ # Unreleased +- Fix `__init__` in subclasses of protocols. - Fix incorrect behaviour on Python 3.9 and Python 3.10 that meant that calling `isinstance` with `typing_extensions.Concatenate[...]` or `typing_extensions.Unpack[...]` as the first argument could have a different diff --git a/src/test_typing_extensions.py b/src/test_typing_extensions.py index 88fa699e..b7968c4a 100644 --- a/src/test_typing_extensions.py +++ b/src/test_typing_extensions.py @@ -20,6 +20,7 @@ import typing import warnings from collections import defaultdict +from dataclasses import dataclass from functools import lru_cache from pathlib import Path from unittest import TestCase, main, skipIf, skipUnless @@ -3697,12 +3698,59 @@ def __init__(self, x: T) -> None: def test_init_called(self): T = TypeVar('T') + class P(Protocol[T]): pass + class C(P[T]): def __init__(self): self.test = 'OK' + self.assertEqual(C[int]().test, 'OK') + class B: + def __init__(self): + self.test = 'OK' + + class D1(B, P[T]): + pass + + self.assertEqual(D1[int]().test, 'OK') + + class D2(P[T], B): + pass + + self.assertEqual(D2[int]().test, 'OK') + + def test_super_call_init(self): + class P(Protocol): + x: int + + class Foo(P): + def __init__(self): + super().__init__() + + Foo() # Previously triggered RecursionError + + def test_inherit_from_protocol(self): + # Dataclasses inheriting from protocol should preserve their own `__init__`. + # See bpo-45081. + + class P(Protocol): + a: int + + @dataclass + class C(P): + a: int + + self.assertEqual(C(5).a, 5) + + @dataclass + class D(P): + def __init__(self, a): + self.a = a * 2 + + self.assertEqual(D(5).a, 10) + def test_protocols_bad_subscripts(self): T = TypeVar('T') S = TypeVar('S') diff --git a/src/typing_extensions.py b/src/typing_extensions.py index bd67a80a..6557dde3 100644 --- a/src/typing_extensions.py +++ b/src/typing_extensions.py @@ -664,10 +664,6 @@ def _allow_reckless_class_checks(depth=2): """ return _caller(depth) in {'abc', 'functools', None} - def _no_init(self, *args, **kwargs): - if type(self)._is_protocol: - raise TypeError('Protocols cannot be instantiated') - def _type_check_issubclass_arg_1(arg): """Raise TypeError if `arg` is not an instance of `type` in `issubclass(arg, )`. @@ -831,7 +827,7 @@ def __init_subclass__(cls, *args, **kwargs): # Prohibit instantiation for protocol classes if cls._is_protocol and cls.__init__ is Protocol.__init__: - cls.__init__ = _no_init + cls.__init__ = typing._no_init_or_replace_init # Breakpoint: https://github.com/python/cpython/pull/113401