Skip to content

Commit 275ad4e

Browse files
committed
create import rewrite hook
1 parent 3963ea7 commit 275ad4e

File tree

6 files changed

+170
-3
lines changed

6 files changed

+170
-3
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ dist/
66
wheels/
77
*.egg-info
88

9+
# JJ VCS
10+
.jj
11+
912
# Virtual environments
1013
.venv
1114

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@ authors = [
1111
{ name = "Daniel Copley", email = "[email protected]" }
1212
]
1313
requires-python = ">=3.9"
14-
dependencies = []
14+
dependencies = [
15+
"astroid"
16+
]

src/import_rewriter/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
1-
def main() -> None:
2-
print("Hello from import-rewrite!")
1+
from import_rewriter.rewrite import ImportRewritingFinder, install_import_rewriter
2+
3+
__all__ = [
4+
"install_import_rewriter",
5+
"ImportRewritingFinder",
6+
]

src/import_rewriter/py.typed

Whitespace-only changes.

src/import_rewriter/rewrite.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import importlib.abc
5+
import importlib.machinery
6+
import importlib.util
7+
import logging
8+
import os
9+
import sys
10+
from typing import TYPE_CHECKING
11+
12+
if TYPE_CHECKING:
13+
import types
14+
from collections.abc import Sequence
15+
16+
17+
class ImportTransformer(ast.NodeTransformer):
18+
def __init__(self, import_map: dict[str, str] | None = None) -> None:
19+
self.import_map = import_map or {}
20+
self.modified = False
21+
22+
def visit_Import(self, node: ast.Import) -> ast.AST: # noqa: N802
23+
new_names = []
24+
for name in node.names:
25+
if name.name in self.import_map:
26+
self.modified = True
27+
new_name = self.import_map[name.name]
28+
new_names.append(ast.alias(name=new_name, asname=name.asname or name.name))
29+
logging.getLogger(__name__).debug("Rewriting import: %s → %s", name.name, new_name)
30+
else:
31+
new_names.append(name)
32+
33+
node.names = new_names
34+
return node
35+
36+
def visit_ImportFrom(self, node: ast.ImportFrom) -> ast.AST: # noqa: N802
37+
if node.module in self.import_map:
38+
self.modified = True
39+
new_module = self.import_map[node.module]
40+
logging.getLogger(__name__).debug("Rewriting from import: %s → %s", node.module, new_module)
41+
node.module = new_module
42+
43+
return node
44+
45+
46+
class ImportRewritingFinder(importlib.abc.MetaPathFinder):
47+
def __init__(self, import_map: dict[str, str] | None = None) -> None:
48+
self.import_map = import_map or {}
49+
self.processed_modules: set[str] = set()
50+
51+
def find_spec(
52+
self,
53+
fullname: str,
54+
path: Sequence[str | bytes] | None,
55+
target: types.ModuleType | None = None, # noqa: ARG002
56+
) -> importlib.machinery.ModuleSpec | None:
57+
if fullname in self.processed_modules or fullname.startswith("_"):
58+
return None
59+
60+
if path is None:
61+
path = sys.path
62+
63+
for entry in path:
64+
if not isinstance(entry, str) or not os.path.isdir(entry):
65+
continue
66+
67+
for suffix in importlib.machinery.SOURCE_SUFFIXES:
68+
filename = os.path.join(entry, fullname.split(".")[-1] + suffix)
69+
if not os.path.exists(filename):
70+
continue
71+
72+
loader = ImportRewritingLoader(fullname, filename, self.import_map)
73+
spec = importlib.machinery.ModuleSpec(name=fullname, loader=loader, origin=filename, is_package=False)
74+
75+
self.processed_modules.add(fullname)
76+
return spec
77+
78+
return None
79+
80+
81+
class ImportRewritingLoader(importlib.abc.SourceLoader):
82+
def __init__(self, fullname: str, path: str, import_map: dict[str, str] | None = None):
83+
self.fullname = fullname
84+
self.path = path
85+
self.import_map = import_map or {}
86+
87+
def get_filename(self, fullname: str) -> str: # noqa: ARG002
88+
return self.path
89+
90+
def get_data(self, path: str | bytes) -> bytes:
91+
with open(path, "rb") as f:
92+
return f.read()
93+
94+
def exec_module(self, module: types.ModuleType) -> None:
95+
source_bytes = self.get_data(self.get_filename(self.fullname))
96+
source = source_bytes.decode("utf-8")
97+
98+
tree = ast.parse(source)
99+
transformer = ImportTransformer(self.import_map)
100+
transformed_tree = transformer.visit(tree)
101+
102+
if transformer.modified:
103+
ast.fix_missing_locations(transformed_tree)
104+
code = compile(transformed_tree, self.get_filename(self.fullname), "exec")
105+
exec(code, module.__dict__) # noqa: S102
106+
else:
107+
code = compile(source, self.get_filename(self.fullname), "exec")
108+
exec(code, module.__dict__) # noqa: S102
109+
110+
111+
def install_import_rewriter(
112+
import_map: dict[str, str] | None = None,
113+
) -> ImportRewritingFinder:
114+
"""Install the import rewriting hook with the specified mapping.
115+
116+
:param import_map: A dictionary mapping original import names to replacement names.
117+
For example: {'requests': 'my_requests'}
118+
:returns: The finder instance that was installed.
119+
"""
120+
import_map = import_map or {}
121+
finder = ImportRewritingFinder(import_map)
122+
sys.meta_path.insert(0, finder)
123+
return finder

uv.lock

Lines changed: 35 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)