Skip to content

Commit 7bc22c7

Browse files
authored
Merge branch 'main' into opt-impact-aseem
2 parents 2cedeaf + 3b522fa commit 7bc22c7

File tree

11 files changed

+2260
-399
lines changed

11 files changed

+2260
-399
lines changed

.github/workflows/e2e-init-optimization.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
COLUMNS: 110
2020
MAX_RETRIES: 3
2121
RETRY_DELAY: 5
22-
EXPECTED_IMPROVEMENT_PCT: 30
22+
EXPECTED_IMPROVEMENT_PCT: 10
2323
CODEFLASH_END_TO_END: 1
2424
steps:
2525
- name: 🛎️ Checkout

.github/workflows/unit-tests.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
strategy:
1212
fail-fast: false
1313
matrix:
14-
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
14+
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"]
1515
continue-on-error: true
1616
runs-on: ubuntu-latest
1717
steps:
@@ -28,5 +28,9 @@ jobs:
2828
- name: install dependencies
2929
run: uv sync
3030

31+
- name: Install test-only dependencies (Python 3.13)
32+
if: matrix.python-version == '3.13'
33+
run: uv sync --group tests
34+
3135
- name: Unit tests
32-
run: uv run pytest tests/
36+
run: uv run pytest tests/

codeflash/code_utils/code_extractor.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,16 +528,29 @@ def add_needed_imports_from_module(
528528

529529
try:
530530
for mod in gatherer.module_imports:
531+
# Skip __future__ imports as they cannot be imported directly
532+
# __future__ imports should only be imported with specific objects i.e from __future__ import annotations
533+
if mod == "__future__":
534+
continue
531535
if mod not in dotted_import_collector.imports:
532536
AddImportsVisitor.add_needed_import(dst_context, mod)
533537
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
538+
aliased_objects = set()
539+
for mod, alias_pairs in gatherer.alias_mapping.items():
540+
for alias_pair in alias_pairs:
541+
if alias_pair[0] and alias_pair[1]: # Both name and alias exist
542+
aliased_objects.add(f"{mod}.{alias_pair[0]}")
543+
534544
for mod, obj_seq in gatherer.object_mapping.items():
535545
for obj in obj_seq:
536546
if (
537547
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
538548
):
539549
continue # Skip adding imports for helper functions already in the context
540550

551+
if f"{mod}.{obj}" in aliased_objects:
552+
continue
553+
541554
# Handle star imports by resolving them to actual symbol names
542555
if obj == "*":
543556
resolved_symbols = resolve_star_import(mod, project_root)
@@ -559,6 +572,8 @@ def add_needed_imports_from_module(
559572
return dst_module_code
560573

561574
for mod, asname in gatherer.module_aliases.items():
575+
if not asname:
576+
continue
562577
if f"{mod}.{asname}" not in dotted_import_collector.imports:
563578
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
564579
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
@@ -568,12 +583,16 @@ def add_needed_imports_from_module(
568583
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
569584
continue
570585

586+
if not alias_pair[0] or not alias_pair[1]:
587+
continue
588+
571589
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
572590
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
573591
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
574592

575593
try:
576-
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
594+
add_imports_visitor = AddImportsVisitor(dst_context)
595+
transformed_module = add_imports_visitor.transform_module(parsed_dst_module)
577596
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
578597
return transformed_module.code.lstrip("\n")
579598
except Exception as e:

codeflash/code_utils/code_utils.py

Lines changed: 95 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import configparser
45
import difflib
56
import os
67
import re
@@ -15,10 +16,12 @@
1516
import tomlkit
1617

1718
from codeflash.cli_cmds.console import logger, paneled_text
18-
from codeflash.code_utils.config_parser import find_pyproject_toml
19+
from codeflash.code_utils.config_parser import find_pyproject_toml, get_all_closest_config_files
1920

2021
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
2122

23+
BLACKLIST_ADDOPTS = ("--benchmark", "--sugar", "--codespeed", "--cov", "--profile", "--junitxml", "-n")
24+
2225

2326
def unified_diff_strings(code1: str, code2: str, fromfile: str = "original", tofile: str = "modified") -> str:
2427
"""Return the unified diff between two code strings as a single string.
@@ -81,42 +84,105 @@ def create_rank_dictionary_compact(int_array: list[int]) -> dict[int, int]:
8184
return {original_index: rank for rank, original_index in enumerate(sorted_indices)}
8285

8386

84-
@contextmanager
85-
def custom_addopts() -> None:
86-
pyproject_file = find_pyproject_toml()
87-
original_content = None
88-
non_blacklist_plugin_args = ""
89-
87+
def filter_args(addopts_args: list[str]) -> list[str]:
88+
# Convert BLACKLIST_ADDOPTS to a set for faster lookup of simple matches
89+
# But keep tuple for startswith
90+
blacklist = BLACKLIST_ADDOPTS
91+
# Precompute the length for re-use
92+
n = len(addopts_args)
93+
filtered_args = []
94+
i = 0
95+
while i < n:
96+
current_arg = addopts_args[i]
97+
if current_arg.startswith(blacklist):
98+
i += 1
99+
if i < n and not addopts_args[i].startswith("-"):
100+
i += 1
101+
else:
102+
filtered_args.append(current_arg)
103+
i += 1
104+
return filtered_args
105+
106+
107+
def modify_addopts(config_file: Path) -> tuple[str, bool]: # noqa : PLR0911
108+
file_type = config_file.suffix.lower()
109+
filename = config_file.name
110+
config = None
111+
if file_type not in {".toml", ".ini", ".cfg"} or not config_file.exists():
112+
return "", False
113+
# Read original file
114+
with Path.open(config_file, encoding="utf-8") as f:
115+
content = f.read()
90116
try:
91-
# Read original file
92-
if pyproject_file.exists():
93-
with Path.open(pyproject_file, encoding="utf-8") as f:
94-
original_content = f.read()
95-
data = tomlkit.parse(original_content)
96-
# Backup original addopts
117+
if filename == "pyproject.toml":
118+
# use tomlkit
119+
data = tomlkit.parse(content)
97120
original_addopts = data.get("tool", {}).get("pytest", {}).get("ini_options", {}).get("addopts", "")
98121
# nothing to do if no addopts present
99-
if original_addopts != "" and isinstance(original_addopts, list):
100-
original_addopts = [x.strip() for x in original_addopts]
101-
non_blacklist_plugin_args = re.sub(r"-n(?: +|=)\S+", "", " ".join(original_addopts)).split(" ")
102-
non_blacklist_plugin_args = [x for x in non_blacklist_plugin_args if x != ""]
103-
if non_blacklist_plugin_args != original_addopts:
104-
data["tool"]["pytest"]["ini_options"]["addopts"] = non_blacklist_plugin_args
105-
# Write modified file
106-
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
107-
f.write(tomlkit.dumps(data))
122+
if original_addopts == "":
123+
return content, False
124+
if isinstance(original_addopts, list):
125+
original_addopts = " ".join(original_addopts)
126+
original_addopts = original_addopts.replace("=", " ")
127+
addopts_args = (
128+
original_addopts.split()
129+
) # any number of space characters as delimiter, doesn't look at = which is fine
130+
else:
131+
# use configparser
132+
config = configparser.ConfigParser()
133+
config.read_string(content)
134+
data = {section: dict(config[section]) for section in config.sections()}
135+
if config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}:
136+
original_addopts = data.get("pytest", {}).get("addopts", "") # should only be a string
137+
else:
138+
original_addopts = data.get("tool:pytest", {}).get("addopts", "") # should only be a string
139+
original_addopts = original_addopts.replace("=", " ")
140+
addopts_args = original_addopts.split()
141+
new_addopts_args = filter_args(addopts_args)
142+
if new_addopts_args == addopts_args:
143+
return content, False
144+
# change addopts now
145+
if file_type == ".toml":
146+
data["tool"]["pytest"]["ini_options"]["addopts"] = " ".join(new_addopts_args)
147+
# Write modified file
148+
with Path.open(config_file, "w", encoding="utf-8") as f:
149+
f.write(tomlkit.dumps(data))
150+
return content, True
151+
elif config_file.name in {"pytest.ini", ".pytest.ini", "tox.ini"}:
152+
config.set("pytest", "addopts", " ".join(new_addopts_args))
153+
# Write modified file
154+
with Path.open(config_file, "w", encoding="utf-8") as f:
155+
config.write(f)
156+
return content, True
157+
else:
158+
config.set("tool:pytest", "addopts", " ".join(new_addopts_args))
159+
# Write modified file
160+
with Path.open(config_file, "w", encoding="utf-8") as f:
161+
config.write(f)
162+
return content, True
163+
164+
except Exception:
165+
logger.debug("Trouble parsing")
166+
return content, False # not modified
167+
168+
169+
@contextmanager
170+
def custom_addopts() -> None:
171+
closest_config_files = get_all_closest_config_files()
172+
173+
original_content = {}
108174

175+
try:
176+
for config_file in closest_config_files:
177+
original_content[config_file] = modify_addopts(config_file)
109178
yield
110179

111180
finally:
112181
# Restore original file
113-
if (
114-
original_content
115-
and pyproject_file.exists()
116-
and tuple(original_addopts) not in {(), tuple(non_blacklist_plugin_args)}
117-
):
118-
with Path.open(pyproject_file, "w", encoding="utf-8") as f:
119-
f.write(original_content)
182+
for file, (content, was_modified) in original_content.items():
183+
if was_modified:
184+
with Path.open(file, "w", encoding="utf-8") as f:
185+
f.write(content)
120186

121187

122188
@contextmanager

codeflash/code_utils/config_parser.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
import tomlkit
77

8+
PYPROJECT_TOML_CACHE = {}
9+
ALL_CONFIG_FILES = {} # map path to closest config file
10+
811

912
def find_pyproject_toml(config_file: Path | None = None) -> Path:
1013
# Find the pyproject.toml file on the root of the project
@@ -19,10 +22,15 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
1922
raise ValueError(msg)
2023
return config_file
2124
dir_path = Path.cwd()
22-
25+
cur_path = dir_path
26+
# see if it was encountered before in search
27+
if cur_path in PYPROJECT_TOML_CACHE:
28+
return PYPROJECT_TOML_CACHE[cur_path]
29+
# map current path to closest file
2330
while dir_path != dir_path.parent:
2431
config_file = dir_path / "pyproject.toml"
2532
if config_file.exists():
33+
PYPROJECT_TOML_CACHE[cur_path] = config_file
2634
return config_file
2735
# Search for pyproject.toml in the parent directories
2836
dir_path = dir_path.parent
@@ -31,6 +39,33 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
3139
raise ValueError(msg)
3240

3341

42+
def get_all_closest_config_files() -> list[Path]:
43+
all_closest_config_files = []
44+
for file_type in ["pyproject.toml", "pytest.ini", ".pytest.ini", "tox.ini", "setup.cfg"]:
45+
closest_config_file = find_closest_config_file(file_type)
46+
if closest_config_file:
47+
all_closest_config_files.append(closest_config_file)
48+
return all_closest_config_files
49+
50+
51+
def find_closest_config_file(file_type: str) -> Path | None:
52+
# Find the closest pyproject.toml, pytest.ini, tox.ini, or setup.cfg file on the root of the project
53+
dir_path = Path.cwd()
54+
cur_path = dir_path
55+
if cur_path in ALL_CONFIG_FILES and file_type in ALL_CONFIG_FILES[cur_path]:
56+
return ALL_CONFIG_FILES[cur_path][file_type]
57+
while dir_path != dir_path.parent:
58+
config_file = dir_path / file_type
59+
if config_file.exists():
60+
if cur_path not in ALL_CONFIG_FILES:
61+
ALL_CONFIG_FILES[cur_path] = {}
62+
ALL_CONFIG_FILES[cur_path][file_type] = config_file
63+
return config_file
64+
# Search for pyproject.toml in the parent directories
65+
dir_path = dir_path.parent
66+
return None
67+
68+
3469
def find_conftest_files(test_paths: list[Path]) -> list[Path]:
3570
list_of_conftest_files = set()
3671
for test_path in test_paths:

codeflash/verification/parse_test_output.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def calculate_function_throughput_from_test_results(test_results: TestResults, f
6767
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
6868
test_results = TestResults()
6969
if not file_location.exists():
70-
logger.warning(f"No test results for {file_location} found.")
70+
logger.debug(f"No test results for {file_location} found.")
7171
console.rule()
7272
return test_results
7373

@@ -237,6 +237,11 @@ def parse_test_xml(
237237

238238
test_class_path = testcase.classname
239239
try:
240+
if testcase.name is None:
241+
logger.debug(
242+
f"testcase.name is None for testcase {testcase!r} in file {test_xml_file_path}, skipping"
243+
)
244+
continue
240245
test_function = testcase.name.split("[", 1)[0] if "[" in testcase.name else testcase.name
241246
except (AttributeError, TypeError) as e:
242247
msg = (
@@ -273,16 +278,16 @@ def parse_test_xml(
273278

274279
timed_out = False
275280
if test_config.test_framework == "pytest":
276-
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if "[" in testcase.name else 1
281+
loop_index = int(testcase.name.split("[ ")[-1][:-2]) if testcase.name and "[" in testcase.name else 1
277282
if len(testcase.result) > 1:
278-
logger.warning(f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!")
283+
logger.debug(f"!!!!!Multiple results for {testcase.name or '<None>'} in {test_xml_file_path}!!!")
279284
if len(testcase.result) == 1:
280285
message = testcase.result[0].message.lower()
281286
if "failed: timeout >" in message:
282287
timed_out = True
283288
else:
284289
if len(testcase.result) > 1:
285-
logger.warning(f"!!!!!Multiple results for {testcase.name} in {test_xml_file_path}!!!")
290+
logger.debug(f"!!!!!Multiple results for {testcase.name or '<None>'} in {test_xml_file_path}!!!")
286291
if len(testcase.result) == 1:
287292
message = testcase.result[0].message.lower()
288293
if "timed out" in message:

pyproject.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ keywords = [
1717
]
1818
dependencies = [
1919
"unidiff>=0.7.4",
20-
"pytest>=7.0.0,!=8.3.4",
20+
"pytest>=7.0.0",
2121
"gitpython>=3.1.31",
2222
"libcst>=1.0.1",
2323
"jedi>=0.19.1",
@@ -85,6 +85,16 @@ dev = [
8585
asyncio = [
8686
"pytest-asyncio>=1.2.0",
8787
]
88+
tests = [
89+
"black>=25.9.0",
90+
"jax>=0.4.30",
91+
"numpy>=2.0.2",
92+
"pandas>=2.3.3",
93+
"pyrsistent>=0.20.0",
94+
"scipy>=1.13.1",
95+
"torch>=2.8.0",
96+
"xarray>=2024.7.0",
97+
]
8898

8999
[tool.hatch.build.targets.sdist]
90100
include = ["codeflash"]

0 commit comments

Comments
 (0)