Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions mypyc/irbuild/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
ValuePattern,
)
from mypy.traverser import TraverserVisitor
from mypy.types import Instance, TupleType, get_proper_type
from mypy.types import Instance, LiteralType, TupleType, get_proper_type
from mypyc.ir.ops import BasicBlock, Value
from mypyc.ir.rtypes import object_rprimitive
from mypyc.irbuild.builder import IRBuilder
Expand Down Expand Up @@ -152,23 +152,7 @@ def visit_class_pattern(self, pattern: ClassPattern) -> None:

node = pattern.class_ref.node
assert isinstance(node, TypeInfo)

ty = node.names.get("__match_args__")
assert ty

match_args_type = get_proper_type(ty.type)
assert isinstance(match_args_type, TupleType)

match_args: list[str] = []

for item in match_args_type.items:
proper_item = get_proper_type(item)
assert isinstance(proper_item, Instance) and proper_item.last_known_value

match_arg = proper_item.last_known_value.value
assert isinstance(match_arg, str)

match_args.append(match_arg)
match_args = extract_dunder_match_args_names(node)

for i, expr in enumerate(pattern.positionals):
self.builder.activate_block(self.code_block)
Expand Down Expand Up @@ -355,3 +339,24 @@ def prep_sequence_pattern(
patterns.append(pattern)

return star_index, capture, patterns


def extract_dunder_match_args_names(info: TypeInfo) -> list[str]:
ty = info.names.get("__match_args__")
assert ty
match_args_type = get_proper_type(ty.type)
assert isinstance(match_args_type, TupleType)

match_args: list[str] = []
for item in match_args_type.items:
proper_item = get_proper_type(item)

match_arg = None
if isinstance(proper_item, Instance) and proper_item.last_known_value:
match_arg = proper_item.last_known_value.value
elif isinstance(proper_item, LiteralType):
match_arg = proper_item.value
assert isinstance(match_arg, str), f"Unrecognized __match_args__ item: {item}"

match_args.append(match_arg)
return match_args
78 changes: 78 additions & 0 deletions mypyc/test-data/irbuild-match.test
Original file line number Diff line number Diff line change
Expand Up @@ -1727,3 +1727,81 @@ L4:
L5:
L6:
unreachable

[case testMatchLiteralMatchArgs_python3_10]
from typing_extensions import Literal

class Foo:
__match_args__: tuple[Literal["foo"]] = ("foo",)
foo: str

def f(x: Foo) -> None:
match x:
case Foo(foo):
print("foo")
case _:
assert False, "Unreachable"
[out]
def Foo.__mypyc_defaults_setup(__mypyc_self__):
__mypyc_self__ :: __main__.Foo
r0 :: str
r1 :: tuple[str]
L0:
r0 = 'foo'
r1 = (r0)
__mypyc_self__.__match_args__ = r1
return 1
def f(x):
x :: __main__.Foo
r0 :: object
r1 :: i32
r2 :: bit
r3 :: bool
r4 :: str
r5 :: object
r6, foo, r7 :: str
r8 :: object
r9 :: str
r10 :: object
r11 :: object[1]
r12 :: object_ptr
r13, r14 :: object
r15 :: i32
r16 :: bit
r17, r18 :: bool
L0:
r0 = __main__.Foo :: type
r1 = PyObject_IsInstance(x, r0)
r2 = r1 >= 0 :: signed
r3 = truncate r1: i32 to builtins.bool
if r3 goto L1 else goto L3 :: bool
L1:
r4 = 'foo'
r5 = CPyObject_GetAttr(x, r4)
r6 = cast(str, r5)
foo = r6
L2:
r7 = 'foo'
r8 = builtins :: module
r9 = 'print'
r10 = CPyObject_GetAttr(r8, r9)
r11 = [r7]
r12 = load_address r11
r13 = PyObject_Vectorcall(r10, r12, 1, 0)
keep_alive r7
goto L8
L3:
L4:
r14 = box(bool, 0)
r15 = PyObject_IsTrue(r14)
r16 = r15 >= 0 :: signed
r17 = truncate r15: i32 to builtins.bool
if r17 goto L6 else goto L5 :: bool
L5:
r18 = raise AssertionError('Unreachable')
unreachable
L6:
goto L8
L7:
L8:
return 1