Skip to content

Commit cc2074e

Browse files
authored
Merge branch 'main' into saga4/micro_fix_12
2 parents f7aaaab + 2ed293e commit cc2074e

File tree

6 files changed

+119
-35
lines changed

6 files changed

+119
-35
lines changed

codeflash/code_utils/code_utils.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,87 @@
22

33
import ast
44
import os
5+
import re
56
import shutil
67
import site
8+
from contextlib import contextmanager
79
from functools import lru_cache
810
from pathlib import Path
911
from tempfile import TemporaryDirectory
1012

13+
import tomlkit
14+
1115
from codeflash.cli_cmds.console import logger
16+
from codeflash.code_utils.config_parser import find_pyproject_toml
17+
18+
19+
@contextmanager
20+
def custom_addopts() -> None:
21+
pyproject_file = find_pyproject_toml()
22+
original_content = None
23+
non_blacklist_plugin_args = ""
24+
25+
try:
26+
# Read original file
27+
if pyproject_file.exists():
28+
with Path.open(pyproject_file, encoding="utf-8") as f:
29+
original_content = f.read()
30+
data = tomlkit.parse(original_content)
31+
# Backup original addopts
32+
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
33+
# nothing to do if no addopts present
34+
if original_addopts != "":
35+
original_addopts = [x.strip() for x in original_addopts]
36+
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
37+
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]
38+
if non_blacklist_plugin_args != original_addopts:
39+
data["tool"]["pytest"]["ini_options"]["addopts"] = non_blacklist_plugin_args
40+
# Write modified file
41+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
42+
f.write(tomlkit.dumps(data))
43+
44+
yield
45+
46+
finally:
47+
# Restore original file
48+
if (
49+
original_content
50+
and pyproject_file.exists()
51+
and tuple(original_addopts) not in {(), tuple(non_blacklist_plugin_args)}
52+
):
53+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
54+
f.write(original_content)
55+
56+
57+
@contextmanager
58+
def add_addopts_to_pyproject() -> None:
59+
pyproject_file = find_pyproject_toml()
60+
original_content = None
61+
try:
62+
# Read original file
63+
if pyproject_file.exists():
64+
with Path.open(pyproject_file, encoding="utf-8") as f:
65+
original_content = f.read()
66+
data = tomlkit.parse(original_content)
67+
data["tool"]["pytest"] = {}
68+
data["tool"]["pytest"]["ini_options"] = {}
69+
data["tool"]["pytest"]["ini_options"]["addopts"] = [
70+
"-n=auto",
71+
"-n",
72+
"1",
73+
"-n 1",
74+
"-n 1",
75+
"-n auto",
76+
]
77+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
78+
f.write(tomlkit.dumps(data))
79+
80+
yield
81+
82+
finally:
83+
# Restore original file
84+
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
85+
f.write(original_content)
1286

1387

1488
def encoded_tokens_len(s: str) -> int:

codeflash/discovery/discover_unit_tests.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from pydantic.dataclasses import dataclass
1818

1919
from codeflash.cli_cmds.console import console, logger, test_files_progress_bar
20-
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
20+
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file, module_name_from_file_path
2121
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_cache_db
2222
from codeflash.models.models import CodePosition, FunctionCalledInTest, TestsInFile, TestType
2323

@@ -150,19 +150,20 @@ def discover_tests_pytest(
150150
project_root = cfg.project_root_path
151151

152152
tmp_pickle_path = get_run_tmp_file("collected_tests.pkl")
153-
result = subprocess.run(
154-
[
155-
SAFE_SYS_EXECUTABLE,
156-
Path(__file__).parent / "pytest_new_process_discovery.py",
157-
str(project_root),
158-
str(tests_root),
159-
str(tmp_pickle_path),
160-
],
161-
cwd=project_root,
162-
check=False,
163-
capture_output=True,
164-
text=True,
165-
)
153+
with custom_addopts():
154+
result = subprocess.run(
155+
[
156+
SAFE_SYS_EXECUTABLE,
157+
Path(__file__).parent / "pytest_new_process_discovery.py",
158+
str(project_root),
159+
str(tests_root),
160+
str(tmp_pickle_path),
161+
],
162+
cwd=project_root,
163+
check=False,
164+
capture_output=True,
165+
text=True,
166+
)
166167
try:
167168
with tmp_pickle_path.open(mode="rb") as f:
168169
exitcode, tests, pytest_rootdir = pickle.load(f)

codeflash/verification/test_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import TYPE_CHECKING
77

88
from codeflash.cli_cmds.console import logger
9-
from codeflash.code_utils.code_utils import get_run_tmp_file
9+
from codeflash.code_utils.code_utils import custom_addopts, get_run_tmp_file
1010
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
1111
from codeflash.code_utils.config_consts import TOTAL_LOOPING_TIME
1212
from codeflash.code_utils.coverage_utils import prepare_coverage_files
@@ -23,8 +23,9 @@ def execute_test_subprocess(
2323
cmd_list: list[str], cwd: Path, env: dict[str, str] | None, timeout: int = 600
2424
) -> subprocess.CompletedProcess:
2525
"""Execute a subprocess with the given command list, working directory, environment variables, and timeout."""
26-
logger.debug(f"executing test run with command: {' '.join(cmd_list)}")
27-
return subprocess.run(cmd_list, capture_output=True, cwd=cwd, env=env, text=True, timeout=timeout, check=False)
26+
with custom_addopts():
27+
logger.debug(f"executing test run with command: {' '.join(cmd_list)}")
28+
return subprocess.run(cmd_list, capture_output=True, cwd=cwd, env=env, text=True, timeout=timeout, check=False)
2829

2930

3031
def run_behavioral_tests(
@@ -97,6 +98,7 @@ def run_behavioral_tests(
9798
coverage_cmd.extend(shlex.split(pytest_cmd, posix=IS_POSIX)[1:])
9899

99100
blocklist_args = [f"-p no:{plugin}" for plugin in BEHAVIORAL_BLOCKLISTED_PLUGINS if plugin != "cov"]
101+
100102
results = execute_test_subprocess(
101103
coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files,
102104
cwd=cwd,
@@ -252,7 +254,6 @@ def run_benchmarking_tests(
252254
pytest_test_env = test_env.copy()
253255
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
254256
blocklist_args = [f"-p no:{plugin}" for plugin in BENCHMARKING_BLOCKLISTED_PLUGINS]
255-
256257
results = execute_test_subprocess(
257258
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
258259
cwd=cwd,
@@ -278,7 +279,6 @@ def run_unittest_tests(
278279
log_level = ["-v"] if verbose else []
279280
files = [str(file) for file in test_file_paths]
280281
output_file = ["--output-file", str(result_file_path)]
281-
282282
results = execute_test_subprocess(
283283
unittest_cmd_list + log_level + files + output_file, cwd=cwd, env=test_env, timeout=600
284284
)

tests/scripts/end_to_end_test_topological_sort.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,29 @@
11
import os
22
import pathlib
3+
import tomlkit
34

5+
from codeflash.code_utils.code_utils import add_addopts_to_pyproject
46
from end_to_end_test_utilities import CoverageExpectation, TestConfig, run_codeflash_command, run_with_retries
57

68

79
def run_test(expected_improvement_pct: int) -> bool:
8-
config = TestConfig(
9-
file_path="topological_sort.py",
10-
function_name="Graph.topologicalSort",
11-
test_framework="pytest",
12-
min_improvement_x=0.05,
13-
coverage_expectations=[
14-
CoverageExpectation(
15-
function_name="Graph.topologicalSort", expected_coverage=100.0, expected_lines=[24, 25, 26, 27, 28, 29]
16-
)
17-
],
18-
)
19-
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
20-
return run_codeflash_command(cwd, config, expected_improvement_pct)
10+
with add_addopts_to_pyproject():
11+
config = TestConfig(
12+
file_path="topological_sort.py",
13+
function_name="Graph.topologicalSort",
14+
test_framework="pytest",
15+
min_improvement_x=0.05,
16+
coverage_expectations=[
17+
CoverageExpectation(
18+
function_name="Graph.topologicalSort",
19+
expected_coverage=100.0,
20+
expected_lines=[24, 25, 26, 27, 28, 29],
21+
)
22+
],
23+
)
24+
cwd = (pathlib.Path(__file__).parent.parent.parent / "code_to_optimize").resolve()
25+
return_var = run_codeflash_command(cwd, config, expected_improvement_pct)
26+
return return_var
2127

2228

2329
if __name__ == "__main__":

tests/scripts/end_to_end_test_tracer_replay.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@ def run_test(expected_improvement_pct: int) -> bool:
1010
min_improvement_x=0.1,
1111
expected_unit_tests=1,
1212
coverage_expectations=[
13-
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[5, 6, 7, 8, 10, 13]),
13+
CoverageExpectation(function_name="funcA", expected_coverage=100.0, expected_lines=[5, 6, 7, 8, 10, 13])
1414
],
1515
)
1616
cwd = (
1717
pathlib.Path(__file__).parent.parent.parent / "code_to_optimize" / "code_directories" / "simple_tracer_e2e"
1818
).resolve()
1919
return run_codeflash_command(cwd, config, expected_improvement_pct)
2020

21+
2122
if __name__ == "__main__":
2223
exit(run_with_retries(run_test, int(os.getenv("EXPECTED_IMPROVEMENT_PCT", 10))))

tests/scripts/end_to_end_test_utilities.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def run_codeflash_command(
117117
return validated
118118

119119

120-
def build_command(cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path, benchmarks_root:pathlib.Path|None = None) -> list[str]:
120+
def build_command(
121+
cwd: pathlib.Path, config: TestConfig, test_root: pathlib.Path, benchmarks_root: pathlib.Path | None = None
122+
) -> list[str]:
121123
python_path = "../../../codeflash/main.py" if "code_directories" in str(cwd) else "../codeflash/main.py"
122124

123125
base_command = ["python", python_path, "--file", config.file_path, "--no-pr"]
@@ -251,4 +253,4 @@ def run_with_retries(test_func, *args, **kwargs) -> bool:
251253
logging.error("Test failed after all retries")
252254
return 1
253255

254-
return 1
256+
return 1

0 commit comments

Comments
 (0)