Skip to content

Commit 248646a

Browse files
committed
Add initial mypy plugin for Record
1 parent c21ccd7 commit 248646a

File tree

6 files changed

+421
-3
lines changed

6 files changed

+421
-3
lines changed

asyncpg/mypy/__init__.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import typing
2+
import mypy.nodes
3+
import mypy.plugin
4+
import mypy.types
5+
6+
from . import common
7+
from . import hooks
8+
from . import utils
9+
10+
11+
class AsyncpgPlugin(mypy.plugin.Plugin):
12+
def get_method_hook(self, fullname: str) \
13+
-> typing.Optional[common.MethodHook]:
14+
class_name, _, method_name = fullname.rpartition('.')
15+
symbol = self.lookup_fully_qualified(class_name)
16+
17+
if symbol and isinstance(symbol.node, mypy.nodes.TypeInfo) and \
18+
utils.is_record(symbol.node):
19+
if method_name == '__getitem__':
20+
return hooks.record_getitem
21+
if method_name == 'get':
22+
return hooks.record_get
23+
24+
return None
25+
26+
def get_attribute_hook(self, fullname: str) \
27+
-> typing.Optional[common.AttributeHook]:
28+
class_name, _, _ = fullname.rpartition('.')
29+
symbol = self.lookup_fully_qualified(class_name)
30+
31+
if symbol and isinstance(symbol.node, mypy.nodes.TypeInfo) and \
32+
symbol.node.has_base(common.RECORD_NAME):
33+
return hooks.record_attribute
34+
35+
return None
36+
37+
def get_customize_class_mro_hook(self, fullname: str) \
38+
-> typing.Optional[common.ClassDefHook]:
39+
return hooks.mark_record
40+
41+
def get_base_class_hook(self, fullname: str) \
42+
-> typing.Optional[common.ClassDefHook]:
43+
if fullname == common.RECORD_NAME:
44+
return hooks.record_final_attributes
45+
return None
46+
47+
48+
def plugin(version: str) -> typing.Type[mypy.plugin.Plugin]:
49+
return AsyncpgPlugin

asyncpg/mypy/common.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import mypy.nodes
2+
import mypy.plugin
3+
import mypy.types
4+
import typing
5+
import typing_extensions
6+
7+
RECORD_NAME = 'asyncpg.protocol.protocol.Record' # type: typing_extensions.Final # noqa: E501
8+
MethodPairType = typing.Tuple[typing.List[mypy.nodes.Argument],
9+
mypy.types.Type]
10+
FieldPairType = typing.Tuple[str, mypy.types.Type]
11+
12+
13+
class MethodHook(typing_extensions.Protocol):
14+
def __call__(self, __ctx: mypy.plugin.MethodContext) -> mypy.types.Type:
15+
...
16+
17+
18+
class AttributeHook(typing_extensions.Protocol):
19+
def __call__(self, __ctx: mypy.plugin.AttributeContext) -> mypy.types.Type:
20+
...
21+
22+
23+
class ClassDefHook(typing_extensions.Protocol):
24+
def __call__(self, __ctx: mypy.plugin.ClassDefContext) -> None:
25+
...

asyncpg/mypy/hooks.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import mypy.nodes
2+
import mypy.plugin
3+
import mypy.types
4+
import typing
5+
6+
from . import common
7+
from . import utils
8+
9+
10+
def mark_record(ctx: mypy.plugin.ClassDefContext) -> None:
11+
if ctx.cls.info.fullname == common.RECORD_NAME:
12+
return
13+
14+
if ctx.cls.info.has_base(common.RECORD_NAME):
15+
utils.mark_record(ctx.cls.info)
16+
17+
18+
def record_final_attributes(ctx: mypy.plugin.ClassDefContext) -> None:
19+
if utils.is_record(ctx.cls.info):
20+
for name, value in ctx.cls.info.names.items():
21+
# set all properties as Final so they can't be set
22+
if isinstance(value.node, mypy.nodes.Var):
23+
value.node.is_final = True
24+
value.node.final_set_in_init = True
25+
value.node.final_unset_in_class = True
26+
27+
28+
def record_attribute(ctx: mypy.plugin.AttributeContext) \
29+
-> mypy.types.Type:
30+
if isinstance(ctx.type, mypy.types.Instance) and \
31+
utils.is_record(ctx.type.type):
32+
assert isinstance(ctx.context, mypy.nodes.MemberExpr)
33+
symbol = ctx.type.type.get(ctx.context.name)
34+
35+
if symbol is None:
36+
ctx.api.fail('"{}" has no attribute "{}"'
37+
.format(ctx.type.type.name, ctx.context.name),
38+
ctx.context)
39+
40+
return ctx.default_attr_type
41+
42+
43+
def record_getitem(ctx: mypy.plugin.MethodContext) \
44+
-> mypy.types.Type:
45+
if isinstance(ctx.type, mypy.types.Instance):
46+
arg = ctx.arg_types[0][0]
47+
48+
if arg is not None and isinstance(arg, mypy.types.Instance) and \
49+
isinstance(arg.last_known_value, mypy.types.LiteralType):
50+
value = arg.last_known_value.value
51+
names = utils.get_record_field_names(ctx.type.type.defn)
52+
name = None # type: typing.Optional[str]
53+
54+
if isinstance(value, int) and value < len(names):
55+
name = names[value]
56+
elif isinstance(value, str) and value in names:
57+
name = value
58+
59+
if name is None:
60+
ctx.api.fail('Unexpected key "{}" for record "{}'
61+
.format(value, ctx.type.type.name),
62+
ctx.context)
63+
else:
64+
node = ctx.type.type.get(name)
65+
66+
if node is not None and node.type is not None:
67+
return node.type
68+
69+
return ctx.default_return_type
70+
71+
72+
def record_get(ctx: mypy.plugin.MethodContext) \
73+
-> mypy.types.Type:
74+
if ctx.arg_names[0][0] is not None or \
75+
len(ctx.arg_names) > 1 and \
76+
ctx.arg_names[1][0] is not None:
77+
ctx.api.fail('get() takes no keyword arguments', ctx.context)
78+
elif isinstance(ctx.type, mypy.types.Instance):
79+
arg = utils.get_argument_type_by_name(ctx, 'key')
80+
default_arg = utils.get_argument_type_by_name(ctx, 'default')
81+
82+
if arg and isinstance(arg, mypy.types.Instance) and \
83+
isinstance(arg.last_known_value, mypy.types.LiteralType):
84+
value = arg.last_known_value.value
85+
names = utils.get_record_field_names(ctx.type.type.defn)
86+
name = None # type: typing.Optional[str]
87+
88+
if isinstance(value, str) and value in names:
89+
name = value
90+
91+
if name is None:
92+
ctx.api.fail('Unexpected key "{}" for record "{}'
93+
.format(value, ctx.type.type.name),
94+
ctx.context)
95+
else:
96+
node = ctx.type.type.get(name)
97+
98+
assert node is not None
99+
100+
if node.type is not None:
101+
if default_arg is not None:
102+
return mypy.types.UnionType([node.type, default_arg],
103+
ctx.context.line,
104+
ctx.context.column)
105+
else:
106+
return node.type
107+
108+
return ctx.default_return_type

0 commit comments

Comments
 (0)