Skip to content

Commit e03d3c1

Browse files
authored
Check class references to catch non-existant classes in match cases (#20042)
Fixes #20018.
1 parent 5b7279b commit e03d3c1

File tree

3 files changed

+92
-81
lines changed

3 files changed

+92
-81
lines changed

mypy/checkpattern.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from mypy.maptype import map_instance_to_supertype
1515
from mypy.meet import narrow_declared_type
1616
from mypy.messages import MessageBuilder
17-
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, TypeInfo, Var
17+
from mypy.nodes import ARG_POS, Context, Expression, NameExpr, TypeAlias, Var
1818
from mypy.options import Options
1919
from mypy.patterns import (
2020
AsPattern,
@@ -37,6 +37,7 @@
3737
)
3838
from mypy.types import (
3939
AnyType,
40+
FunctionLike,
4041
Instance,
4142
LiteralType,
4243
NoneType,
@@ -538,27 +539,20 @@ def visit_class_pattern(self, o: ClassPattern) -> PatternType:
538539
# Check class type
539540
#
540541
type_info = o.class_ref.node
541-
if type_info is None:
542-
typ: Type = AnyType(TypeOfAny.from_error)
543-
elif isinstance(type_info, TypeAlias) and not type_info.no_args:
542+
typ = self.chk.expr_checker.accept(o.class_ref)
543+
p_typ = get_proper_type(typ)
544+
if isinstance(type_info, TypeAlias) and not type_info.no_args:
544545
self.msg.fail(message_registry.CLASS_PATTERN_GENERIC_TYPE_ALIAS, o)
545546
return self.early_non_match()
546-
elif isinstance(type_info, TypeInfo):
547-
typ = fill_typevars_with_any(type_info)
548-
elif isinstance(type_info, TypeAlias):
549-
typ = type_info.target
550-
elif (
551-
isinstance(type_info, Var)
552-
and type_info.type is not None
553-
and isinstance(get_proper_type(type_info.type), AnyType)
554-
):
555-
typ = type_info.type
556-
else:
557-
if isinstance(type_info, Var) and type_info.type is not None:
558-
name = type_info.type.str_with_options(self.options)
559-
else:
560-
name = type_info.name
561-
self.msg.fail(message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(name), o)
547+
elif isinstance(p_typ, FunctionLike) and p_typ.is_type_obj():
548+
typ = fill_typevars_with_any(p_typ.type_object())
549+
elif not isinstance(p_typ, AnyType):
550+
self.msg.fail(
551+
message_registry.CLASS_PATTERN_TYPE_REQUIRED.format(
552+
typ.str_with_options(self.options)
553+
),
554+
o,
555+
)
562556
return self.early_non_match()
563557

564558
new_type, rest_type = self.chk.conditional_types_with_intersection(
@@ -697,6 +691,8 @@ def should_self_match(self, typ: Type) -> bool:
697691
typ = get_proper_type(typ)
698692
if isinstance(typ, TupleType):
699693
typ = typ.partial_fallback
694+
if isinstance(typ, AnyType):
695+
return False
700696
if isinstance(typ, Instance) and typ.type.get("__match_args__") is not None:
701697
# Named tuples and other subtypes of builtins that define __match_args__
702698
# should not self match.

mypyc/test-data/irbuild-match.test

Lines changed: 60 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -563,10 +563,9 @@ def f():
563563
def f():
564564
r0, r1 :: object
565565
r2 :: bool
566-
i :: int
567-
r3 :: object
568-
r4 :: str
569-
r5, r6 :: object
566+
r3, i, r4 :: object
567+
r5 :: str
568+
r6 :: object
570569
r7 :: object[1]
571570
r8 :: object_ptr
572571
r9, r10 :: object
@@ -576,21 +575,22 @@ L0:
576575
r2 = CPy_TypeCheck(r1, r0)
577576
if r2 goto L1 else goto L3 :: bool
578577
L1:
579-
i = 246
578+
r3 = object 123
579+
i = r3
580580
L2:
581-
r3 = builtins :: module
582-
r4 = 'print'
583-
r5 = CPyObject_GetAttr(r3, r4)
584-
r6 = box(int, i)
585-
r7 = [r6]
581+
r4 = builtins :: module
582+
r5 = 'print'
583+
r6 = CPyObject_GetAttr(r4, r5)
584+
r7 = [i]
586585
r8 = load_address r7
587-
r9 = PyObject_Vectorcall(r5, r8, 1, 0)
588-
keep_alive r6
586+
r9 = PyObject_Vectorcall(r6, r8, 1, 0)
587+
keep_alive i
589588
goto L4
590589
L3:
591590
L4:
592591
r10 = box(None, 1)
593592
return r10
593+
594594
[case testMatchClassPatternWithPositionalArgs_python3_10]
595595
class Position:
596596
__match_args__ = ("x", "y", "z")
@@ -599,7 +599,7 @@ class Position:
599599
y: int
600600
z: int
601601

602-
def f(x):
602+
def f(x) -> None:
603603
match x:
604604
case Position(1, 2, 3):
605605
print("matched")
@@ -641,7 +641,7 @@ def f(x):
641641
r28 :: object
642642
r29 :: object[1]
643643
r30 :: object_ptr
644-
r31, r32 :: object
644+
r31 :: object
645645
L0:
646646
r0 = __main__.Position :: type
647647
r1 = PyObject_IsInstance(x, r0)
@@ -687,8 +687,8 @@ L4:
687687
goto L6
688688
L5:
689689
L6:
690-
r32 = box(None, 1)
691-
return r32
690+
return 1
691+
692692
[case testMatchClassPatternWithKeywordPatterns_python3_10]
693693
class Position:
694694
x: int
@@ -848,7 +848,7 @@ class C:
848848
a: int
849849
b: int
850850

851-
def f(x):
851+
def f(x) -> None:
852852
match x:
853853
case C(1, 2) as y:
854854
print("matched")
@@ -885,7 +885,7 @@ def f(x):
885885
r22 :: object
886886
r23 :: object[1]
887887
r24 :: object_ptr
888-
r25, r26 :: object
888+
r25 :: object
889889
L0:
890890
r0 = __main__.C :: type
891891
r1 = PyObject_IsInstance(x, r0)
@@ -925,8 +925,8 @@ L4:
925925
goto L6
926926
L5:
927927
L6:
928-
r26 = box(None, 1)
929-
return r26
928+
return 1
929+
930930
[case testMatchClassPatternPositionalCapture_python3_10]
931931
class C:
932932
__match_args__ = ("x",)
@@ -953,15 +953,14 @@ def f(x):
953953
r2 :: bit
954954
r3 :: bool
955955
r4 :: str
956-
r5 :: object
957-
r6, num :: int
958-
r7 :: str
959-
r8 :: object
960-
r9 :: str
961-
r10 :: object
962-
r11 :: object[1]
963-
r12 :: object_ptr
964-
r13, r14 :: object
956+
r5, num :: object
957+
r6 :: str
958+
r7 :: object
959+
r8 :: str
960+
r9 :: object
961+
r10 :: object[1]
962+
r11 :: object_ptr
963+
r12, r13 :: object
965964
L0:
966965
r0 = __main__.C :: type
967966
r1 = PyObject_IsInstance(x, r0)
@@ -971,22 +970,22 @@ L0:
971970
L1:
972971
r4 = 'x'
973972
r5 = CPyObject_GetAttr(x, r4)
974-
r6 = unbox(int, r5)
975-
num = r6
973+
num = r5
976974
L2:
977-
r7 = 'matched'
978-
r8 = builtins :: module
979-
r9 = 'print'
980-
r10 = CPyObject_GetAttr(r8, r9)
981-
r11 = [r7]
982-
r12 = load_address r11
983-
r13 = PyObject_Vectorcall(r10, r12, 1, 0)
984-
keep_alive r7
975+
r6 = 'matched'
976+
r7 = builtins :: module
977+
r8 = 'print'
978+
r9 = CPyObject_GetAttr(r7, r8)
979+
r10 = [r6]
980+
r11 = load_address r10
981+
r12 = PyObject_Vectorcall(r9, r11, 1, 0)
982+
keep_alive r6
985983
goto L4
986984
L3:
987985
L4:
988-
r14 = box(None, 1)
989-
return r14
986+
r13 = box(None, 1)
987+
return r13
988+
990989
[case testMatchMappingEmpty_python3_10]
991990
def f(x):
992991
match x:
@@ -1601,35 +1600,35 @@ def f(x):
16011600
def f(x):
16021601
x, r0 :: object
16031602
r1 :: bool
1604-
r2, y :: int
1605-
r3 :: str
1606-
r4 :: object
1607-
r5 :: str
1608-
r6 :: object
1609-
r7 :: object[1]
1610-
r8 :: object_ptr
1611-
r9, r10 :: object
1603+
y :: object
1604+
r2 :: str
1605+
r3 :: object
1606+
r4 :: str
1607+
r5 :: object
1608+
r6 :: object[1]
1609+
r7 :: object_ptr
1610+
r8, r9 :: object
16121611
L0:
16131612
r0 = load_address PyLong_Type
16141613
r1 = CPy_TypeCheck(x, r0)
16151614
if r1 goto L1 else goto L3 :: bool
16161615
L1:
1617-
r2 = unbox(int, x)
1618-
y = r2
1616+
y = x
16191617
L2:
1620-
r3 = 'matched'
1621-
r4 = builtins :: module
1622-
r5 = 'print'
1623-
r6 = CPyObject_GetAttr(r4, r5)
1624-
r7 = [r3]
1625-
r8 = load_address r7
1626-
r9 = PyObject_Vectorcall(r6, r8, 1, 0)
1627-
keep_alive r3
1618+
r2 = 'matched'
1619+
r3 = builtins :: module
1620+
r4 = 'print'
1621+
r5 = CPyObject_GetAttr(r3, r4)
1622+
r6 = [r2]
1623+
r7 = load_address r6
1624+
r8 = PyObject_Vectorcall(r5, r7, 1, 0)
1625+
keep_alive r2
16281626
goto L4
16291627
L3:
16301628
L4:
1631-
r10 = box(None, 1)
1632-
return r10
1629+
r9 = box(None, 1)
1630+
return r9
1631+
16331632
[case testMatchSequenceCaptureAll_python3_10]
16341633
def f(x):
16351634
match x:

test-data/unit/check-python310.test

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2990,3 +2990,19 @@ def foo(e: Literal[0, 1]) -> None:
29902990
...
29912991

29922992
defer = unknown_module.foo
2993+
2994+
[case testMatchErrorsIncorrectName]
2995+
class A:
2996+
pass
2997+
2998+
match 5:
2999+
case A.blah(): # E: "type[A]" has no attribute "blah"
3000+
pass
3001+
3002+
[case testMatchAllowsAnyClassArgsForAny]
3003+
match 5:
3004+
case BlahBlah(a, b): # E: Name "BlahBlah" is not defined
3005+
reveal_type(a) # N: Revealed type is "Any"
3006+
reveal_type(b) # N: Revealed type is "Any"
3007+
case BlahBlah(c=c): # E: Name "BlahBlah" is not defined
3008+
reveal_type(c) # N: Revealed type is "Any"

0 commit comments

Comments
 (0)