Skip to content

Commit c118a16

Browse files
committed
new way
1 parent 617a888 commit c118a16

File tree

2 files changed

+20
-186
lines changed

2 files changed

+20
-186
lines changed

codeflash/verification/_auditwall.py

Lines changed: 18 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -13,211 +13,46 @@
1313
# https://docs.python.org/3/license.html#python-software-foundation-license-version-2
1414
#
1515
#
16-
import importlib
17-
import os
18-
import sys
19-
import traceback
20-
from collections.abc import Generator, Iterable
21-
from contextlib import contextmanager, suppress
22-
from types import ModuleType
23-
from typing import Callable, Optional
2416

17+
from auditwall.core import AuditWallConfig, _default_audit_wall, accept, reject
2518

26-
class SideEffectDetectedError(Exception):
27-
pass
2819

29-
30-
_BLOCKED_OPEN_FLAGS = os.O_WRONLY | os.O_RDWR | os.O_APPEND | os.O_CREAT | os.O_EXCL | os.O_TRUNC
31-
32-
33-
def accept(event: str, args: tuple) -> None:
34-
pass
35-
36-
37-
args_allow_list = {".coverage", "matplotlib.rc", "codeflash"}
38-
39-
40-
def reject(event: str, args: tuple) -> None:
41-
msg = f'codeflash has detected: {event}{args}".'
42-
raise SideEffectDetectedError(msg)
43-
44-
45-
def inside_module(modules: Iterable[ModuleType]) -> bool:
46-
files = {m.__file__ for m in modules}
47-
return any(frame.f_code.co_filename in files for frame, lineno in traceback.walk_stack(None))
48-
49-
50-
def check_open(event: str, args: tuple) -> None:
51-
(filename_or_descriptor, mode, flags) = args
52-
if filename_or_descriptor in ("/dev/null", "nul"):
53-
# (no-op writes on unix/windows)
54-
return
55-
if flags & _BLOCKED_OPEN_FLAGS:
56-
msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})."
57-
raise SideEffectDetectedError(msg)
58-
59-
60-
def check_msvcrt_open(event: str, args: tuple) -> None:
61-
(handle, flags) = args
62-
if flags & _BLOCKED_OPEN_FLAGS:
63-
msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})."
64-
raise SideEffectDetectedError(msg)
65-
66-
67-
_MODULES_THAT_CAN_POPEN: Optional[set[ModuleType]] = None
68-
69-
70-
def modules_with_allowed_popen():
71-
global _MODULES_THAT_CAN_POPEN
72-
if _MODULES_THAT_CAN_POPEN is None:
73-
allowed_module_names = ("_aix_support", "ctypes", "platform", "uuid")
74-
_MODULES_THAT_CAN_POPEN = set()
75-
for module_name in allowed_module_names:
76-
with suppress(ImportError):
77-
_MODULES_THAT_CAN_POPEN.add(importlib.import_module(module_name))
78-
return _MODULES_THAT_CAN_POPEN
79-
80-
81-
def check_subprocess(event: str, args: tuple) -> None:
82-
if not inside_module(modules_with_allowed_popen()):
83-
reject(event, args)
20+
class CodeflashAuditWallConfig(AuditWallConfig):
21+
def __init__(self) -> None:
22+
super().__init__()
23+
self.allowed_write_paths = {".coverage", "matplotlib.rc", "codeflash"}
8424

8525

8626
def handle_os_remove(event: str, args: tuple) -> None:
8727
filename = str(args[0])
88-
if any(pattern in filename for pattern in args_allow_list):
28+
if any(
29+
pattern in filename
30+
for pattern in _default_audit_wall.config.allowed_write_paths
31+
):
8932
accept(event, args)
9033
else:
9134
reject(event, args)
9235

9336

9437
def check_sqlite_connect(event: str, args: tuple) -> None:
9538
if (
96-
event == "sqlite3.connect" and any(pattern in str(args[0]) for pattern in args_allow_list)
39+
event == "sqlite3.connect"
40+
and any(
41+
pattern in str(args[0])
42+
for pattern in _default_audit_wall.config.allowed_write_paths
43+
)
9744
) or event == "sqlite3.connect/handle":
9845
accept(event, args)
9946
else:
10047
reject(event, args)
10148

10249

103-
_SPECIAL_HANDLERS = {
104-
"open": check_open,
105-
"subprocess.Popen": check_subprocess,
106-
"msvcrt.open_osfhandle": check_msvcrt_open,
50+
custom_handlers = {
51+
"os.remove": handle_os_remove,
10752
"sqlite3.connect": check_sqlite_connect,
10853
"sqlite3.connect/handle": check_sqlite_connect,
109-
"os.remove": handle_os_remove,
11054
}
11155

11256

113-
def make_handler(event: str) -> Callable[[str, tuple], None]:
114-
special_handler = _SPECIAL_HANDLERS.get(event)
115-
if special_handler:
116-
return special_handler
117-
# Block certain events
118-
if event in (
119-
"winreg.CreateKey",
120-
"winreg.DeleteKey",
121-
"winreg.DeleteValue",
122-
"winreg.SaveKey",
123-
"winreg.SetValue",
124-
"winreg.DisableReflectionKey",
125-
"winreg.EnableReflectionKey",
126-
):
127-
return reject
128-
# Allow certain events.
129-
if event in (
130-
# These seem not terribly dangerous to allow:
131-
"os.putenv",
132-
"os.unsetenv",
133-
"msvcrt.heapmin",
134-
"msvcrt.kbhit",
135-
# These involve I/O, but are hopefully non-destructive:
136-
"glob.glob",
137-
"msvcrt.get_osfhandle",
138-
"msvcrt.setmode",
139-
"os.listdir", # (important for Python's importer)
140-
"os.scandir", # (important for Python's importer)
141-
"os.chdir",
142-
"os.fwalk",
143-
"os.getxattr",
144-
"os.listxattr",
145-
"os.walk",
146-
"pathlib.Path.glob",
147-
"socket.gethostbyname", # (FastAPI TestClient uses this)
148-
"socket.__new__", # (FastAPI TestClient uses this)
149-
"socket.bind", # pygls's asyncio needs this on windows
150-
"socket.connect", # pygls's asyncio needs this on windows
151-
):
152-
return accept
153-
# Block groups of events.
154-
event_prefix = event.split(".", 1)[0]
155-
if event_prefix in (
156-
"fcntl",
157-
"ftplib",
158-
"glob",
159-
"imaplib",
160-
"msvcrt",
161-
"nntplib",
162-
"pathlib",
163-
"poplib",
164-
"shutil",
165-
"smtplib",
166-
"socket",
167-
"sqlite3",
168-
"subprocess",
169-
"telnetlib",
170-
"urllib",
171-
"webbrowser",
172-
):
173-
return reject
174-
if event_prefix == "os" and event not in [
175-
"os.putenv",
176-
"os.unsetenv",
177-
"os.listdir",
178-
"os.scandir",
179-
"os.chdir",
180-
"os.fwalk",
181-
"os.getxattr",
182-
"os.listxattr",
183-
"os.walk",
184-
]:
185-
return reject
186-
# Allow other events.
187-
return accept
188-
189-
190-
_HANDLERS: dict[str, Callable[[str, tuple], None]] = {}
191-
_ENABLED = True
192-
193-
194-
def audithook(event: str, args: tuple) -> None:
195-
if not _ENABLED:
196-
return
197-
handler = _HANDLERS.get(event)
198-
if handler is None:
199-
handler = make_handler(event)
200-
_HANDLERS[event] = handler
201-
handler(event, args)
202-
203-
204-
@contextmanager
205-
def opened_auditwall() -> Generator:
206-
global _ENABLED
207-
assert _ENABLED
208-
_ENABLED = False
209-
try:
210-
yield
211-
finally:
212-
_ENABLED = True
213-
214-
215-
def engage_auditwall() -> None:
216-
sys.dont_write_bytecode = True # disable .pyc file writing
217-
sys.addaudithook(audithook)
218-
219-
220-
def disable_auditwall() -> None:
221-
global _ENABLED
222-
assert _ENABLED
223-
_ENABLED = False
57+
_default_audit_wall.config = CodeflashAuditWallConfig()
58+
_default_audit_wall.config.special_handlers = custom_handlers

codeflash/verification/codeflash_auditwall.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33

44
class AuditWallTransformer(ast.NodeTransformer):
5-
def visit_Module(self, node):
5+
def visit_Module(self, node: ast.Module) -> ast.Module: # noqa: N802
66
last_import_index = -1
77
for i, body_node in enumerate(node.body):
88
if isinstance(body_node, (ast.Import, ast.ImportFrom)):
99
last_import_index = i
1010

1111
new_import = ast.ImportFrom(
12-
module="codeflash.verification._auditwall", names=[ast.alias(name="engage_auditwall")], level=0
12+
module="auditwall.core", names=[ast.alias(name="engage_auditwall")], level=0
1313
)
1414
function_call = ast.Expr(
1515
value=ast.Call(func=ast.Name(id="engage_auditwall", ctx=ast.Load()), args=[], keywords=[])
@@ -20,7 +20,6 @@ def visit_Module(self, node):
2020

2121
return node
2222

23-
2423
def transform_code(source_code: str) -> str:
2524
tree = ast.parse(source_code)
2625
transformer = AuditWallTransformer()

0 commit comments

Comments
 (0)