|
13 | 13 | # https://docs.python.org/3/license.html#python-software-foundation-license-version-2 |
14 | 14 | # |
15 | 15 | # |
16 | | -import importlib |
17 | | -import os |
18 | | -import sys |
19 | | -import traceback |
20 | | -from collections.abc import Generator, Iterable |
21 | | -from contextlib import contextmanager, suppress |
22 | | -from types import ModuleType |
23 | | -from typing import Callable, Optional |
24 | 16 |
|
| 17 | +from auditwall.core import AuditWallConfig, _default_audit_wall, accept, reject |
25 | 18 |
|
26 | | -class SideEffectDetectedError(Exception): |
27 | | - pass |
28 | 19 |
|
29 | | - |
30 | | -_BLOCKED_OPEN_FLAGS = os.O_WRONLY | os.O_RDWR | os.O_APPEND | os.O_CREAT | os.O_EXCL | os.O_TRUNC |
31 | | - |
32 | | - |
33 | | -def accept(event: str, args: tuple) -> None: |
34 | | - pass |
35 | | - |
36 | | - |
37 | | -args_allow_list = {".coverage", "matplotlib.rc", "codeflash"} |
38 | | - |
39 | | - |
40 | | -def reject(event: str, args: tuple) -> None: |
41 | | - msg = f'codeflash has detected: {event}{args}".' |
42 | | - raise SideEffectDetectedError(msg) |
43 | | - |
44 | | - |
45 | | -def inside_module(modules: Iterable[ModuleType]) -> bool: |
46 | | - files = {m.__file__ for m in modules} |
47 | | - return any(frame.f_code.co_filename in files for frame, lineno in traceback.walk_stack(None)) |
48 | | - |
49 | | - |
50 | | -def check_open(event: str, args: tuple) -> None: |
51 | | - (filename_or_descriptor, mode, flags) = args |
52 | | - if filename_or_descriptor in ("/dev/null", "nul"): |
53 | | - # (no-op writes on unix/windows) |
54 | | - return |
55 | | - if flags & _BLOCKED_OPEN_FLAGS: |
56 | | - msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})." |
57 | | - raise SideEffectDetectedError(msg) |
58 | | - |
59 | | - |
60 | | -def check_msvcrt_open(event: str, args: tuple) -> None: |
61 | | - (handle, flags) = args |
62 | | - if flags & _BLOCKED_OPEN_FLAGS: |
63 | | - msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})." |
64 | | - raise SideEffectDetectedError(msg) |
65 | | - |
66 | | - |
67 | | -_MODULES_THAT_CAN_POPEN: Optional[set[ModuleType]] = None |
68 | | - |
69 | | - |
70 | | -def modules_with_allowed_popen(): |
71 | | - global _MODULES_THAT_CAN_POPEN |
72 | | - if _MODULES_THAT_CAN_POPEN is None: |
73 | | - allowed_module_names = ("_aix_support", "ctypes", "platform", "uuid") |
74 | | - _MODULES_THAT_CAN_POPEN = set() |
75 | | - for module_name in allowed_module_names: |
76 | | - with suppress(ImportError): |
77 | | - _MODULES_THAT_CAN_POPEN.add(importlib.import_module(module_name)) |
78 | | - return _MODULES_THAT_CAN_POPEN |
79 | | - |
80 | | - |
81 | | -def check_subprocess(event: str, args: tuple) -> None: |
82 | | - if not inside_module(modules_with_allowed_popen()): |
83 | | - reject(event, args) |
| 20 | +class CodeflashAuditWallConfig(AuditWallConfig): |
| 21 | + def __init__(self) -> None: |
| 22 | + super().__init__() |
| 23 | + self.allowed_write_paths = {".coverage", "matplotlib.rc", "codeflash"} |
84 | 24 |
|
85 | 25 |
|
86 | 26 | def handle_os_remove(event: str, args: tuple) -> None: |
87 | 27 | filename = str(args[0]) |
88 | | - if any(pattern in filename for pattern in args_allow_list): |
| 28 | + if any( |
| 29 | + pattern in filename |
| 30 | + for pattern in _default_audit_wall.config.allowed_write_paths |
| 31 | + ): |
89 | 32 | accept(event, args) |
90 | 33 | else: |
91 | 34 | reject(event, args) |
92 | 35 |
|
93 | 36 |
|
94 | 37 | def check_sqlite_connect(event: str, args: tuple) -> None: |
95 | 38 | if ( |
96 | | - event == "sqlite3.connect" and any(pattern in str(args[0]) for pattern in args_allow_list) |
| 39 | + event == "sqlite3.connect" |
| 40 | + and any( |
| 41 | + pattern in str(args[0]) |
| 42 | + for pattern in _default_audit_wall.config.allowed_write_paths |
| 43 | + ) |
97 | 44 | ) or event == "sqlite3.connect/handle": |
98 | 45 | accept(event, args) |
99 | 46 | else: |
100 | 47 | reject(event, args) |
101 | 48 |
|
102 | 49 |
|
103 | | -_SPECIAL_HANDLERS = { |
104 | | - "open": check_open, |
105 | | - "subprocess.Popen": check_subprocess, |
106 | | - "msvcrt.open_osfhandle": check_msvcrt_open, |
| 50 | +custom_handlers = { |
| 51 | + "os.remove": handle_os_remove, |
107 | 52 | "sqlite3.connect": check_sqlite_connect, |
108 | 53 | "sqlite3.connect/handle": check_sqlite_connect, |
109 | | - "os.remove": handle_os_remove, |
110 | 54 | } |
111 | 55 |
|
112 | 56 |
|
113 | | -def make_handler(event: str) -> Callable[[str, tuple], None]: |
114 | | - special_handler = _SPECIAL_HANDLERS.get(event) |
115 | | - if special_handler: |
116 | | - return special_handler |
117 | | - # Block certain events |
118 | | - if event in ( |
119 | | - "winreg.CreateKey", |
120 | | - "winreg.DeleteKey", |
121 | | - "winreg.DeleteValue", |
122 | | - "winreg.SaveKey", |
123 | | - "winreg.SetValue", |
124 | | - "winreg.DisableReflectionKey", |
125 | | - "winreg.EnableReflectionKey", |
126 | | - ): |
127 | | - return reject |
128 | | - # Allow certain events. |
129 | | - if event in ( |
130 | | - # These seem not terribly dangerous to allow: |
131 | | - "os.putenv", |
132 | | - "os.unsetenv", |
133 | | - "msvcrt.heapmin", |
134 | | - "msvcrt.kbhit", |
135 | | - # These involve I/O, but are hopefully non-destructive: |
136 | | - "glob.glob", |
137 | | - "msvcrt.get_osfhandle", |
138 | | - "msvcrt.setmode", |
139 | | - "os.listdir", # (important for Python's importer) |
140 | | - "os.scandir", # (important for Python's importer) |
141 | | - "os.chdir", |
142 | | - "os.fwalk", |
143 | | - "os.getxattr", |
144 | | - "os.listxattr", |
145 | | - "os.walk", |
146 | | - "pathlib.Path.glob", |
147 | | - "socket.gethostbyname", # (FastAPI TestClient uses this) |
148 | | - "socket.__new__", # (FastAPI TestClient uses this) |
149 | | - "socket.bind", # pygls's asyncio needs this on windows |
150 | | - "socket.connect", # pygls's asyncio needs this on windows |
151 | | - ): |
152 | | - return accept |
153 | | - # Block groups of events. |
154 | | - event_prefix = event.split(".", 1)[0] |
155 | | - if event_prefix in ( |
156 | | - "fcntl", |
157 | | - "ftplib", |
158 | | - "glob", |
159 | | - "imaplib", |
160 | | - "msvcrt", |
161 | | - "nntplib", |
162 | | - "pathlib", |
163 | | - "poplib", |
164 | | - "shutil", |
165 | | - "smtplib", |
166 | | - "socket", |
167 | | - "sqlite3", |
168 | | - "subprocess", |
169 | | - "telnetlib", |
170 | | - "urllib", |
171 | | - "webbrowser", |
172 | | - ): |
173 | | - return reject |
174 | | - if event_prefix == "os" and event not in [ |
175 | | - "os.putenv", |
176 | | - "os.unsetenv", |
177 | | - "os.listdir", |
178 | | - "os.scandir", |
179 | | - "os.chdir", |
180 | | - "os.fwalk", |
181 | | - "os.getxattr", |
182 | | - "os.listxattr", |
183 | | - "os.walk", |
184 | | - ]: |
185 | | - return reject |
186 | | - # Allow other events. |
187 | | - return accept |
188 | | - |
189 | | - |
190 | | -_HANDLERS: dict[str, Callable[[str, tuple], None]] = {} |
191 | | -_ENABLED = True |
192 | | - |
193 | | - |
194 | | -def audithook(event: str, args: tuple) -> None: |
195 | | - if not _ENABLED: |
196 | | - return |
197 | | - handler = _HANDLERS.get(event) |
198 | | - if handler is None: |
199 | | - handler = make_handler(event) |
200 | | - _HANDLERS[event] = handler |
201 | | - handler(event, args) |
202 | | - |
203 | | - |
204 | | -@contextmanager |
205 | | -def opened_auditwall() -> Generator: |
206 | | - global _ENABLED |
207 | | - assert _ENABLED |
208 | | - _ENABLED = False |
209 | | - try: |
210 | | - yield |
211 | | - finally: |
212 | | - _ENABLED = True |
213 | | - |
214 | | - |
215 | | -def engage_auditwall() -> None: |
216 | | - sys.dont_write_bytecode = True # disable .pyc file writing |
217 | | - sys.addaudithook(audithook) |
218 | | - |
219 | | - |
220 | | -def disable_auditwall() -> None: |
221 | | - global _ENABLED |
222 | | - assert _ENABLED |
223 | | - _ENABLED = False |
| 57 | +_default_audit_wall.config = CodeflashAuditWallConfig() |
| 58 | +_default_audit_wall.config.special_handlers = custom_handlers |
0 commit comments