Skip to content

Commit 5c11124

Browse files
committed
Move to a separate file and cache ModuleCompleter
1 parent 7a2fde0 commit 5c11124

File tree

2 files changed

+371
-361
lines changed

2 files changed

+371
-361
lines changed

Lib/_pyrepl/_module_completer.py

Lines changed: 366 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,366 @@
1+
from __future__ import annotations
2+
3+
import pkgutil
4+
import sys
5+
import tokenize
6+
from io import StringIO
7+
from contextlib import contextmanager
8+
from dataclasses import dataclass
9+
from itertools import chain
10+
from tokenize import TokenInfo
11+
12+
TYPE_CHECKING = False
13+
14+
if TYPE_CHECKING:
15+
from typing import Any, Iterable, Iterator, Mapping
16+
17+
18+
class ModuleCompleter:
19+
"""A completer for Python import statements.
20+
21+
Examples:
22+
- import <tab>
23+
- import foo<tab>
24+
- import foo.<tab>
25+
- import foo as bar, baz<tab>
26+
27+
- from <tab>
28+
- from foo<tab>
29+
- from foo import <tab>
30+
- from foo import bar<tab>
31+
- from foo import (bar as baz, qux<tab>
32+
"""
33+
34+
def __init__(self, namespace: Mapping[str, Any] | None = None) -> None:
35+
self.namespace = namespace or {}
36+
self._global_cache: list[pkgutil.ModuleInfo] = []
37+
self._curr_sys_path: list[str] = sys.path[:]
38+
39+
def get_completions(self, line: str) -> list[str]:
40+
"""Return the next possible import completions for 'line'."""
41+
result = ImportParser(line).parse()
42+
if not result:
43+
return []
44+
return self.complete(*result)
45+
46+
def complete(self, from_name: str | None, name: str | None) -> list[str]:
47+
if from_name is None:
48+
# import x.y.z<tab>
49+
assert name is not None
50+
path, prefix = self.get_path_and_prefix(name)
51+
modules = self.find_modules(path, prefix)
52+
return [self.format_completion(path, module) for module in modules]
53+
54+
if name is None:
55+
# from x.y.z<tab>
56+
path, prefix = self.get_path_and_prefix(from_name)
57+
modules = self.find_modules(path, prefix)
58+
return [self.format_completion(path, module) for module in modules]
59+
60+
# from x.y import z<tab>
61+
return self.find_modules(from_name, name)
62+
63+
def find_modules(self, path: str, prefix: str) -> list[str]:
64+
"""Find all modules under 'path' that start with 'prefix'."""
65+
modules = self._find_modules(path, prefix)
66+
# Filter out invalid module names
67+
# (for example those containing dashes that cannot be imported with 'import')
68+
return [mod for mod in modules if mod.isidentifier()]
69+
70+
def _find_modules(self, path: str, prefix: str) -> list[str]:
71+
if not path:
72+
# Top-level import (e.g. `import foo<tab>`` or `from foo<tab>`)`
73+
return [name for _, name, _ in self.global_cache
74+
if name.startswith(prefix)]
75+
76+
if path.startswith('.'):
77+
# Convert relative path to absolute path
78+
package = self.namespace.get('__package__', '')
79+
path = self.resolve_relative_name(path, package) # type: ignore[assignment]
80+
if path is None:
81+
return []
82+
83+
modules: Iterable[pkgutil.ModuleInfo] = self.global_cache
84+
for segment in path.split('.'):
85+
modules = [mod_info for mod_info in modules
86+
if mod_info.ispkg and mod_info.name == segment]
87+
modules = self.iter_submodules(modules)
88+
return [module.name for module in modules
89+
if module.name.startswith(prefix)]
90+
91+
def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]:
92+
"""Iterate over all submodules of the given parent modules."""
93+
specs = [info.module_finder.find_spec(info.name, None)
94+
for info in parent_modules if info.ispkg]
95+
search_locations = set(chain.from_iterable(
96+
getattr(spec, 'submodule_search_locations', [])
97+
for spec in specs if spec
98+
))
99+
return pkgutil.iter_modules(search_locations)
100+
101+
def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]:
102+
"""
103+
Split a dotted name into an import path and a
104+
final prefix that is to be completed.
105+
106+
Examples:
107+
'foo.bar' -> 'foo', 'bar'
108+
'foo.' -> 'foo', ''
109+
'.foo' -> '.', 'foo'
110+
"""
111+
if '.' not in dotted_name:
112+
return '', dotted_name
113+
if dotted_name.startswith('.'):
114+
stripped = dotted_name.lstrip('.')
115+
dots = '.' * (len(dotted_name) - len(stripped))
116+
if '.' not in stripped:
117+
return dots, stripped
118+
path, prefix = stripped.rsplit('.', 1)
119+
return dots + path, prefix
120+
path, prefix = dotted_name.rsplit('.', 1)
121+
return path, prefix
122+
123+
def format_completion(self, path: str, module: str) -> str:
124+
if path == '' or path.endswith('.'):
125+
return f'{path}{module}'
126+
return f'{path}.{module}'
127+
128+
def resolve_relative_name(self, name: str, package: str) -> str | None:
129+
"""Resolve a relative module name to an absolute name.
130+
131+
Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo'
132+
"""
133+
# taken from importlib._bootstrap
134+
level = 0
135+
for character in name:
136+
if character != '.':
137+
break
138+
level += 1
139+
bits = package.rsplit('.', level - 1)
140+
if len(bits) < level:
141+
return None
142+
base = bits[0]
143+
name = name[level:]
144+
return f'{base}.{name}' if name else base
145+
146+
@property
147+
def global_cache(self) -> list[pkgutil.ModuleInfo]:
148+
"""Global module cache"""
149+
if not self._global_cache or self._curr_sys_path != sys.path:
150+
self._curr_sys_path = sys.path[:]
151+
# print('getting packages')
152+
self._global_cache = list(pkgutil.iter_modules())
153+
return self._global_cache
154+
155+
156+
class ImportParser:
157+
"""
158+
Parses incomplete import statements that are
159+
suitable for autocomplete suggestions.
160+
161+
Examples:
162+
- import foo -> Result(from_name=None, name='foo')
163+
- import foo. -> Result(from_name=None, name='foo.')
164+
- from foo -> Result(from_name='foo', name=None)
165+
- from foo import bar -> Result(from_name='foo', name='bar')
166+
- from .foo import ( -> Result(from_name='.foo', name='')
167+
168+
Note that the parser works in reverse order, starting from the
169+
last token in the input string. This makes the parser more robust
170+
when parsing multiple statements.
171+
"""
172+
_ignored_tokens = {
173+
tokenize.INDENT, tokenize.DEDENT, tokenize.COMMENT,
174+
tokenize.NL, tokenize.NEWLINE, tokenize.ENDMARKER
175+
}
176+
_keywords = {'import', 'from', 'as'}
177+
178+
def __init__(self, code: str) -> None:
179+
self.code = code
180+
tokens = []
181+
try:
182+
for t in tokenize.generate_tokens(StringIO(code).readline):
183+
if t.type not in self._ignored_tokens:
184+
tokens.append(t)
185+
except tokenize.TokenError as e:
186+
if 'unexpected EOF' not in str(e):
187+
# unexpected EOF is fine, since we're parsing an
188+
# incomplete statement, but other errors are not
189+
# because we may not have all the tokens so it's
190+
# safer to bail out
191+
tokens = []
192+
except SyntaxError:
193+
tokens = []
194+
self.tokens = TokenQueue(tokens[::-1])
195+
196+
def parse(self) -> tuple[str | None, str | None] | None:
197+
if not (res := self._parse()):
198+
return None
199+
return res.from_name, res.name
200+
201+
def _parse(self) -> Result | None:
202+
with self.tokens.save_state():
203+
return self.parse_from_import()
204+
with self.tokens.save_state():
205+
return self.parse_import()
206+
207+
def parse_import(self) -> Result:
208+
if self.code.rstrip().endswith('import') and self.code.endswith(' '):
209+
return Result(name='')
210+
if self.tokens.peek_string(','):
211+
name = ''
212+
else:
213+
if self.code.endswith(' '):
214+
raise ParseError('parse_import')
215+
name = self.parse_dotted_name()
216+
if name.startswith('.'):
217+
raise ParseError('parse_import')
218+
while self.tokens.peek_string(','):
219+
self.tokens.pop()
220+
self.parse_dotted_as_name()
221+
if self.tokens.peek_string('import'):
222+
return Result(name=name)
223+
raise ParseError('parse_import')
224+
225+
def parse_from_import(self) -> Result:
226+
if self.code.rstrip().endswith('import') and self.code.endswith(' '):
227+
return Result(from_name=self.parse_empty_from_import(), name='')
228+
if self.code.rstrip().endswith('from') and self.code.endswith(' '):
229+
return Result(from_name='')
230+
if self.tokens.peek_string('(') or self.tokens.peek_string(','):
231+
return Result(from_name=self.parse_empty_from_import(), name='')
232+
if self.code.endswith(' '):
233+
raise ParseError('parse_from_import')
234+
name = self.parse_dotted_name()
235+
if '.' in name:
236+
self.tokens.pop_string('from')
237+
return Result(from_name=name)
238+
if self.tokens.peek_string('from'):
239+
return Result(from_name=name)
240+
from_name = self.parse_empty_from_import()
241+
return Result(from_name=from_name, name=name)
242+
243+
def parse_empty_from_import(self) -> str:
244+
if self.tokens.peek_string(','):
245+
self.tokens.pop()
246+
self.parse_as_names()
247+
if self.tokens.peek_string('('):
248+
self.tokens.pop()
249+
self.tokens.pop_string('import')
250+
return self.parse_from()
251+
252+
def parse_from(self) -> str:
253+
from_name = self.parse_dotted_name()
254+
self.tokens.pop_string('from')
255+
return from_name
256+
257+
def parse_dotted_as_name(self) -> str:
258+
self.tokens.pop_name()
259+
if self.tokens.peek_string('as'):
260+
self.tokens.pop()
261+
with self.tokens.save_state():
262+
return self.parse_dotted_name()
263+
264+
def parse_dotted_name(self) -> str:
265+
name = []
266+
if self.tokens.peek_string('.'):
267+
name.append('.')
268+
self.tokens.pop()
269+
if (self.tokens.peek_name()
270+
and (tok := self.tokens.peek())
271+
and tok.string not in self._keywords):
272+
name.append(self.tokens.pop_name())
273+
if not name:
274+
raise ParseError('parse_dotted_name')
275+
while self.tokens.peek_string('.'):
276+
name.append('.')
277+
self.tokens.pop()
278+
if (self.tokens.peek_name()
279+
and (tok := self.tokens.peek())
280+
and tok.string not in self._keywords):
281+
name.append(self.tokens.pop_name())
282+
else:
283+
break
284+
285+
while self.tokens.peek_string('.'):
286+
name.append('.')
287+
self.tokens.pop()
288+
return ''.join(name[::-1])
289+
290+
def parse_as_names(self) -> None:
291+
self.parse_as_name()
292+
while self.tokens.peek_string(','):
293+
self.tokens.pop()
294+
self.parse_as_name()
295+
296+
def parse_as_name(self) -> None:
297+
self.tokens.pop_name()
298+
if self.tokens.peek_string('as'):
299+
self.tokens.pop()
300+
self.tokens.pop_name()
301+
302+
303+
class ParseError(Exception):
304+
pass
305+
306+
307+
@dataclass(frozen=True)
308+
class Result:
309+
from_name: str | None = None
310+
name: str | None = None
311+
312+
313+
class TokenQueue:
314+
"""Provides helper functions for working with a sequence of tokens."""
315+
316+
def __init__(self, tokens: list[TokenInfo]) -> None:
317+
self.tokens: list[TokenInfo] = tokens
318+
self.index: int = 0
319+
self.stack: list[int] = []
320+
321+
@contextmanager
322+
def save_state(self) -> Any:
323+
try:
324+
self.stack.append(self.index)
325+
yield
326+
except ParseError:
327+
self.index = self.stack.pop()
328+
else:
329+
self.stack.pop()
330+
331+
def __bool__(self) -> bool:
332+
return self.index < len(self.tokens)
333+
334+
def peek(self) -> TokenInfo | None:
335+
if not self:
336+
return None
337+
return self.tokens[self.index]
338+
339+
def peek_name(self) -> bool:
340+
if not (tok := self.peek()):
341+
return False
342+
return tok.type == tokenize.NAME
343+
344+
def pop_name(self) -> str:
345+
tok = self.pop()
346+
if tok.type != tokenize.NAME:
347+
raise ParseError('pop_name')
348+
return tok.string
349+
350+
def peek_string(self, string: str) -> bool:
351+
if not (tok := self.peek()):
352+
return False
353+
return tok.string == string
354+
355+
def pop_string(self, string: str) -> str:
356+
tok = self.pop()
357+
if tok.string != string:
358+
raise ParseError('pop_string')
359+
return tok.string
360+
361+
def pop(self) -> TokenInfo:
362+
if not self:
363+
raise ParseError('pop')
364+
tok = self.tokens[self.index]
365+
self.index += 1
366+
return tok

0 commit comments

Comments
 (0)