|
1 | 1 | import argparse |
2 | 2 | import re |
3 | 3 | import sys |
4 | | -from typing import Any, Iterator, List, Union |
| 4 | +from ast import Import, ImportFrom |
| 5 | +from typing import Any, Dict, Iterator, List, Union |
| 6 | + |
| 7 | +from .ast_data import python_ast_objects_of_type |
5 | 8 |
|
6 | 9 |
|
7 | 10 | # @decorators.map_firstp_arg |
@@ -359,3 +362,44 @@ def python_type_name(python_type: type) -> str: |
359 | 362 | def python_object_type_to_word(python_object: Any) -> str: |
360 | 363 | """Convert the given python type to a string.""" |
361 | 364 | return python_type_name(type(python_object)) |
| 365 | + |
| 366 | + |
| 367 | +def _get_importfrom_module_name(node: ImportFrom) -> str: |
| 368 | + """Extract the module name from an ast.ImportFrom node. |
| 369 | +
|
| 370 | + The module name on the ast.ImportFrom node can be None for relative imports |
| 371 | + In this case, this function will return the name as the dots from the import statement. |
| 372 | + A few examples: |
| 373 | + "from requests import get" -> "requests" |
| 374 | + "from . import *" -> "." |
| 375 | + "from .. import *" -> ".." |
| 376 | + "from .foo import bar" -> "foo" |
| 377 | + """ |
| 378 | + if node.module is None: |
| 379 | + module_name = "." * node.level |
| 380 | + else: |
| 381 | + module_name = node.module |
| 382 | + |
| 383 | + return module_name |
| 384 | + |
| 385 | + |
| 386 | +def python_package_imports(code: str) -> Dict[str, List[str]]: |
| 387 | + """Return a dictionary containing the names of all imported modules.""" |
| 388 | + # Start with the Import nodes. |
| 389 | + # These will always have an empty list of submodules |
| 390 | + # so we can just overwrite them without losing any data |
| 391 | + modules = dict() |
| 392 | + nodes = python_ast_objects_of_type(code, Import) |
| 393 | + for node in nodes: |
| 394 | + for alias in node.names: |
| 395 | + modules[alias.name] = [] |
| 396 | + |
| 397 | + # Now for the ImportFrom nodes |
| 398 | + importfrom_nodes = python_ast_objects_of_type(code, ImportFrom) |
| 399 | + for node in importfrom_nodes: |
| 400 | + module_name = _get_importfrom_module_name(node) |
| 401 | + |
| 402 | + for alias in node.names: |
| 403 | + modules.setdefault(module_name, []).append(alias.name) |
| 404 | + |
| 405 | + return modules |
0 commit comments