Skip to content

Commit 3f524c2

Browse files
committed
first pass
1 parent 3510312 commit 3f524c2

File tree

2 files changed

+103
-72
lines changed

2 files changed

+103
-72
lines changed

codeflash/optimization/function_optimizer.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
import ast
44
import concurrent.futures
55
import os
6+
import re
7+
import shlex
68
import shutil
79
import subprocess
10+
import tempfile
811
import time
912
import uuid
1013
from collections import defaultdict
@@ -13,6 +16,7 @@
1316

1417
import isort
1518
import libcst as cst
19+
from crosshair.auditwall import SideEffectDetected
1620
from rich.console import Group
1721
from rich.panel import Panel
1822
from rich.syntax import Syntax
@@ -29,12 +33,14 @@
2933
get_run_tmp_file,
3034
module_name_from_file_path,
3135
)
36+
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
3237
from codeflash.code_utils.config_consts import (
3338
INDIVIDUAL_TESTCASE_TIMEOUT,
3439
N_CANDIDATES,
3540
N_TESTS_TO_GENERATE,
3641
TOTAL_LOOPING_TIME,
3742
)
43+
from codeflash.code_utils.coverage_utils import prepare_coverage_files
3844
from codeflash.code_utils.formatter import format_code, sort_imports
3945
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
4046
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
@@ -63,12 +69,13 @@
6369
from codeflash.result.critic import coverage_critic, performance_gain, quantity_of_tests_critic, speedup_critic
6470
from codeflash.result.explanation import Explanation
6571
from codeflash.telemetry.posthog_cf import ph
72+
from codeflash.verification.codeflash_auditwall import transform_code
6673
from codeflash.verification.concolic_testing import generate_concolic_tests
6774
from codeflash.verification.equivalence import compare_test_results
6875
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
6976
from codeflash.verification.parse_test_output import parse_test_results
7077
from codeflash.verification.test_results import TestResults, TestType
71-
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
78+
from codeflash.verification.test_runner import execute_test_subprocess, run_behavioral_tests, run_benchmarking_tests
7279
from codeflash.verification.verification_utils import get_test_file_path
7380
from codeflash.verification.verifier import generate_tests
7481

@@ -149,12 +156,14 @@ def optimize_function(self) -> Result[BestOptimization, str]:
149156
self.args.project_root,
150157
)
151158

159+
152160
generated_test_paths = [
153161
get_test_file_path(
154162
self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="unit"
155163
)
156164
for test_index in range(N_TESTS_TO_GENERATE)
157165
]
166+
158167
generated_perf_test_paths = [
159168
get_test_file_path(
160169
self.test_cfg.tests_root, self.function_to_optimize.function_name, test_index, test_type="perf"
@@ -844,6 +853,8 @@ def establish_original_code_baseline(
844853
enable_coverage=test_framework == "pytest",
845854
code_context=code_context,
846855
)
856+
except SideEffectDetected as e:
857+
return Failure(f"Side effect detected in original code: {e}")
847858
finally:
848859
# Remove codeflash capture
849860
self.write_code_and_helpers(
@@ -855,9 +866,7 @@ def establish_original_code_baseline(
855866
)
856867
console.rule()
857868
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
858-
if not coverage_critic(
859-
coverage_results, self.args.test_framework
860-
):
869+
if not coverage_critic(coverage_results, self.args.test_framework):
861870
return Failure("The threshold for test coverage was not met.")
862871
if test_framework == "pytest":
863872
benchmarking_results, _ = self.run_and_parse_tests(
@@ -898,7 +907,6 @@ def establish_original_code_baseline(
898907
)
899908
console.rule()
900909

901-
902910
total_timing = benchmarking_results.total_passed_runtime() # caution: doesn't handle the loop index
903911
functions_to_remove = [
904912
result.id.test_function_name
@@ -1097,13 +1105,13 @@ def run_and_parse_tests(
10971105
raise ValueError(f"Unexpected testing type: {testing_type}")
10981106
except subprocess.TimeoutExpired:
10991107
logger.exception(
1100-
f'Error running tests in {", ".join(str(f) for f in test_files.test_files)}.\nTimeout Error'
1108+
f"Error running tests in {', '.join(str(f) for f in test_files.test_files)}.\nTimeout Error"
11011109
)
11021110
return TestResults(), None
11031111
if run_result.returncode != 0 and testing_type == TestingMode.BEHAVIOR:
11041112
logger.debug(
1105-
f'Nonzero return code {run_result.returncode} when running tests in '
1106-
f'{", ".join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n'
1113+
f"Nonzero return code {run_result.returncode} when running tests in "
1114+
f"{', '.join([str(f.instrumented_behavior_file_path) for f in test_files.test_files])}.\n"
11071115
f"stdout: {run_result.stdout}\n"
11081116
f"stderr: {run_result.stderr}\n"
11091117
)
@@ -1149,4 +1157,3 @@ def generate_and_instrument_tests(
11491157
zip(generated_test_paths, generated_perf_test_paths)
11501158
)
11511159
]
1152-

codeflash/verification/test_runner.py

Lines changed: 87 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from __future__ import annotations
22

3+
import re
34
import shlex
45
import subprocess
6+
import tempfile
57
from pathlib import Path
68
from typing import TYPE_CHECKING
79

10+
from crosshair.auditwall import SideEffectDetected
11+
812
from codeflash.cli_cmds.console import logger
913
from codeflash.code_utils.code_utils import get_run_tmp_file
1014
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
1115
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME
1216
from codeflash.code_utils.coverage_utils import prepare_coverage_files
1317
from codeflash.models.models import TestFiles
18+
from codeflash.verification.codeflash_auditwall import transform_code
1419
from codeflash.verification.test_results import TestType
1520

1621
if TYPE_CHECKING:
@@ -36,78 +41,97 @@ def run_behavioral_tests(
3641
pytest_target_runtime_seconds: int = TOTAL_LOOPING_TIME,
3742
enable_coverage: bool = False,
3843
) -> tuple[Path, subprocess.CompletedProcess, Path | None]:
39-
if test_framework == "pytest":
40-
test_files: list[str] = []
41-
for file in test_paths.test_files:
42-
if file.test_type == TestType.REPLAY_TEST:
43-
# TODO: Does this work for unittest framework?
44-
test_files.extend(
45-
[
46-
str(file.instrumented_behavior_file_path) + "::" + test.test_function
47-
for test in file.tests_in_file
48-
]
49-
)
50-
else:
51-
test_files.append(str(file.instrumented_behavior_file_path))
52-
test_files = list(set(test_files)) # remove multiple calls in the same test function
53-
pytest_cmd_list = shlex.split(pytest_cmd, posix=IS_POSIX)
54-
55-
common_pytest_args = [
56-
"--capture=tee-sys",
57-
f"--timeout={pytest_timeout}",
58-
"-q",
59-
"--codeflash_loops_scope=session",
60-
"--codeflash_min_loops=1",
61-
"--codeflash_max_loops=1",
62-
f"--codeflash_seconds={pytest_target_runtime_seconds}", # TODO :This is unnecessary, update the plugin to not ask for this
63-
]
64-
65-
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
66-
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
67-
68-
pytest_test_env = test_env.copy()
69-
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
44+
if test_framework not in ["pytest", "unittest"]:
45+
raise ValueError(f"Unsupported test framework: {test_framework}")
7046

71-
if enable_coverage:
72-
coverage_database_file, coveragercfile = prepare_coverage_files()
73-
74-
cov_erase = execute_test_subprocess(
75-
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env
76-
) # this cleanup is necessary to avoid coverage data from previous runs, if there are any,
77-
# then the current run will be appended to the previous data, which skews the results
78-
logger.debug(cov_erase)
79-
80-
results = execute_test_subprocess(
81-
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage run --rcfile={coveragercfile.as_posix()} -m")
82-
+ pytest_cmd_list
83-
+ common_pytest_args
84-
+ result_args
85-
+ test_files,
86-
cwd=cwd,
87-
env=pytest_test_env,
88-
timeout=600,
47+
test_files: list[str] = []
48+
for file in test_paths.test_files:
49+
if file.test_type == TestType.REPLAY_TEST:
50+
# TODO: Does this work for unittest framework?
51+
test_files.extend(
52+
[str(file.instrumented_behavior_file_path) + "::" + test.test_function for test in file.tests_in_file]
8953
)
90-
logger.debug(
91-
f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""")
9254
else:
93-
results = execute_test_subprocess(
94-
pytest_cmd_list + common_pytest_args + result_args + test_files,
95-
cwd=cwd,
96-
env=pytest_test_env,
97-
timeout=600, # TODO: Make this dynamic
55+
test_files.append(str(file.instrumented_behavior_file_path))
56+
57+
source_code = next((file.original_source for file in test_paths.test_files if file.original_source), None)
58+
if not source_code:
59+
raise ValueError("No source code found for auditing")
60+
61+
audit_code = transform_code(source_code)
62+
pytest_cmd_list = shlex.split(pytest_cmd, posix=IS_POSIX)
63+
common_pytest_args = [
64+
"--capture=tee-sys",
65+
f"--timeout={pytest_timeout}",
66+
"-q",
67+
"--codeflash_loops_scope=session",
68+
"--codeflash_min_loops=1",
69+
"--codeflash_max_loops=1",
70+
f"--codeflash_seconds={pytest_target_runtime_seconds}",
71+
"-p",
72+
"no:cacheprovider",
73+
]
74+
75+
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
76+
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
77+
78+
pytest_test_env = test_env.copy()
79+
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
80+
81+
with tempfile.TemporaryDirectory(
82+
dir=Path(test_paths.test_files[0].instrumented_behavior_file_path).parent
83+
) as temp_dir:
84+
audited_file_path = Path(temp_dir) / "audited_code.py"
85+
audited_file_path.write_text(audit_code, encoding="utf8")
86+
87+
auditing_res = execute_test_subprocess(
88+
pytest_cmd_list + common_pytest_args + [audited_file_path.as_posix()],
89+
cwd=cwd,
90+
env=pytest_test_env,
91+
timeout=600,
92+
)
93+
94+
if auditing_res.returncode != 0:
95+
line_co = next(
96+
(
97+
line
98+
for line in auditing_res.stderr.splitlines() + auditing_res.stdout.splitlines()
99+
if "crosshair.auditwall.SideEffectDetected" in line
100+
),
101+
None,
98102
)
99-
logger.debug(
100-
f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""")
101-
elif test_framework == "unittest":
103+
104+
if line_co:
105+
match = re.search(r"crosshair\.auditwall\.SideEffectDetected: A(.*) operation was detected\.", line_co)
106+
if match:
107+
msg = match.group(1)
108+
raise SideEffectDetected(msg)
109+
logger.debug(auditing_res.stderr)
110+
logger.debug(auditing_res.stdout)
111+
112+
if test_framework == "pytest":
113+
coverage_database_file, coveragercfile = prepare_coverage_files()
114+
cov_erase = execute_test_subprocess(
115+
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env
116+
)
117+
logger.debug(cov_erase)
118+
119+
results = execute_test_subprocess(
120+
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage run --rcfile={coveragercfile.as_posix()} -m")
121+
+ pytest_cmd_list
122+
+ common_pytest_args
123+
+ result_args
124+
+ list(set(test_files)), # remove duplicates
125+
cwd=cwd,
126+
env=pytest_test_env,
127+
timeout=600,
128+
)
129+
else: # unittest
102130
if enable_coverage:
103131
raise ValueError("Coverage is not supported yet for unittest framework")
104132
test_env["CODEFLASH_LOOP_INDEX"] = "1"
105133
test_files = [file.instrumented_behavior_file_path for file in test_paths.test_files]
106134
result_file_path, results = run_unittest_tests(verbose, test_files, test_env, cwd)
107-
logger.debug(
108-
f"""Result return code: {results.returncode}, {"Result stderr:" + str(results.stderr) if results.stderr else ''}""")
109-
else:
110-
raise ValueError(f"Unsupported test framework: {test_framework}")
111135

112136
return result_file_path, results, coverage_database_file if enable_coverage else None
113137

0 commit comments

Comments
 (0)