Skip to content

Commit 414ac44

Browse files
committed
second pass
1 parent 3f524c2 commit 414ac44

File tree

5 files changed

+200
-18
lines changed

5 files changed

+200
-18
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,8 @@
33
import ast
44
import concurrent.futures
55
import os
6-
import re
7-
import shlex
86
import shutil
97
import subprocess
10-
import tempfile
118
import time
129
import uuid
1310
from collections import defaultdict
@@ -16,7 +13,6 @@
1613

1714
import isort
1815
import libcst as cst
19-
from crosshair.auditwall import SideEffectDetected
2016
from rich.console import Group
2117
from rich.panel import Panel
2218
from rich.syntax import Syntax
@@ -33,14 +29,12 @@
3329
get_run_tmp_file,
3430
module_name_from_file_path,
3531
)
36-
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
3732
from codeflash.code_utils.config_consts import (
3833
INDIVIDUAL_TESTCASE_TIMEOUT,
3934
N_CANDIDATES,
4035
N_TESTS_TO_GENERATE,
4136
TOTAL_LOOPING_TIME,
4237
)
43-
from codeflash.code_utils.coverage_utils import prepare_coverage_files
4438
from codeflash.code_utils.formatter import format_code, sort_imports
4539
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
4640
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
@@ -69,13 +63,13 @@
6963
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
7064
from codeflash.result.explanation import Explanation
7165
from codeflash.telemetry.posthog_cf import ph
72-
from codeflash.verification.codeflash_auditwall import transform_code
66+
from codeflash.verification._auditwall import SideEffectDetectedError
7367
from codeflash.verification.concolic_testing import generate_concolic_tests
7468
from codeflash.verification.equivalence import compare_test_results
7569
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
7670
from codeflash.verification.parse_test_output import parse_test_results
7771
from codeflash.verification.test_results import TestResults, TestType
78-
from codeflash.verification.test_runner import execute_test_subprocess, run_behavioral_tests, run_benchmarking_tests
72+
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
7973
from codeflash.verification.verification_utils import get_test_file_path
8074
from codeflash.verification.verifier import generate_tests
8175

@@ -853,8 +847,8 @@ def establish_original_code_baseline(
853847
enable_coverage=test_framework == "pytest",
854848
code_context=code_context,
855849
)
856-
except SideEffectDetected as e:
857-
return Failure(f"Side effect detected in original code: {e}")
850+
except SideEffectDetectedError as e:
851+
return Failure(f"Side effect detected in original code: {e}, skipping optimization.")
858852
finally:
859853
# Remove codeflash capture
860854
self.write_code_and_helpers(

codeflash/verification/_auditwall.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import importlib
2+
import os
3+
import sys
4+
import traceback
5+
from collections.abc import Generator, Iterable
6+
from contextlib import contextmanager, suppress
7+
from types import ModuleType
8+
from typing import Callable, Optional
9+
10+
11+
class SideEffectDetectedError(Exception):
12+
pass
13+
14+
15+
_BLOCKED_OPEN_FLAGS = os.O_WRONLY | os.O_RDWR | os.O_APPEND | os.O_CREAT | os.O_EXCL | os.O_TRUNC
16+
17+
18+
def accept(event: str, args: tuple) -> None:
19+
pass
20+
21+
22+
def reject(event: str, args: tuple) -> None:
23+
msg = f'codeflash has detected: {event}{args}".'
24+
raise SideEffectDetectedError(msg)
25+
26+
27+
def inside_module(modules: Iterable[ModuleType]) -> bool:
28+
files = {m.__file__ for m in modules}
29+
return any(frame.f_code.co_filename in files for frame, lineno in traceback.walk_stack(None))
30+
31+
32+
def check_open(event: str, args: tuple) -> None:
33+
(filename_or_descriptor, mode, flags) = args
34+
if filename_or_descriptor in ("/dev/null", "nul"):
35+
# (no-op writes on unix/windows)
36+
return
37+
if flags & _BLOCKED_OPEN_FLAGS:
38+
msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})."
39+
raise SideEffectDetectedError(msg)
40+
41+
42+
def check_msvcrt_open(event: str, args: tuple) -> None:
43+
print(args)
44+
(handle, flags) = args
45+
if flags & _BLOCKED_OPEN_FLAGS:
46+
msg = f"codeflash has detected: {event}({', '.join(map(repr, args))})."
47+
raise SideEffectDetectedError(msg)
48+
49+
50+
_MODULES_THAT_CAN_POPEN: Optional[set[ModuleType]] = None
51+
52+
53+
def modules_with_allowed_popen():
54+
global _MODULES_THAT_CAN_POPEN
55+
if _MODULES_THAT_CAN_POPEN is None:
56+
allowed_module_names = ("_aix_support", "ctypes", "platform", "uuid")
57+
_MODULES_THAT_CAN_POPEN = set()
58+
for module_name in allowed_module_names:
59+
with suppress(ImportError):
60+
_MODULES_THAT_CAN_POPEN.add(importlib.import_module(module_name))
61+
return _MODULES_THAT_CAN_POPEN
62+
63+
64+
def check_subprocess(event: str, args: tuple) -> None:
65+
if not inside_module(modules_with_allowed_popen()):
66+
reject(event, args)
67+
68+
69+
def check_sqlite_connect(event: str, args: tuple) -> None:
70+
if "codeflash_" in args[0]:
71+
accept(event, args)
72+
else:
73+
reject(event, args)
74+
75+
76+
_SPECIAL_HANDLERS = {
77+
"open": check_open,
78+
"subprocess.Popen": check_subprocess,
79+
"msvcrt.open_osfhandle": check_msvcrt_open,
80+
"sqlite3.connect": check_sqlite_connect,
81+
}
82+
83+
84+
def make_handler(event: str) -> Callable[[str, tuple], None]:
85+
special_handler = _SPECIAL_HANDLERS.get(event)
86+
if special_handler:
87+
return special_handler
88+
# Block certain events
89+
if event in (
90+
"winreg.CreateKey",
91+
"winreg.DeleteKey",
92+
"winreg.DeleteValue",
93+
"winreg.SaveKey",
94+
"winreg.SetValue",
95+
"winreg.DisableReflectionKey",
96+
"winreg.EnableReflectionKey",
97+
):
98+
return reject
99+
# Allow certain events.
100+
if event in (
101+
# These seem not terribly dangerous to allow:
102+
"os.putenv",
103+
"os.unsetenv",
104+
"msvcrt.heapmin",
105+
"msvcrt.kbhit",
106+
# These involve I/O, but are hopefully non-destructive:
107+
"glob.glob",
108+
"msvcrt.get_osfhandle",
109+
"msvcrt.setmode",
110+
"os.listdir", # (important for Python's importer)
111+
"os.scandir", # (important for Python's importer)
112+
"os.chdir",
113+
"os.fwalk",
114+
"os.getxattr",
115+
"os.listxattr",
116+
"os.walk",
117+
"pathlib.Path.glob",
118+
"socket.gethostbyname", # (FastAPI TestClient uses this)
119+
"socket.__new__", # (FastAPI TestClient uses this)
120+
"socket.bind", # pygls's asyncio needs this on windows
121+
"socket.connect", # pygls's asyncio needs this on windows
122+
):
123+
return accept
124+
# Block groups of events.
125+
event_prefix = event.split(".", 1)[0]
126+
if event_prefix in (
127+
"os",
128+
"fcntl",
129+
"ftplib",
130+
"glob",
131+
"imaplib",
132+
"msvcrt",
133+
"nntplib",
134+
"os",
135+
"pathlib",
136+
"poplib",
137+
"shutil",
138+
"smtplib",
139+
"socket",
140+
"sqlite3",
141+
"subprocess",
142+
"telnetlib",
143+
"urllib",
144+
"webbrowser",
145+
):
146+
return reject
147+
# Allow other events.
148+
return accept
149+
150+
151+
_HANDLERS: dict[str, Callable[[str, tuple], None]] = {}
152+
_ENABLED = True
153+
154+
155+
def audithook(event: str, args: tuple) -> None:
156+
if not _ENABLED:
157+
return
158+
handler = _HANDLERS.get(event)
159+
if handler is None:
160+
handler = make_handler(event)
161+
_HANDLERS[event] = handler
162+
handler(event, args)
163+
164+
165+
@contextmanager
166+
def opened_auditwall() -> Generator:
167+
global _ENABLED
168+
assert _ENABLED
169+
_ENABLED = False
170+
try:
171+
yield
172+
finally:
173+
_ENABLED = True
174+
175+
176+
def engage_auditwall() -> None:
177+
sys.dont_write_bytecode = True # disable .pyc file writing
178+
sys.addaudithook(audithook)
179+
180+
181+
def disable_auditwall() -> None:
182+
global _ENABLED
183+
assert _ENABLED
184+
_ENABLED = False

codeflash/verification/codeflash_auditwall.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ def visit_Module(self, node):
88
if isinstance(body_node, (ast.Import, ast.ImportFrom)):
99
last_import_index = i
1010

11-
new_import = ast.ImportFrom(module="crosshair.auditwall", names=[ast.alias(name="engage_auditwall")], level=0)
11+
new_import = ast.ImportFrom(
12+
module="codeflash.verification._auditwall", names=[ast.alias(name="engage_auditwall")], level=0
13+
)
1214
function_call = ast.Expr(
1315
value=ast.Call(func=ast.Name(id="engage_auditwall", ctx=ast.Load()), args=[], keywords=[])
1416
)

codeflash/verification/test_runner.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
from pathlib import Path
88
from typing import TYPE_CHECKING
99

10-
from crosshair.auditwall import SideEffectDetected
11-
1210
from codeflash.cli_cmds.console import logger
1311
from codeflash.code_utils.code_utils import get_run_tmp_file
1412
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
1513
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME
1614
from codeflash.code_utils.coverage_utils import prepare_coverage_files
1715
from codeflash.models.models import TestFiles
16+
from codeflash.verification._auditwall import SideEffectDetectedError
1817
from codeflash.verification.codeflash_auditwall import transform_code
1918
from codeflash.verification.test_results import TestType
2019

@@ -90,22 +89,20 @@ def run_behavioral_tests(
9089
env=pytest_test_env,
9190
timeout=600,
9291
)
93-
9492
if auditing_res.returncode != 0:
9593
line_co = next(
9694
(
9795
line
9896
for line in auditing_res.stderr.splitlines() + auditing_res.stdout.splitlines()
99-
if "crosshair.auditwall.SideEffectDetected" in line
97+
if "codeflash.verification._auditwall.SideEffectDetectedError" in line
10098
),
10199
None,
102100
)
103-
104101
if line_co:
105-
match = re.search(r"crosshair\.auditwall\.SideEffectDetected: A(.*) operation was detected\.", line_co)
102+
match = re.search(r"codeflash has detected: (.+).", line_co)
106103
if match:
107104
msg = match.group(1)
108-
raise SideEffectDetected(msg)
105+
raise SideEffectDetectedError(msg)
109106
logger.debug(auditing_res.stderr)
110107
logger.debug(auditing_res.stdout)
111108

tests/test_codeflash_capture.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def __init__(self, x=2):
457457
test_type=test_type,
458458
original_file_path=test_path,
459459
benchmarking_file_path=test_path_perf,
460+
original_source=test_code,
460461
)
461462
]
462463
)
@@ -568,6 +569,7 @@ def __init__(self, *args, **kwargs):
568569
test_type=test_type,
569570
original_file_path=test_path,
570571
benchmarking_file_path=test_path_perf,
572+
original_source=test_code,
571573
)
572574
]
573575
)
@@ -681,6 +683,7 @@ def __init__(self, x=2):
681683
test_type=test_type,
682684
original_file_path=test_path,
683685
benchmarking_file_path=test_path_perf,
686+
original_source=test_code,
684687
)
685688
]
686689
)
@@ -831,6 +834,7 @@ def another_helper(self):
831834
test_type=test_type,
832835
original_file_path=test_path,
833836
benchmarking_file_path=test_path_perf,
837+
original_source=test_code,
834838
)
835839
]
836840
)
@@ -967,6 +971,7 @@ def another_helper(self):
967971
test_type=test_type,
968972
original_file_path=test_path,
969973
benchmarking_file_path=test_path_perf,
974+
original_source=test_code,
970975
)
971976
]
972977
)

0 commit comments

Comments
 (0)