|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import ast |
| 4 | +import importlib.abc |
| 5 | +import importlib.machinery |
| 6 | +import importlib.util |
| 7 | +import logging |
| 8 | +import os |
| 9 | +import sys |
| 10 | +from typing import TYPE_CHECKING |
| 11 | + |
| 12 | +if TYPE_CHECKING: |
| 13 | + import types |
| 14 | + from collections.abc import Sequence |
| 15 | + |
| 16 | + |
| 17 | +class ImportTransformer(ast.NodeTransformer): |
| 18 | + def __init__(self, import_map: dict[str, str] | None = None) -> None: |
| 19 | + self.import_map = import_map or {} |
| 20 | + self.modified = False |
| 21 | + |
| 22 | + def visit_Import(self, node: ast.Import) -> ast.AST: # noqa: N802 |
| 23 | + new_names = [] |
| 24 | + for name in node.names: |
| 25 | + if name.name in self.import_map: |
| 26 | + self.modified = True |
| 27 | + new_name = self.import_map[name.name] |
| 28 | + new_names.append(ast.alias(name=new_name, asname=name.asname or name.name)) |
| 29 | + logging.getLogger(__name__).debug("Rewriting import: %s → %s", name.name, new_name) |
| 30 | + else: |
| 31 | + new_names.append(name) |
| 32 | + |
| 33 | + node.names = new_names |
| 34 | + return node |
| 35 | + |
| 36 | + def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST: # noqa: N802 |
| 37 | + if node.module in self.import_map: |
| 38 | + self.modified = True |
| 39 | + new_module = self.import_map[node.module] |
| 40 | + logging.getLogger(__name__).debug("Rewriting from import: %s → %s", node.module, new_module) |
| 41 | + node.module = new_module |
| 42 | + |
| 43 | + return node |
| 44 | + |
| 45 | + |
| 46 | +class ImportRewritingFinder(importlib.abc.MetaPathFinder): |
| 47 | + def __init__(self, import_map: dict[str, str] | None = None) -> None: |
| 48 | + self.import_map = import_map or {} |
| 49 | + self.processed_modules: set[str] = set() |
| 50 | + |
| 51 | + def find_spec( |
| 52 | + self, |
| 53 | + fullname: str, |
| 54 | + path: Sequence[str | bytes] | None, |
| 55 | + target: types.ModuleType | None = None, # noqa: ARG002 |
| 56 | + ) -> importlib.machinery.ModuleSpec | None: |
| 57 | + if fullname in self.processed_modules or fullname.startswith("_"): |
| 58 | + return None |
| 59 | + |
| 60 | + if path is None: |
| 61 | + path = sys.path |
| 62 | + |
| 63 | + for entry in path: |
| 64 | + if not isinstance(entry, str) or not os.path.isdir(entry): |
| 65 | + continue |
| 66 | + |
| 67 | + for suffix in importlib.machinery.SOURCE_SUFFIXES: |
| 68 | + filename = os.path.join(entry, fullname.split(".")[-1] + suffix) |
| 69 | + if not os.path.exists(filename): |
| 70 | + continue |
| 71 | + |
| 72 | + loader = ImportRewritingLoader(fullname, filename, self.import_map) |
| 73 | + spec = importlib.machinery.ModuleSpec(name=fullname, loader=loader, origin=filename, is_package=False) |
| 74 | + |
| 75 | + self.processed_modules.add(fullname) |
| 76 | + return spec |
| 77 | + |
| 78 | + return None |
| 79 | + |
| 80 | + |
| 81 | +class ImportRewritingLoader(importlib.abc.SourceLoader): |
| 82 | + def __init__(self, fullname: str, path: str, import_map: dict[str, str] | None = None): |
| 83 | + self.fullname = fullname |
| 84 | + self.path = path |
| 85 | + self.import_map = import_map or {} |
| 86 | + |
| 87 | + def get_filename(self, fullname: str) -> str: # noqa: ARG002 |
| 88 | + return self.path |
| 89 | + |
| 90 | + def get_data(self, path: str | bytes) -> bytes: |
| 91 | + with open(path, "rb") as f: |
| 92 | + return f.read() |
| 93 | + |
| 94 | + def exec_module(self, module: types.ModuleType) -> None: |
| 95 | + source_bytes = self.get_data(self.get_filename(self.fullname)) |
| 96 | + source = source_bytes.decode("utf-8") |
| 97 | + |
| 98 | + tree = ast.parse(source) |
| 99 | + transformer = ImportTransformer(self.import_map) |
| 100 | + transformed_tree = transformer.visit(tree) |
| 101 | + |
| 102 | + if transformer.modified: |
| 103 | + ast.fix_missing_locations(transformed_tree) |
| 104 | + code = compile(transformed_tree, self.get_filename(self.fullname), "exec") |
| 105 | + exec(code, module.__dict__) # noqa: S102 |
| 106 | + else: |
| 107 | + code = compile(source, self.get_filename(self.fullname), "exec") |
| 108 | + exec(code, module.__dict__) # noqa: S102 |
| 109 | + |
| 110 | + |
| 111 | +def install_import_rewriter( |
| 112 | + import_map: dict[str, str] | None = None, |
| 113 | +) -> ImportRewritingFinder: |
| 114 | + """Install the import rewriting hook with the specified mapping. |
| 115 | +
|
| 116 | + :param import_map: A dictionary mapping original import names to replacement names. |
| 117 | + For example: {'requests': 'my_requests'} |
| 118 | + :returns: The finder instance that was installed. |
| 119 | + """ |
| 120 | + import_map = import_map or {} |
| 121 | + finder = ImportRewritingFinder(import_map) |
| 122 | + sys.meta_path.insert(0, finder) |
| 123 | + return finder |
0 commit comments