Skip to content

Commit 83dbe19

Browse files
committed
More cases covered
- nested concatenate - Unpack subscription
1 parent b89c272 commit 83dbe19

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

src/test_typing_extensions.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5520,17 +5520,20 @@ def test_substitution(self):
55205520
T = TypeVar('T')
55215521
P = ParamSpec('P')
55225522
Ts = TypeVarTuple("Ts")
5523+
U1 = Unpack[Tuple[int, str]]
5524+
U2 = Unpack[Ts]
55235525

55245526
C1 = Concatenate[str, T, ...]
55255527
self.assertEqual(C1[int], Concatenate[str, int, ...])
55265528

55275529
C2 = Concatenate[str, P]
55285530
self.assertEqual(C2[...], Concatenate[str, ...])
55295531
self.assertEqual(C2[int], (str, int))
5530-
U1 = Unpack[Tuple[int, str]]
5531-
U2 = Unpack[Ts]
5532+
self.assertEqual(C2[int, ...], (str, int, ...))
5533+
55325534
self.assertEqual(C2[U1], (str, int, str))
55335535
self.assertEqual(C2[U2], (str, Unpack[Ts]))
5536+
self.assertEqual(C2["U1"], (str, typing.ForwardRef("U1")))
55345537

55355538
if (3, 12, 0) <= sys.version_info < (3, 12, 4):
55365539
with self.assertRaises(AssertionError):
@@ -5541,7 +5544,22 @@ def test_substitution(self):
55415544

55425545
C3 = Concatenate[str, T, P]
55435546
self.assertEqual(C3[int, [bool]], (str, int, bool))
5547+
self.assertEqual(C3[int, ...], Concatenate[str, int, ...])
5548+
self.assertEqual(C3[int, Concatenate[str, P]], Concatenate[str, int, str, P])
5549+
5550+
@skipIf((3, 10) <= sys.version_info < (3, 12), reason="no backport yet")
5551+
def test_invalid_substitution(self):
5552+
T = TypeVar('T')
5553+
Ts = TypeVarTuple("Ts")
5554+
U1 = Unpack[Tuple[int, str]]
5555+
U2 = Unpack[Ts]
5556+
5557+
C1 = Concatenate[str, T, ...]
5558+
with self.assertRaisesRegex(TypeError, "Too many arguments"):
5559+
C1[U1]
55445560

5561+
with self.assertRaisesRegex(TypeError, r"Unpack\[Ts\] is not valid as type argument"):
5562+
C1[U2]
55455563

55465564
class TypeGuardTests(BaseTestCase):
55475565
def test_basics(self):

src/typing_extensions.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1766,13 +1766,17 @@ def __call__(self, *args, **kwargs):
17661766
if not hasattr(typing, 'Concatenate'):
17671767
# Inherits from list as a workaround for Callable checks in Python < 3.9.2.
17681768

1769-
#3.9.0-1
1769+
# 3.9.0-1
17701770
if not hasattr(typing, '_type_convert'):
17711771
def _type_convert(arg, module=None, *, allow_special_forms=False):
17721772
"""For converting None to type(None), and strings to ForwardRef."""
17731773
if arg is None:
17741774
return type(None)
17751775
if isinstance(arg, str):
1776+
if sys.version_info <= (3, 9, 6):
1777+
return ForwardRef(arg)
1778+
if sys.version_info <= (3, 9, 7):
1779+
return ForwardRef(arg, module=module)
17761780
return ForwardRef(arg, module=module, is_class=allow_special_forms)
17771781
return arg
17781782
else:
@@ -1812,10 +1816,10 @@ def __parameters__(self):
18121816
# 3.8; needed for typing._subs_tvars
18131817
# 3.9 used by __getitem__ below
18141818
def copy_with(self, params):
1815-
if isinstance(params[-1], (list, tuple)):
1816-
return (*params[:-1], *params[-1])
18171819
if isinstance(params[-1], _ConcatenateGenericAlias):
18181820
params = (*params[:-1], *params[-1].__args__)
1821+
elif isinstance(params[-1], (list, tuple)):
1822+
return (*params[:-1], *params[-1])
18191823
elif (not(params[-1] is ... or isinstance(params[-1], ParamSpec))):
18201824
raise TypeError("The last parameter to Concatenate should be a "
18211825
"ParamSpec variable or ellipsis.")
@@ -1847,10 +1851,21 @@ def __getitem__(self, args):
18471851
if len(params) == 1 and not _is_param_expr(args[0]):
18481852
assert i == 0
18491853
args = (args,)
1850-
# Convert lists to tuples to help other libraries cache the results.
1851-
elif isinstance(args[i], list):
1854+
# This class inherits from list do not convert
1855+
elif (
1856+
isinstance(args[i], list)
1857+
and not isinstance(args[i], _ConcatenateGenericAlias)
1858+
):
18521859
args = (*args[:i], tuple(args[i]), *args[i+1:])
18531860

1861+
alen = len(args)
1862+
plen = len(params)
1863+
if alen != plen:
1864+
raise TypeError(
1865+
f"Too {'many' if alen > plen else 'few'} arguments for {self};"
1866+
f" actual {alen}, expected {plen}"
1867+
)
1868+
18541869
subst = dict(zip(self.__parameters__, args))
18551870
# determine new args
18561871
new_args = []
@@ -1860,6 +1875,16 @@ def __getitem__(self, args):
18601875
continue
18611876
if isinstance(arg, TypeVar):
18621877
arg = subst[arg]
1878+
if (
1879+
(isinstance(arg, typing._GenericAlias) and _is_unpack(arg))
1880+
or (
1881+
hasattr(_types, "GenericAlias")
1882+
and isinstance(arg, _types.GenericAlias)
1883+
and getattr(arg, "__unpacked__", False)
1884+
)
1885+
):
1886+
raise TypeError(f"{arg} is not valid as type argument")
1887+
18631888
elif isinstance(arg,
18641889
typing._GenericAlias
18651890
if not hasattr(_types, "GenericAlias") else

0 commit comments

Comments
 (0)