Skip to content

Commit 0383c36

Browse files
committed
refactor: rename write_patched_source to persist_patched_source and update related usages
1 parent 03edeab commit 0383c36

File tree

3 files changed

+37
-40
lines changed

3 files changed

+37
-40
lines changed

src/awepatch/function.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from copy import deepcopy
99
from functools import partial
1010
from types import CodeType, TracebackType
11-
from typing import TYPE_CHECKING, Any, Self
11+
from typing import Any, Self
1212

1313
from awepatch.utils import (
1414
AWEPATCH_DEBUG,
@@ -22,9 +22,11 @@
2222
compile_idents,
2323
find_matched_node,
2424
load_stmts,
25-
write_patched_source,
25+
persist_patched_source,
2626
)
2727

28+
TYPE_CHECKING = False
29+
2830
if TYPE_CHECKING:
2931
from collections.abc import Callable, Sequence
3032

@@ -88,7 +90,7 @@ def load_function_code(
8890
source = ast.unparse(func)
8991

9092
if AWEPATCH_DEBUG:
91-
file_path, source = write_patched_source(
93+
file_path, source = persist_patched_source(
9294
source,
9395
name=func.name,
9496
type="function",
@@ -107,13 +109,13 @@ def load_function_code(
107109

108110

109111
def _get_function_def(
110-
func: CodeType, source: list[str]
112+
func: CodeType, slines: list[str]
111113
) -> ast.FunctionDef | ast.AsyncFunctionDef:
112114
"""Get the AST function definition from a code object.
113115
114116
Args:
115117
func: The code object
116-
source: The source code lines of the function
118+
slines: The source code lines of the function
117119
118120
Returns:
119121
The AST function definition
@@ -123,7 +125,7 @@ def _get_function_def(
123125
124126
"""
125127

126-
for node in ast.walk(ast.parse("".join(source))):
128+
for node in ast.walk(ast.parse("".join(slines))):
127129
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
128130
continue
129131
if node.name != func.co_name:
@@ -161,8 +163,8 @@ class _SingleFunctionPatcher:
161163
def __init__(self, func: Callable[..., Any]) -> None:
162164
self._func = func
163165
self._orig_code = func.__code__
164-
self._source, _ = inspect.findsource(func)
165-
self._func_def = _get_function_def(func.__code__, self._source)
166+
self._slines, _ = inspect.findsource(func)
167+
self._func_def = _get_function_def(func.__code__, self._slines)
166168
self._func_def.decorator_list.clear()
167169
self._patches: list[CompiledPatch] = []
168170

@@ -183,11 +185,9 @@ def add_patch(
183185
def apply(self) -> Callable[..., Any]:
184186
"""Apply the patches to the function."""
185187
func_def = deepcopy(self._func_def)
186-
compiled: CompiledPatches = defaultdict(
187-
lambda: defaultdict(lambda: defaultdict(list))
188-
)
188+
compiled: CompiledPatches = defaultdict(lambda: defaultdict(dict))
189189
for patch in self._patches:
190-
target = find_matched_node(func_def, self._source, patch.target)
190+
target = find_matched_node(func_def, self._slines, patch.target)
191191
if target is None:
192192
raise ValueError(f"Patch target {patch.target} not found")
193193
append_patch(

src/awepatch/module.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
compile_idents,
2020
find_matched_node,
2121
load_stmts,
22-
write_patched_source,
22+
persist_patched_source,
2323
)
2424

2525
TYPE_CHECKING = False
@@ -51,9 +51,7 @@ def get_data(self, path: str) -> bytes:
5151
source = f.read()
5252
tree = ast.parse(source, filename=self._origin)
5353
slines = source.splitlines(keepends=True)
54-
compiled: CompiledPatches = defaultdict(
55-
lambda: defaultdict(lambda: defaultdict(list))
56-
)
54+
compiled: CompiledPatches = defaultdict(lambda: defaultdict(dict))
5755
for patch in self._patches:
5856
target = find_matched_node(tree, slines, patch.target)
5957
if target is None:
@@ -69,7 +67,7 @@ def get_data(self, path: str) -> bytes:
6967
apply_compiled_patches(compiled)
7068
source = ast.unparse(tree)
7169
if AWEPATCH_DEBUG:
72-
self._path, source = write_patched_source(
70+
self._path, source = persist_patched_source(
7371
source,
7472
self._fullname.rsplit(".", 1)[-1],
7573
"module",

src/awepatch/utils.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,15 @@
55
import re
66
import threading
77
from abc import ABC, abstractmethod
8-
from binascii import crc32
98
from collections import defaultdict
9+
from collections.abc import Sequence
1010
from dataclasses import KW_ONLY, dataclass
1111
from typing import Any, Literal, TypeAlias, cast
1212

13-
from filelock import FileLock
14-
from platformdirs import user_cache_dir
15-
1613
TYPE_CHECKING = False
1714

1815
if TYPE_CHECKING:
19-
from collections.abc import Mapping, Sequence
16+
from collections.abc import Mapping
2017
from types import TracebackType
2118

2219

@@ -101,7 +98,7 @@ def __str__(self) -> str:
10198

10299

103100
CompiledPatches: TypeAlias = defaultdict[
104-
Location, defaultdict[int, defaultdict[Mode, list[ast.stmt]]]
101+
Location, defaultdict[int, dict[Mode, Sequence[ast.stmt]]]
105102
]
106103

107104

@@ -112,19 +109,16 @@ def append_patch(
112109
mode: Mode,
113110
) -> None:
114111
patches = complied[target[0]][target[1]]
115-
# Check for conflicting patches,
116-
if patches and mode == "replace" and "replace" in patches:
117-
raise ValueError(f"Cannot have multiple 'replace' patches on target {target!r}")
118112
if mode == "replace":
119113
if "replace" in patches:
120114
raise ValueError(
121115
f"Multiple 'replace' patches on the same target {target!r}"
122116
)
123-
patches[mode].extend(stmts)
117+
patches[mode] = stmts
124118
elif mode == "before":
125-
patches[mode].extend(stmts)
119+
patches[mode] = [*stmts, *patches[mode]] if mode in patches else stmts
126120
elif mode == "after":
127-
patches[mode] = [*patches[mode], *stmts]
121+
patches[mode] = [*patches[mode], *stmts] if mode in patches else stmts
128122
else:
129123
raise ValueError(f"Unknown patch mode: {mode!r}")
130124

@@ -350,8 +344,8 @@ def __exit__(
350344
_cache_dir_lock = threading.Lock()
351345

352346

353-
def get_cache_dir() -> str:
354-
"""Get or create the temporary directory for awepatch."""
347+
def _user_cache_dir() -> str:
348+
"""Get or create the cache directory for awepatch."""
355349
global _cache_dir
356350

357351
if _cache_dir is not None:
@@ -360,36 +354,41 @@ def get_cache_dir() -> str:
360354
with _cache_dir_lock:
361355
if _cache_dir is not None:
362356
return _cache_dir
357+
from platformdirs import user_cache_dir
358+
363359
_cache_dir = user_cache_dir("awepatch", appauthor=False, ensure_exists=True)
364360
return _cache_dir
365361

366362

367-
def write_patched_source(
363+
def persist_patched_source(
368364
source: str,
369365
name: str,
370366
type: str,
371367
origin: str = "",
372368
) -> tuple[str, str]:
373-
"""Load a function's code object from its AST module.
369+
"""Persist the patched source code to a file and return the file path and source.
374370
375371
Args:
376-
source (str): The source code of the function.
377-
name (str): The name of the function.
378-
type (str): The type of the function (e.g., "module", "function").
379-
origin (str, optional): The origin location for the function. Defaults to "".
372+
source: The source code of the function.
373+
name: The name of the function.
374+
type: The type of the function (e.g., "module", "function").
375+
origin: The origin location for the function. Defaults to "".
380376
381377
Returns:
382-
CodeType: The code object of the function.
378+
tuple[str, str]: The file path and the source code of the function.
383379
384380
"""
381+
from binascii import crc32
382+
383+
from filelock import FileLock
385384

386385
origin = f" (patched from {origin})" if origin else ""
387386
source = f"# generated by awepatch{origin}\n{source}"
388-
389387
bsource = source.encode("utf-8")
390388

391-
cache_dir = get_cache_dir()
389+
cache_dir = _user_cache_dir()
392390
file_path = os.path.join(cache_dir, f"{type}_{name}_{crc32(bsource):010x}.py")
391+
393392
with FileLock(f"{file_path}.lock"):
394393
if not os.path.exists(file_path):
395394
with open(file_path, "wb") as f:

0 commit comments

Comments
 (0)