Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions mypy/checkpattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mypy.maptype import map_instance_to_supertype
from mypy.meet import narrow_declared_type
from mypy.messages import MessageBuilder
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, TypeInfo, Var
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, Var
from mypy.options import Options
from mypy.patterns import (
AsPattern,
Expand All @@ -37,6 +37,7 @@
)
from mypy.types import (
AnyType,
FunctionLike,
Instance,
LiteralType,
NoneType,
Expand Down Expand Up @@ -538,27 +539,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
# Check class type
#
type_info = o.class_ref.node
if type_info is None:
typ: Type = AnyType(TypeOfAny.from_error)
elif isinstance(type_info, TypeAlias) and not type_info.no_args:
typ = self.chk.expr_checker.accept(o.class_ref)
p_typ = get_proper_type(typ)
if isinstance(type_info, TypeAlias) and not type_info.no_args:
self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
return self.early_non_match()
elif isinstance(type_info, TypeInfo):
typ = fill_typevars_with_any(type_info)
elif isinstance(type_info, TypeAlias):
typ = type_info.target
elif (
isinstance(type_info, Var)
and type_info.type is not None
and isinstance(get_proper_type(type_info.type), AnyType)
):
typ = type_info.type
else:
if isinstance(type_info, Var) and type_info.type is not None:
name = type_info.type.str_with_options(self.options)
else:
name = type_info.name
self.msg.fail(message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name), o)
elif isinstance(p_typ, FunctionLike) and p_typ.is_type_obj():
typ = fill_typevars_with_any(p_typ.type_object())
elif not isinstance(p_typ, AnyType):
self.msg.fail(
message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(
typ.str_with_options(self.options)
),
o,
)
return self.early_non_match()

new_type, rest_type = self.chk.conditional_types_with_intersection(
Expand Down Expand Up @@ -697,6 +691,8 @@ def should_self_match(self, typ: Type) -> bool:
typ = get_proper_type(typ)
if isinstance(typ, TupleType):
typ = typ.partial_fallback
if isinstance(typ, AnyType):
return False
if isinstance(typ, Instance) and typ.type.get("__match_args__") is not None:
# Named tuples and other subtypes of builtins that define __match_args__
# should not self match.
Expand Down
121 changes: 60 additions & 61 deletions mypyc/test-data/irbuild-match.test
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,9 @@ def f():
def f():
r0, r1 :: object
r2 :: bool
i :: int
r3 :: object
r4 :: str
r5, r6 :: object
r3, i, r4 :: object
r5 :: str
r6 :: object
r7 :: object[1]
r8 :: object_ptr
r9, r10 :: object
Expand All @@ -576,21 +575,22 @@ L0:
r2 = CPy_TypeCheck(r1, r0)
if r2 goto L1 else goto L3 :: bool
L1:
i = 246
r3 = object 123
i = r3
L2:
r3 = builtins :: module
r4 = 'print'
r5 = CPyObject_GetAttr(r3, r4)
r6 = box(int, i)
r7 = [r6]
r4 = builtins :: module
r5 = 'print'
r6 = CPyObject_GetAttr(r4, r5)
r7 = [i]
r8 = load_address r7
r9 = PyObject_Vectorcall(r5, r8, 1, 0)
keep_alive r6
r9 = PyObject_Vectorcall(r6, r8, 1, 0)
keep_alive i
goto L4
L3:
L4:
r10 = box(None, 1)
return r10

[case testMatchClassPatternWithPositionalArgs_python3_10]
class Position:
__match_args__ = ("x", "y", "z")
Expand All @@ -599,7 +599,7 @@ class Position:
y: int
z: int

def f(x):
def f(x) -> None:
match x:
case Position(1, 2, 3):
print("matched")
Expand Down Expand Up @@ -641,7 +641,7 @@ def f(x):
r28 :: object
r29 :: object[1]
r30 :: object_ptr
r31, r32 :: object
r31 :: object
L0:
r0 = __main__.Position :: type
r1 = PyObject_IsInstance(x, r0)
Expand Down Expand Up @@ -687,8 +687,8 @@ L4:
goto L6
L5:
L6:
r32 = box(None, 1)
return r32
return 1

[case testMatchClassPatternWithKeywordPatterns_python3_10]
class Position:
x: int
Expand Down Expand Up @@ -848,7 +848,7 @@ class C:
a: int
b: int

def f(x):
def f(x) -> None:
match x:
case C(1, 2) as y:
print("matched")
Expand Down Expand Up @@ -885,7 +885,7 @@ def f(x):
r22 :: object
r23 :: object[1]
r24 :: object_ptr
r25, r26 :: object
r25 :: object
L0:
r0 = __main__.C :: type
r1 = PyObject_IsInstance(x, r0)
Expand Down Expand Up @@ -925,8 +925,8 @@ L4:
goto L6
L5:
L6:
r26 = box(None, 1)
return r26
return 1

[case testMatchClassPatternPositionalCapture_python3_10]
class C:
__match_args__ = ("x",)
Expand All @@ -953,15 +953,14 @@ def f(x):
r2 :: bit
r3 :: bool
r4 :: str
r5 :: object
r6, num :: int
r7 :: str
r8 :: object
r9 :: str
r10 :: object
r11 :: object[1]
r12 :: object_ptr
r13, r14 :: object
r5, num :: object
r6 :: str
r7 :: object
r8 :: str
r9 :: object
r10 :: object[1]
r11 :: object_ptr
r12, r13 :: object
L0:
r0 = __main__.C :: type
r1 = PyObject_IsInstance(x, r0)
Expand All @@ -971,22 +970,22 @@ L0:
L1:
r4 = 'x'
r5 = CPyObject_GetAttr(x, r4)
r6 = unbox(int, r5)
num = r6
num = r5
L2:
r7 = 'matched'
r8 = builtins :: module
r9 = 'print'
r10 = CPyObject_GetAttr(r8, r9)
r11 = [r7]
r12 = load_address r11
r13 = PyObject_Vectorcall(r10, r12, 1, 0)
keep_alive r7
r6 = 'matched'
r7 = builtins :: module
r8 = 'print'
r9 = CPyObject_GetAttr(r7, r8)
r10 = [r6]
r11 = load_address r10
r12 = PyObject_Vectorcall(r9, r11, 1, 0)
keep_alive r6
goto L4
L3:
L4:
r14 = box(None, 1)
return r14
r13 = box(None, 1)
return r13

[case testMatchMappingEmpty_python3_10]
def f(x):
match x:
Expand Down Expand Up @@ -1601,35 +1600,35 @@ def f(x):
def f(x):
x, r0 :: object
r1 :: bool
r2, y :: int
r3 :: str
r4 :: object
r5 :: str
r6 :: object
r7 :: object[1]
r8 :: object_ptr
r9, r10 :: object
y :: object
r2 :: str
r3 :: object
r4 :: str
r5 :: object
r6 :: object[1]
r7 :: object_ptr
r8, r9 :: object
L0:
r0 = load_address PyLong_Type
r1 = CPy_TypeCheck(x, r0)
if r1 goto L1 else goto L3 :: bool
L1:
r2 = unbox(int, x)
y = r2
y = x
L2:
r3 = 'matched'
r4 = builtins :: module
r5 = 'print'
r6 = CPyObject_GetAttr(r4, r5)
r7 = [r3]
r8 = load_address r7
r9 = PyObject_Vectorcall(r6, r8, 1, 0)
keep_alive r3
r2 = 'matched'
r3 = builtins :: module
r4 = 'print'
r5 = CPyObject_GetAttr(r3, r4)
r6 = [r2]
r7 = load_address r6
r8 = PyObject_Vectorcall(r5, r7, 1, 0)
keep_alive r2
goto L4
L3:
L4:
r10 = box(None, 1)
return r10
r9 = box(None, 1)
return r9

[case testMatchSequenceCaptureAll_python3_10]
def f(x):
match x:
Expand Down
16 changes: 16 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -2990,3 +2990,19 @@ def foo(e: Literal[0, 1]) -> None:
...

defer = unknown_module.foo

[case testMatchErrorsIncorrectName]
class A:
pass

match 5:
case A.blah(): # E: "type[A]" has no attribute "blah"
pass

[case testMatchAllowsAnyClassArgsForAny]
match 5:
case BlahBlah(a, b): # E: Name "BlahBlah" is not defined
reveal_type(a) # N: Revealed type is "Any"
reveal_type(b) # N: Revealed type is "Any"
case BlahBlah(c=c): # E: Name "BlahBlah" is not defined
reveal_type(c) # N: Revealed type is "Any"