11
11
import re
12
12
import sys
13
13
import traceback
14
+ import types
15
+ import typing
14
16
15
17
import isort
16
18
import black
17
19
20
+ import circuitpython_typing
21
+ import circuitpython_typing .socket
22
+
18
23
19
24
IMPORTS_IGNORE = frozenset (
20
25
{
21
- "int" ,
22
- "float" ,
26
+ "array" ,
23
27
"bool" ,
24
- "str" ,
28
+ "buffer" ,
29
+ "bytearray" ,
25
30
"bytes" ,
26
- "tuple" ,
27
- "list" ,
28
- "set" ,
29
31
"dict" ,
30
- "bytearray" ,
31
- "slice" ,
32
32
"file" ,
33
- "buffer" ,
33
+ "float" ,
34
+ "int" ,
35
+ "list" ,
34
36
"range" ,
35
- "array" ,
37
+ "set" ,
38
+ "slice" ,
39
+ "str" ,
36
40
"struct_time" ,
41
+ "tuple" ,
37
42
}
38
43
)
39
- IMPORTS_TYPING = frozenset (
40
- {
41
- "Any" ,
42
- "Dict" ,
43
- "Optional" ,
44
- "Union" ,
45
- "Tuple" ,
46
- "List" ,
47
- "Sequence" ,
48
- "NamedTuple" ,
49
- "Iterable" ,
50
- "Iterator" ,
51
- "Callable" ,
52
- "AnyStr" ,
53
- "overload" ,
54
- "Type" ,
55
- }
56
- )
57
- IMPORTS_TYPES = frozenset ({"TracebackType" })
58
- CPY_TYPING = frozenset (
59
- {"ReadableBuffer" , "WriteableBuffer" , "AudioSample" , "FrameBuffer" , "Alarm" }
60
- )
44
+
45
+ # Include all definitions in these type modules, minus some name conflicts.
46
+ AVAILABLE_TYPE_MODULE_IMPORTS = {
47
+ "types" : frozenset (types .__all__ ),
48
+ # Conflicts: countio.Counter, canio.Match
49
+ "typing" : frozenset (typing .__all__ ) - set (["Counter" , "Match" ]),
50
+ "circuitpython_typing" : frozenset (circuitpython_typing .__all__ ),
51
+ "circuitpython_typing.socket" : frozenset (circuitpython_typing .socket .__all__ ),
52
+ }
61
53
62
54
63
55
def is_typed (node , allow_any = False ):
@@ -116,9 +108,7 @@ def find_stub_issues(tree):
116
108
117
109
def extract_imports (tree ):
118
110
modules = set ()
119
- typing = set ()
120
- types = set ()
121
- cpy_typing = set ()
111
+ used_type_module_imports = {module : set () for module in AVAILABLE_TYPE_MODULE_IMPORTS .keys ()}
122
112
123
113
def collect_annotations (anno_tree ):
124
114
if anno_tree is None :
@@ -127,12 +117,9 @@ def collect_annotations(anno_tree):
127
117
if isinstance (node , ast .Name ):
128
118
if node .id in IMPORTS_IGNORE :
129
119
continue
130
- elif node .id in IMPORTS_TYPING :
131
- typing .add (node .id )
132
- elif node .id in IMPORTS_TYPES :
133
- types .add (node .id )
134
- elif node .id in CPY_TYPING :
135
- cpy_typing .add (node .id )
120
+ for module , imports in AVAILABLE_TYPE_MODULE_IMPORTS .items ():
121
+ if node .id in imports :
122
+ used_type_module_imports [module ].add (node .id )
136
123
elif isinstance (node , ast .Attribute ):
137
124
if isinstance (node .value , ast .Name ):
138
125
modules .add (node .value .id )
@@ -145,15 +132,12 @@ def collect_annotations(anno_tree):
145
132
elif isinstance (node , ast .FunctionDef ):
146
133
collect_annotations (node .returns )
147
134
for deco in node .decorator_list :
148
- if isinstance (deco , ast .Name ) and (deco .id in IMPORTS_TYPING ):
149
- typing .add (deco .id )
150
-
151
- return {
152
- "modules" : sorted (modules ),
153
- "typing" : sorted (typing ),
154
- "types" : sorted (types ),
155
- "cpy_typing" : sorted (cpy_typing ),
156
- }
135
+ if isinstance (deco , ast .Name ) and (
136
+ deco .id in AVAILABLE_TYPE_MODULE_IMPORTS ["typing" ]
137
+ ):
138
+ used_type_module_imports ["typing" ].add (deco .id )
139
+
140
+ return (modules , used_type_module_imports )
157
141
158
142
159
143
def find_references (tree ):
@@ -237,15 +221,11 @@ def convert_folder(top_level, stub_directory):
237
221
ok += 1
238
222
239
223
# Add import statements
240
- imports = extract_imports (tree )
224
+ imports , type_imports = extract_imports (tree )
241
225
import_lines = ["from __future__ import annotations" ]
242
- if imports ["types" ]:
243
- import_lines .append ("from types import " + ", " .join (imports ["types" ]))
244
- if imports ["typing" ]:
245
- import_lines .append ("from typing import " + ", " .join (imports ["typing" ]))
246
- if imports ["cpy_typing" ]:
247
- import_lines .append ("from circuitpython_typing import " + ", " .join (imports ["cpy_typing" ]))
248
- import_lines .extend (f"import { m } " for m in imports ["modules" ])
226
+ for type_module , used_types in type_imports .items ():
227
+ import_lines .append (f"from { type_module } import { ', ' .join (sorted (used_types ))} " )
228
+ import_lines .extend (f"import { m } " for m in sorted (imports ))
249
229
import_body = "\n " .join (import_lines )
250
230
m = re .match (r'(\s*""".*?""")' , stub_contents , flags = re .DOTALL )
251
231
if m :
0 commit comments