Skip to content

Commit 8efafe5

Browse files
committed
Merge remote-tracking branch 'origin/main' into stdout_comparison_
2 parents 18af38e + 69e43dd commit 8efafe5

File tree

10 files changed

+239
-52
lines changed

10 files changed

+239
-52
lines changed

codeflash/LICENSE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Business Source License 1.1
33
Parameters
44

55
Licensor: CodeFlash Inc.
6-
Licensed Work: Codeflash Client version 0.9.x
6+
Licensed Work: Codeflash Client version 0.10.x
77
The Licensed Work is (c) 2024 CodeFlash Inc.
88

99
Additional Use Grant: None. Production use of the Licensed Work is only permitted
@@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
1313
Platform. Please visit codeflash.ai for further
1414
information.
1515

16-
Change Date: 2029-01-06
16+
Change Date: 2029-02-25
1717

1818
Change License: MIT
1919

codeflash/code_utils/code_replacer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

33
import ast
4+
import re
45
from collections import defaultdict
56
from functools import lru_cache
6-
from typing import TYPE_CHECKING, TypeVar
7+
from typing import TYPE_CHECKING, Optional, TypeVar
78

89
import libcst as cst
910

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import re
5+
from typing import Optional
6+
7+
8+
class AssertCleanup:
9+
def transform_asserts(self, code: str) -> str:
10+
lines = code.splitlines()
11+
result_lines = []
12+
13+
for line in lines:
14+
transformed = self._transform_assert_line(line)
15+
result_lines.append(transformed if transformed is not None else line)
16+
17+
return "\n".join(result_lines)
18+
19+
def _transform_assert_line(self, line: str) -> Optional[str]:
20+
indent = line[: len(line) - len(line.lstrip())]
21+
22+
assert_match = self.assert_re.match(line)
23+
if assert_match:
24+
expression = assert_match.group(1).strip()
25+
if expression.startswith("not "):
26+
return f"{indent}{expression}"
27+
28+
expression = expression.rstrip(",;")
29+
return f"{indent}{expression}"
30+
31+
unittest_match = self.unittest_re.match(line)
32+
if unittest_match:
33+
indent, assert_method, args = unittest_match.groups()
34+
35+
if args:
36+
arg_parts = self._split_top_level_args(args)
37+
if arg_parts and arg_parts[0]:
38+
return f"{indent}{arg_parts[0]}"
39+
40+
return None
41+
42+
def _split_top_level_args(self, args_str: str) -> list[str]:
43+
result = []
44+
current = []
45+
depth = 0
46+
47+
for char in args_str:
48+
if char in "([{":
49+
depth += 1
50+
current.append(char)
51+
elif char in ")]}":
52+
depth -= 1
53+
current.append(char)
54+
elif char == "," and depth == 0:
55+
result.append("".join(current).strip())
56+
current = []
57+
else:
58+
current.append(char)
59+
60+
if current:
61+
result.append("".join(current).strip())
62+
63+
return result
64+
65+
def __init__(self):
66+
# Pre-compiling regular expressions for faster execution
67+
self.assert_re = re.compile(r"\s*assert\s+(.*?)(?:\s*==\s*.*)?$")
68+
self.unittest_re = re.compile(r"(\s*)self\.assert([A-Za-z]+)\((.*)\)$")
69+
70+
71+
def clean_concolic_tests(test_suite_code: str) -> str:
72+
try:
73+
can_parse = True
74+
tree = ast.parse(test_suite_code)
75+
except SyntaxError:
76+
can_parse = False
77+
78+
if not can_parse:
79+
return AssertCleanup().transform_asserts(test_suite_code)
80+
81+
for node in ast.walk(tree):
82+
if isinstance(node, ast.FunctionDef) and node.name.startswith("test_"):
83+
new_body = []
84+
for stmt in node.body:
85+
if isinstance(stmt, ast.Assert):
86+
if isinstance(stmt.test, ast.Compare) and isinstance(stmt.test.left, ast.Call):
87+
new_body.append(ast.Expr(value=stmt.test.left))
88+
else:
89+
new_body.append(stmt)
90+
91+
else:
92+
new_body.append(stmt)
93+
node.body = new_body
94+
95+
return ast.unparse(tree).strip()

codeflash/discovery/functions_to_optimize.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import json
45
import os
56
import random
67
import warnings
@@ -156,9 +157,9 @@ def get_functions_to_optimize(
156157
project_root: Path,
157158
module_root: Path,
158159
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
159-
assert (
160-
sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1
161-
), "Only one of optimize_all, replay_test, or file should be provided"
160+
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
161+
"Only one of optimize_all, replay_test, or file should be provided"
162+
)
162163
functions: dict[str, list[FunctionToOptimize]]
163164
with warnings.catch_warnings():
164165
warnings.simplefilter(action="ignore", category=SyntaxWarning)
@@ -434,9 +435,7 @@ def filter_functions(
434435
test_functions_removed_count += len(functions)
435436
continue
436437
if file_path in ignore_paths or any(
437-
# file_path.startswith(ignore_path + os.sep) for ignore_path in ignore_paths if ignore_path
438-
file_path.startswith(str(ignore_path) + os.sep)
439-
for ignore_path in ignore_paths
438+
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
440439
):
441440
ignore_paths_removed_count += 1
442441
continue
@@ -457,15 +456,17 @@ def filter_functions(
457456
malformed_paths_count += 1
458457
continue
459458
if blocklist_funcs:
460-
for function in functions.copy():
461-
path = Path(function.file_path).name
462-
if path in blocklist_funcs and function.function_name in blocklist_funcs[path]:
463-
functions.remove(function)
464-
logger.debug(f"Skipping {function.function_name} in {path} as it has already been optimized")
465-
continue
466-
459+
functions = [
460+
function
461+
for function in functions
462+
if not (
463+
function.file_path.name in blocklist_funcs
464+
and function.qualified_name in blocklist_funcs[function.file_path.name]
465+
)
466+
]
467467
filtered_modified_functions[file_path] = functions
468468
functions_count += len(functions)
469+
469470
if not disable_logs:
470471
log_info = {
471472
f"{test_functions_removed_count} test function{'s' if test_functions_removed_count != 1 else ''}": test_functions_removed_count,
@@ -475,10 +476,11 @@ def filter_functions(
475476
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
476477
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
477478
}
478-
log_string: str
479-
if log_string := "\n".join([k for k, v in log_info.items() if v > 0]):
479+
log_string = "\n".join([k for k, v in log_info.items() if v > 0])
480+
if log_string:
480481
logger.info(f"Ignoring: {log_string}")
481482
console.rule()
483+
482484
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
483485

484486

codeflash/github/PrComment.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ class PrComment:
2020
winning_benchmarking_test_results: TestResults
2121

2222
def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]:
23+
24+
report_table = {
25+
test_type.to_name(): result
26+
for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items()
27+
if test_type.to_name()
28+
}
29+
2330
return {
2431
"optimization_explanation": self.optimization_explanation,
2532
"best_runtime": humanize_runtime(self.best_runtime),
@@ -29,10 +36,7 @@ def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str]]:
2936
"speedup_x": self.speedup_x,
3037
"speedup_pct": self.speedup_pct,
3138
"loop_count": self.winning_benchmarking_test_results.number_of_loops(),
32-
"report_table": {
33-
test_type.to_name(): result
34-
for test_type, result in self.winning_behavioral_test_results.get_test_pass_fail_report_by_type().items()
35-
},
39+
"report_table": report_table
3640
}
3741

3842

codeflash/models/models.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
from enum import Enum, IntEnum
88
from pathlib import Path
99
from re import Pattern
10-
from typing import Any, Optional, Union
10+
from typing import Annotated, Any, Optional, Union
1111

12+
import sentry_sdk
13+
from coverage.exceptions import NoDataError
1214
from jedi.api.classes import Name
1315
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
1416
from pydantic.dataclasses import dataclass
15-
from typing_extensions import Annotated
1617

1718
from codeflash.cli_cmds.console import console, logger
1819
from codeflash.code_utils.code_utils import validate_python_code
@@ -217,7 +218,7 @@ class CoverageData:
217218
graph: dict[str, dict[str, Collection[object]]]
218219
code_context: CodeOptimizationContext
219220
main_func_coverage: FunctionCoverage
220-
dependent_func_coverage: Union[FunctionCoverage, None]
221+
dependent_func_coverage: Optional[FunctionCoverage]
221222
status: CoverageStatus
222223
blank_re: Pattern[str] = re.compile(r"\s*(#|$)")
223224
else_re: Pattern[str] = re.compile(r"\s*else\s*:\s*(#|$)")
@@ -231,34 +232,21 @@ def load_from_sqlite_database(
231232
from coverage.jsonreport import JsonReporter
232233

233234
cov = Coverage(data_file=database_path, data_suffix=True, auto_data=True, branch=True)
235+
234236
if not database_path.stat().st_size or not database_path.exists():
235237
logger.debug(f"Coverage database {database_path} is empty or does not exist")
236-
return CoverageData(
237-
file_path=source_code_path,
238-
coverage=0.0,
239-
function_name=function_name,
240-
functions_being_tested=[],
241-
graph={},
242-
code_context=code_context,
243-
main_func_coverage=FunctionCoverage(
244-
name=function_name,
245-
coverage=0.0,
246-
executed_lines=[],
247-
unexecuted_lines=[],
248-
executed_branches=[],
249-
unexecuted_branches=[],
250-
),
251-
dependent_func_coverage=None,
252-
status=CoverageStatus.NOT_FOUND,
253-
)
254-
238+
sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist")
239+
return CoverageData.create_empty(source_code_path, function_name, code_context)
255240
cov.load()
256241

257242
reporter = JsonReporter(cov)
258243
temp_json_file = database_path.with_suffix(".report.json")
259244
with temp_json_file.open("w") as f:
260-
reporter.report(morfs=[source_code_path.as_posix()], outfile=f)
261-
245+
try:
246+
reporter.report(morfs=[source_code_path.as_posix()], outfile=f)
247+
except NoDataError:
248+
sentry_sdk.capture_message(f"No coverage data found for {function_name} in {source_code_path}")
249+
return CoverageData.create_empty(source_code_path, function_name, code_context)
262250
with temp_json_file.open() as f:
263251
original_coverage_data = json.load(f)
264252

@@ -461,6 +449,34 @@ def log_coverage(self) -> None:
461449
if is_end_to_end():
462450
console.print(self)
463451

452+
@classmethod
453+
def create_empty(cls, file_path: Path, function_name: str, code_context: CodeOptimizationContext) -> CoverageData:
454+
return cls(
455+
file_path=file_path,
456+
coverage=0.0,
457+
function_name=function_name,
458+
functions_being_tested=[function_name],
459+
graph={
460+
function_name: {
461+
"executed_lines": set(),
462+
"unexecuted_lines": set(),
463+
"executed_branches": [],
464+
"unexecuted_branches": [],
465+
}
466+
},
467+
code_context=code_context,
468+
main_func_coverage=FunctionCoverage(
469+
name=function_name,
470+
coverage=0.0,
471+
executed_lines=[],
472+
unexecuted_lines=[],
473+
executed_branches=[],
474+
unexecuted_branches=[],
475+
),
476+
dependent_func_coverage=None,
477+
status=CoverageStatus.NOT_FOUND,
478+
)
479+
464480

465481
@dataclass
466482
class FunctionCoverage:

codeflash/verification/concolic_testing.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88

99
from codeflash.cli_cmds.console import console, logger
10+
from codeflash.code_utils.concolic_utils import clean_concolic_tests
1011
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
1112
from codeflash.code_utils.static_analysis import has_typed_parameters
1213
from codeflash.discovery.discover_unit_tests import discover_unit_tests
@@ -21,7 +22,11 @@ def generate_concolic_tests(
2122
) -> tuple[dict[str, list[FunctionCalledInTest]], str]:
2223
function_to_concolic_tests = {}
2324
concolic_test_suite_code = ""
24-
if test_cfg.concolic_test_root_dir and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents):
25+
if (
26+
test_cfg.concolic_test_root_dir
27+
and isinstance(function_to_optimize_ast, (ast.FunctionDef, ast.AsyncFunctionDef))
28+
and has_typed_parameters(function_to_optimize_ast, function_to_optimize.parents)
29+
):
2530
logger.info("Generating concolic opcode coverage tests for the original code…")
2631
console.rule()
2732
try:
@@ -54,7 +59,8 @@ def generate_concolic_tests(
5459
return function_to_concolic_tests, concolic_test_suite_code
5560

5661
if cover_result.returncode == 0:
57-
concolic_test_suite_code: str = cover_result.stdout
62+
generated_concolic_test: str = cover_result.stdout
63+
concolic_test_suite_code: str = clean_concolic_tests(generated_concolic_test)
5864
concolic_test_suite_dir = Path(tempfile.mkdtemp(dir=test_cfg.concolic_test_root_dir))
5965
concolic_test_suite_path = concolic_test_suite_dir / "test_concolic_coverage.py"
6066
concolic_test_suite_path.write_text(concolic_test_suite_code, encoding="utf8")

codeflash/version.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# These version placeholders will be replaced by poetry-dynamic-versioning during `poetry build`.
2-
__version__ = "0.9.2"
3-
__version_tuple__ = (0, 9, 2)
2+
__version__ = "0.10.0"
3+
__version_tuple__ = (0, 10, 0)

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ exclude = [
6767

6868
# Versions here the minimum required versions for the project. These should be as loose as possible.
6969
[tool.poetry.dependencies]
70-
python = "^3.9"
70+
python = ">=3.9"
7171
unidiff = ">=0.7.4"
7272
pytest = ">=7.0.0"
7373
gitpython = ">=3.1.31"
@@ -176,7 +176,7 @@ ignore = [
176176
"TD003",
177177
"TD004",
178178
"PLR2004",
179-
"UP007" # remove once we drop 3.9 support.
179+
"UP007"
180180
]
181181

182182
[tool.ruff.lint.flake8-type-checking]

0 commit comments

Comments
 (0)