Skip to content

Commit 0917d24

Browse files
committed
Remove attribute completion, never import modules
1 parent 48ee6ad commit 0917d24

File tree

2 files changed

+97
-101
lines changed

2 files changed

+97
-101
lines changed

Lib/_pyrepl/readline.py

Lines changed: 77 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from io import StringIO
3636
from contextlib import contextmanager
3737
from dataclasses import dataclass, field
38+
from itertools import chain
3839
from tokenize import TokenInfo
3940

4041
import os
@@ -635,71 +636,69 @@ def __init__(self, namespace: Mapping[str, Any] | None = None):
635636

636637
def get_completions(self, line: str) -> list[str]:
637638
"""Return the next possible import completions for 'line'."""
638-
639-
parser = ImportParser(line)
640-
if not (result := parser.parse()):
639+
result = ImportParser(line).parse()
640+
if not result:
641641
return []
642642
return self.complete(*result)
643643

644644
def complete(self, from_name: str | None, name: str | None) -> list[str]:
645-
# import x.y.z<tab>
646645
if from_name is None:
647-
if not name:
648-
return []
649-
return self.complete_import(name)
646+
# import x.y.z<tab>
647+
path, prefix = self.get_path_and_prefix(name)
648+
modules = self.find_modules(path, prefix)
649+
return [self.format_completion(path, module) for module in modules]
650650

651-
# from x.y.z<tab>
652651
if name is None:
653-
if not from_name:
654-
return []
655-
return self.complete_import(from_name)
652+
# from x.y.z<tab>
653+
path, prefix = self.get_path_and_prefix(from_name)
654+
modules = self.find_modules(path, prefix)
655+
return [self.format_completion(path, module) for module in modules]
656656

657657
# from x.y import z<tab>
658-
if not (module := self.import_module(from_name)):
659-
return []
660-
661-
submodules = self.filter_submodules(module, name)
662-
attributes = self.filter_attributes(module, name)
663-
return list(set(submodules + attributes))
664-
665-
def complete_import(self, name: str) -> list[str]:
666-
is_relative = name.startswith('.')
667-
path, prefix = self.get_path_and_prefix(name)
668-
669-
if not is_relative and not path:
670-
return [name for name in self.global_cache if name.startswith(prefix)]
671-
672-
if not (module := self.import_module(path)):
673-
return []
674-
675-
submodules = self.filter_submodules(module, prefix)
676-
if not is_relative:
677-
return [f'{path}.{name}' for name in submodules]
678-
return [f'.{name}' for name in submodules]
679-
680-
def import_module(self, path: str) -> ModuleType | None:
681-
package = self.namespace.get('__package__')
682-
is_relative = path.startswith('.')
683-
if is_relative and not package:
684-
return None
685-
try:
686-
module = importlib.import_module(
687-
path,
688-
package=package if is_relative else None)
689-
except ImportError:
690-
return None
691-
return module
692-
693-
def filter_submodules(self, module: ModuleType, prefix: str) -> list[str]:
694-
if not hasattr(module, '__path__'):
695-
return []
696-
return [name for _, name, _ in pkgutil.iter_modules(module.__path__)
697-
if name.startswith(prefix)]
658+
return self.find_modules(from_name, name)
659+
660+
def find_modules(self, path: str, prefix: str) -> list[str]:
661+
"""Find all modules under 'path' that start with 'prefix'."""
662+
if not path:
663+
# Top-level import (e.g. `import foo<tab>`` or `from foo<tab>`)`
664+
return [name for _, name, _ in self.global_cache
665+
if name.startswith(prefix)]
666+
667+
if path.startswith('.'):
668+
# Convert relative path to absolute path
669+
package = self.namespace.get('__package__')
670+
path = self.resolve_relative_name(path, package)
671+
if path is None:
672+
return []
698673

699-
def filter_attributes(self, module: ModuleType, prefix: str) -> list[str]:
700-
return [attr for attr in module.__dict__ if attr.startswith(prefix)]
674+
modules = self.global_cache
675+
for segment in path.split('.'):
676+
modules = [mod_info for mod_info in modules
677+
if mod_info.ispkg and mod_info.name == segment]
678+
modules = self.iter_submodules(modules)
679+
return [module.name for module in modules
680+
if module.name.startswith(prefix)]
681+
682+
def iter_submodules(self, parent_modules):
683+
"""Iterate over all submodules of the given parent modules."""
684+
specs = [info.module_finder.find_spec(info.name)
685+
for info in parent_modules if info.ispkg]
686+
search_locations = set(chain.from_iterable(
687+
getattr(spec, 'submodule_search_locations', [])
688+
for spec in specs if spec
689+
))
690+
return pkgutil.iter_modules(search_locations)
701691

702692
def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]:
693+
"""
694+
Split a dotted name into an import path and a
695+
final prefix that is to be completed.
696+
697+
Examples:
698+
'foo.bar' -> 'foo', 'bar'
699+
'foo.' -> 'foo', ''
700+
'.foo' -> '.', 'foo'
701+
"""
703702
if '.' not in dotted_name:
704703
return '', dotted_name
705704
if dotted_name.startswith('.'):
@@ -712,12 +711,35 @@ def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]:
712711
path, prefix = dotted_name.rsplit('.', 1)
713712
return path, prefix
714713

714+
def format_completion(self, path: str, module: str) -> str:
715+
if path == '' or path.endswith('.'):
716+
return f'{path}{module}'
717+
return f'{path}.{module}'
718+
719+
def resolve_relative_name(self, name, package):
720+
"""Resolve a relative module name to an absolute name.
721+
722+
Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo'
723+
"""
724+
# taken from importlib._bootstrap
725+
level = 0
726+
for character in name:
727+
if character != '.':
728+
break
729+
level += 1
730+
bits = package.rsplit('.', level - 1)
731+
if len(bits) < level:
732+
return None
733+
base = bits[0]
734+
name = name[level:]
735+
return f'{base}.{name}' if name else base
736+
715737
@property
716738
def global_cache(self) -> list[str]:
739+
"""Global module cache"""
717740
if not self._global_cache or self._curr_sys_path != sys.path:
718741
self._curr_sys_path = sys.path[:]
719-
self._global_cache = [
720-
name for _, name, _ in pkgutil.iter_modules()]
742+
self._global_cache = list(pkgutil.iter_modules())
721743
return self._global_cache
722744

723745

Lib/test/test_pyrepl/test_pyrepl.py

Lines changed: 20 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -904,62 +904,33 @@ def prepare_reader(self, events, namespace):
904904
reader = ReadlineAlikeReader(console=console, config=config)
905905
return reader
906906

907-
def test_import(self):
908-
cases = [
907+
def test_import_completions(self):
908+
cases = (
909909
("import path\t\n", "import pathlib"),
910910
("import importlib.\t\tres\t\n", "import importlib.resources"),
911911
("import importlib.resources.\t\ta\t\n", "import importlib.resources.abc"),
912912
("import foo, impo\t\n", "import foo, importlib"),
913913
("import foo as bar, impo\t\n", "import foo as bar, importlib"),
914-
]
915-
916-
for code, expected in cases:
917-
with self.subTest(code=code):
918-
events = code_to_events(code)
919-
reader = self.prepare_reader(events, namespace={})
920-
output = reader.readline()
921-
self.assertEqual(output, expected)
922-
923-
def test_from_import(self):
924-
cases = [
925914
("from impo\t\n", "from importlib"),
926915
("from importlib.res\t\n", "from importlib.resources"),
927916
("from importlib.\t\tres\t\n", "from importlib.resources"),
928917
("from importlib.resources.ab\t\n", "from importlib.resources.abc"),
929-
]
930-
931-
for code, expected in cases:
932-
with self.subTest(code=code):
933-
events = code_to_events(code)
934-
reader = self.prepare_reader(events, namespace={})
935-
output = reader.readline()
936-
self.assertEqual(output, expected)
937-
938-
def test_from_import_attributes(self):
939-
cases = [
940918
("from importlib import mac\t\n", "from importlib import machinery"),
941919
("from importlib import res\t\n", "from importlib import resources"),
942-
("from importlib import invalidate_\t\n", "from importlib import invalidate_caches"),
943-
("from importlib import (inval\t\n", "from importlib import (invalidate_caches"),
944-
("from importlib import foo, invalidate_\t\n", "from importlib import foo, invalidate_caches"),
945-
("from importlib import (foo, invalidate_\t\n", "from importlib import (foo, invalidate_caches"),
946-
("from importlib import foo as bar, invalidate_\t\n", "from importlib import foo as bar, invalidate_caches"),
947-
("from importlib import (foo as bar, invalidate_\t\n", "from importlib import (foo as bar, invalidate_caches"),
948-
]
949-
920+
("from importlib.res\t import a\t\n", "from importlib.resources import abc"),
921+
)
950922
for code, expected in cases:
951923
with self.subTest(code=code):
952924
events = code_to_events(code)
953925
reader = self.prepare_reader(events, namespace={})
954926
output = reader.readline()
955927
self.assertEqual(output, expected)
956928

957-
def test_relative_from_import(self):
958-
cases = [
929+
def test_relative_import_completions(self):
930+
cases = (
959931
("from .readl\t\n", "from .readline"),
960-
("from .readline import Mod\t\n", "from .readline import ModuleCompleter"),
961-
]
962-
932+
("from . import readl\t\n", "from . import readline"),
933+
)
963934
for code, expected in cases:
964935
with self.subTest(code=code):
965936
events = code_to_events(code)
@@ -968,7 +939,7 @@ def test_relative_from_import(self):
968939
self.assertEqual(output, expected)
969940

970941
def test_get_path_and_prefix(self):
971-
cases = [
942+
cases = (
972943
('', ('', '')),
973944
('.', ('.', '')),
974945
('..', ('..', '')),
@@ -983,15 +954,14 @@ def test_get_path_and_prefix(self):
983954
('foo.bar', ('foo', 'bar')),
984955
('foo.bar.', ('foo.bar', '')),
985956
('foo.bar.baz', ('foo.bar', 'baz')),
986-
]
987-
957+
)
988958
completer = ModuleCompleter()
989959
for name, expected in cases:
990960
with self.subTest(name=name):
991961
self.assertEqual(completer.get_path_and_prefix(name), expected)
992962

993963
def test_parse(self):
994-
cases = [
964+
cases = (
995965
('import ', (None, '')),
996966
('import foo', (None, 'foo')),
997967
('import foo,', (None, '')),
@@ -1027,8 +997,7 @@ def test_parse(self):
1027997
('from foo import (a, ', ('foo', '')),
1028998
('from foo import (a, c', ('foo', 'c')),
1029999
('from foo import (a as b, c', ('foo', 'c')),
1030-
]
1031-
1000+
)
10321001
for code, parsed in cases:
10331002
parser = ImportParser(code)
10341003
actual = parser.parse()
@@ -1039,9 +1008,12 @@ def test_parse(self):
10391008
code = f'import xyz\n{code}'
10401009
with self.subTest(code=code):
10411010
self.assertEqual(actual, parsed)
1011+
code = f'import xyz;{code}'
1012+
with self.subTest(code=code):
1013+
self.assertEqual(actual, parsed)
10421014

10431015
def test_parse_error(self):
1044-
cases = [
1016+
cases = (
10451017
'',
10461018
'import foo ',
10471019
'from foo ',
@@ -1060,6 +1032,9 @@ def test_parse_error(self):
10601032
'import a.b.c as',
10611033
'import (foo',
10621034
'import (',
1035+
'import .foo',
1036+
'import ..foo',
1037+
'import .foo.bar',
10631038
'import foo; x = 1',
10641039
'import a.; x = 1',
10651040
'import a.b; x = 1',
@@ -1080,8 +1055,7 @@ def test_parse_error(self):
10801055
'from foo import import',
10811056
'from foo import from',
10821057
'from foo import as',
1083-
]
1084-
1058+
)
10851059
for code in cases:
10861060
parser = ImportParser(code)
10871061
actual = parser.parse()

0 commit comments

Comments
 (0)