55import re
66import threading
77from abc import ABC , abstractmethod
8- from binascii import crc32
98from collections import defaultdict
9+ from collections .abc import Sequence
1010from dataclasses import KW_ONLY , dataclass
1111from typing import Any , Literal , TypeAlias , cast
1212
13- from filelock import FileLock
14- from platformdirs import user_cache_dir
15-
1613TYPE_CHECKING = False
1714
1815if TYPE_CHECKING :
19- from collections .abc import Mapping , Sequence
16+ from collections .abc import Mapping
2017 from types import TracebackType
2118
2219
@@ -101,7 +98,7 @@ def __str__(self) -> str:
10198
10299
103100CompiledPatches : TypeAlias = defaultdict [
104- Location , defaultdict [int , defaultdict [Mode , list [ast .stmt ]]]
101+ Location , defaultdict [int , dict [Mode , Sequence [ast .stmt ]]]
105102]
106103
107104
@@ -112,19 +109,16 @@ def append_patch(
112109 mode : Mode ,
113110) -> None :
114111 patches = complied [target [0 ]][target [1 ]]
115- # Check for conflicting patches,
116- if patches and mode == "replace" and "replace" in patches :
117- raise ValueError (f"Cannot have multiple 'replace' patches on target { target !r} " )
118112 if mode == "replace" :
119113 if "replace" in patches :
120114 raise ValueError (
121115 f"Multiple 'replace' patches on the same target { target !r} "
122116 )
123- patches [mode ]. extend ( stmts )
117+ patches [mode ] = stmts
124118 elif mode == "before" :
125- patches [mode ]. extend ( stmts )
119+ patches [mode ] = [ * stmts , * patches [ mode ]] if mode in patches else stmts
126120 elif mode == "after" :
127- patches [mode ] = [* patches [mode ], * stmts ]
121+ patches [mode ] = [* patches [mode ], * stmts ] if mode in patches else stmts
128122 else :
129123 raise ValueError (f"Unknown patch mode: { mode !r} " )
130124
@@ -350,8 +344,8 @@ def __exit__(
350344_cache_dir_lock = threading .Lock ()
351345
352346
353- def get_cache_dir () -> str :
354- """Get or create the temporary directory for awepatch."""
347+ def _user_cache_dir () -> str :
348+ """Get or create the cache directory for awepatch."""
355349 global _cache_dir
356350
357351 if _cache_dir is not None :
@@ -360,36 +354,41 @@ def get_cache_dir() -> str:
360354 with _cache_dir_lock :
361355 if _cache_dir is not None :
362356 return _cache_dir
357+ from platformdirs import user_cache_dir
358+
363359 _cache_dir = user_cache_dir ("awepatch" , appauthor = False , ensure_exists = True )
364360 return _cache_dir
365361
366362
367- def write_patched_source (
363+ def persist_patched_source (
368364 source : str ,
369365 name : str ,
370366 type : str ,
371367 origin : str = "" ,
372368) -> tuple [str , str ]:
373- """Load a function's code object from its AST module .
369+ """Persist the patched source code to a file and return the file path and source .
374370
375371 Args:
376- source (str) : The source code of the function.
377- name (str) : The name of the function.
378- type (str) : The type of the function (e.g., "module", "function").
379- origin (str, optional) : The origin location for the function. Defaults to "".
372+ source: The source code of the function.
373+ name: The name of the function.
374+ type: The type of the function (e.g., "module", "function").
375+ origin: The origin location for the function. Defaults to "".
380376
381377 Returns:
382- CodeType : The code object of the function.
378+ tuple[str, str] : The file path and the source code of the function.
383379
384380 """
381+ from binascii import crc32
382+
383+ from filelock import FileLock
385384
386385 origin = f" (patched from { origin } )" if origin else ""
387386 source = f"# generated by awepatch{ origin } \n { source } "
388-
389387 bsource = source .encode ("utf-8" )
390388
391- cache_dir = get_cache_dir ()
389+ cache_dir = _user_cache_dir ()
392390 file_path = os .path .join (cache_dir , f"{ type } _{ name } _{ crc32 (bsource ):010x} .py" )
391+
393392 with FileLock (f"{ file_path } .lock" ):
394393 if not os .path .exists (file_path ):
395394 with open (file_path , "wb" ) as f :
0 commit comments