Skip to content

Commit 0be74c4

Browse files
committed
Ready to review
2 parents c93b80e + b4ab00b commit 0be74c4

26 files changed

+2696
-77
lines changed

codeflash/benchmarking/codeflash_trace.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def setup(self, trace_path: str) -> None:
2525
"""Set up the database connection for direct writing.
2626
2727
Args:
28+
----
2829
trace_path: Path to the trace database file
2930
3031
"""
@@ -52,6 +53,7 @@ def write_function_timings(self) -> None:
5253
"""Write function call data directly to the database.
5354
5455
Args:
56+
----
5557
data: List of function call data tuples to write
5658
5759
"""
@@ -94,9 +96,11 @@ def __call__(self, func: Callable) -> Callable:
9496
"""Use as a decorator to trace function execution.
9597
9698
Args:
99+
----
97100
func: The function to be decorated
98101
99102
Returns:
103+
-------
100104
The wrapped function
101105
102106
"""

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,12 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
7676
"""Add codeflash_trace to a function.
7777
7878
Args:
79+
----
7980
code: The source code as a string
8081
functions_to_optimize: List of FunctionToOptimize instances containing function details
8182
8283
Returns:
84+
-------
8385
The modified source code as a string
8486
8587
"""

codeflash/benchmarking/plugin/plugin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,11 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
7474
"""Process the trace file and extract timing data for all functions.
7575
7676
Args:
77+
----
7778
trace_path: Path to the trace file
7879
7980
Returns:
81+
-------
8082
A nested dictionary where:
8183
- Outer keys are module_name.qualified_name (module.class.function)
8284
- Inner keys are of type BenchmarkKey
@@ -132,9 +134,11 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
132134
"""Extract total benchmark timings from trace files.
133135
134136
Args:
137+
----
135138
trace_path: Path to the trace file
136139
137140
Returns:
141+
-------
138142
A dictionary mapping where:
139143
- Keys are of type BenchmarkKey
140144
- Values are total benchmark timing in milliseconds (with overhead subtracted)

codeflash/benchmarking/replay_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,14 @@ def create_trace_replay_test_code(
5555
"""Create a replay test for functions based on trace data.
5656
5757
Args:
58+
----
5859
trace_file: Path to the SQLite database file
5960
functions_data: List of dictionaries with function info extracted from DB
6061
test_framework: 'pytest' or 'unittest'
6162
max_run_count: Maximum number of runs to include in the test
6263
6364
Returns:
65+
-------
6466
A string containing the test code
6567
6668
"""
@@ -218,12 +220,14 @@ def generate_replay_test(
218220
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.
219221
220222
Args:
223+
----
221224
trace_file_path: Path to the SQLite database file
222225
output_dir: Directory to write the generated tests (if None, only returns the code)
223226
test_framework: 'pytest' or 'unittest'
224227
max_run_count: Maximum number of runs to include per function
225228
226229
Returns:
230+
-------
227231
Dictionary mapping benchmark names to generated test code
228232
229233
"""

codeflash/benchmarking/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,13 @@ def process_benchmark_data(
8383
"""Process benchmark data and generate detailed benchmark information.
8484
8585
Args:
86+
----
8687
replay_performance_gain: The performance gain from replay
8788
fto_benchmark_timings: Function to optimize benchmark timings
8889
total_benchmark_timings: Total benchmark timings
8990
9091
Returns:
92+
-------
9193
ProcessedBenchmarkInfo containing processed benchmark details
9294
9395
"""

codeflash/cli_cmds/cmd_init.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def collect_setup_info() -> SetupInfo:
211211
# Discover test directory
212212
default_tests_subdir = "tests"
213213
create_for_me_option = f"okay, create a tests{os.pathsep} directory for me!"
214-
test_subdir_options = valid_subdirs
214+
test_subdir_options = [sub_dir for sub_dir in valid_subdirs if sub_dir != module_root]
215215
if "tests" not in valid_subdirs:
216216
test_subdir_options.append(create_for_me_option)
217217
custom_dir_option = "enter a custom directory…"
@@ -240,7 +240,16 @@ def collect_setup_info() -> SetupInfo:
240240
apologize_and_exit()
241241
else:
242242
tests_root = Path(curdir) / Path(cast("str", tests_root_answer))
243+
243244
tests_root = tests_root.relative_to(curdir)
245+
246+
resolved_module_root = (Path(curdir) / Path(module_root)).resolve()
247+
resolved_tests_root = (Path(curdir) / Path(tests_root)).resolve()
248+
if resolved_module_root == resolved_tests_root:
249+
logger.warning(
250+
"It looks like your tests root is the same as your module root. This is not recommended and can lead to unexpected behavior."
251+
)
252+
244253
ph("cli-tests-root-provided")
245254

246255
# Autodiscover test framework

codeflash/cli_cmds/logging_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
2727
],
2828
force=True,
2929
)
30-
logging.info("Verbose DEBUG logging enabled") # noqa: LOG015
30+
logging.info("Verbose DEBUG logging enabled")
3131
else:
32-
logging.info("Logging level set to INFO") # noqa: LOG015
32+
logging.info("Logging level set to INFO")
3333
console.rule()

codeflash/code_utils/checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def add_function_to_checkpoint(
4747
"""Add a function to the checkpoint after it has been processed.
4848
4949
Args:
50+
----
5051
function_fully_qualified_name: The fully qualified name of the function
5152
status: Status of optimization (e.g., "optimized", "failed", "skipped")
5253
additional_info: Any additional information to store about the function
@@ -104,7 +105,8 @@ def cleanup(self) -> None:
104105
def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]:
105106
"""Get information about all processed functions, regardless of status.
106107
107-
Returns:
108+
Returns
109+
-------
108110
Dictionary mapping function names to their processing information
109111
110112
"""

codeflash/code_utils/compat.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,47 @@
11
import os
22
import sys
3+
import tempfile
34
from pathlib import Path
5+
from typing import TYPE_CHECKING
46

57
from platformdirs import user_config_dir
68

7-
# os-independent newline
8-
# important for any user-facing output or files we write
9-
# make sure to use this in f-strings e.g. f"some string{LF}"
10-
# you can use "[^f]\".*\{LF\}\" to find any lines in your code that use this without the f-string
11-
LF: str = os.linesep
9+
if TYPE_CHECKING:
10+
codeflash_temp_dir: Path
11+
codeflash_cache_dir: Path
12+
codeflash_cache_db: Path
1213

1314

14-
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
15+
class Compat:
16+
# os-independent newline
17+
LF: str = os.linesep
1518

16-
IS_POSIX = os.name != "nt"
19+
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
1720

21+
IS_POSIX: bool = os.name != "nt"
1822

19-
codeflash_cache_dir = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
23+
@property
24+
def codeflash_cache_dir(self) -> Path:
25+
return Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
2026

21-
codeflash_cache_db = codeflash_cache_dir / "codeflash_cache.db"
27+
@property
28+
def codeflash_temp_dir(self) -> Path:
29+
temp_dir = Path(tempfile.gettempdir()) / "codeflash"
30+
if not temp_dir.exists():
31+
temp_dir.mkdir(parents=True, exist_ok=True)
32+
return temp_dir
33+
34+
@property
35+
def codeflash_cache_db(self) -> Path:
36+
return self.codeflash_cache_dir / "codeflash_cache.db"
37+
38+
39+
_compat = Compat()
40+
41+
42+
codeflash_temp_dir = _compat.codeflash_temp_dir
43+
codeflash_cache_dir = _compat.codeflash_cache_dir
44+
codeflash_cache_db = _compat.codeflash_cache_db
45+
LF = _compat.LF
46+
SAFE_SYS_EXECUTABLE = _compat.SAFE_SYS_EXECUTABLE
47+
IS_POSIX = _compat.IS_POSIX
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import re
2+
3+
import libcst as cst
4+
5+
from codeflash.cli_cmds.console import logger
6+
from codeflash.code_utils.time_utils import format_time
7+
from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults
8+
9+
10+
def remove_functions_from_generated_tests(
11+
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
12+
) -> GeneratedTestsList:
13+
new_generated_tests = []
14+
for generated_test in generated_tests.generated_tests:
15+
for test_function in test_functions_to_remove:
16+
function_pattern = re.compile(
17+
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
18+
re.DOTALL,
19+
)
20+
21+
match = function_pattern.search(generated_test.generated_original_test_source)
22+
23+
if match is None or "@pytest.mark.parametrize" in match.group(0):
24+
continue
25+
26+
generated_test.generated_original_test_source = function_pattern.sub(
27+
"", generated_test.generated_original_test_source
28+
)
29+
30+
new_generated_tests.append(generated_test)
31+
32+
return GeneratedTestsList(generated_tests=new_generated_tests)
33+
34+
35+
def add_runtime_comments_to_generated_tests(
36+
generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults
37+
) -> GeneratedTestsList:
38+
"""Add runtime performance comments to function calls in generated tests."""
39+
# Create dictionaries for fast lookup of runtime data
40+
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
41+
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()
42+
43+
class RuntimeCommentTransformer(cst.CSTTransformer):
44+
def __init__(self) -> None:
45+
self.in_test_function = False
46+
self.current_test_name: str | None = None
47+
48+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
49+
if node.name.value.startswith("test_"):
50+
self.in_test_function = True
51+
self.current_test_name = node.name.value
52+
else:
53+
self.in_test_function = False
54+
self.current_test_name = None
55+
56+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
57+
if original_node.name.value.startswith("test_"):
58+
self.in_test_function = False
59+
self.current_test_name = None
60+
return updated_node
61+
62+
def leave_SimpleStatementLine(
63+
self,
64+
original_node: cst.SimpleStatementLine, # noqa: ARG002
65+
updated_node: cst.SimpleStatementLine,
66+
) -> cst.SimpleStatementLine:
67+
if not self.in_test_function or not self.current_test_name:
68+
return updated_node
69+
70+
# Look for assignment statements that assign to codeflash_output
71+
# Handle both single statements and multiple statements on one line
72+
codeflash_assignment_found = False
73+
for stmt in updated_node.body:
74+
if isinstance(stmt, cst.Assign) and (
75+
len(stmt.targets) == 1
76+
and isinstance(stmt.targets[0].target, cst.Name)
77+
and stmt.targets[0].target.value == "codeflash_output"
78+
):
79+
codeflash_assignment_found = True
80+
break
81+
82+
if codeflash_assignment_found:
83+
# Find matching test cases by looking for this test function name in the test results
84+
matching_original_times = []
85+
matching_optimized_times = []
86+
87+
for invocation_id, runtimes in original_runtime_by_test.items():
88+
if invocation_id.test_function_name == self.current_test_name:
89+
matching_original_times.extend(runtimes)
90+
91+
for invocation_id, runtimes in optimized_runtime_by_test.items():
92+
if invocation_id.test_function_name == self.current_test_name:
93+
matching_optimized_times.extend(runtimes)
94+
95+
if matching_original_times and matching_optimized_times:
96+
original_time = min(matching_original_times)
97+
optimized_time = min(matching_optimized_times)
98+
99+
# Create the runtime comment
100+
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}"
101+
102+
# Add comment to the trailing whitespace
103+
new_trailing_whitespace = cst.TrailingWhitespace(
104+
whitespace=cst.SimpleWhitespace(" "),
105+
comment=cst.Comment(comment_text),
106+
newline=updated_node.trailing_whitespace.newline,
107+
)
108+
109+
return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)
110+
111+
return updated_node
112+
113+
# Process each generated test
114+
modified_tests = []
115+
for test in generated_tests.generated_tests:
116+
try:
117+
# Parse the test source code
118+
tree = cst.parse_module(test.generated_original_test_source)
119+
120+
# Transform the tree to add runtime comments
121+
transformer = RuntimeCommentTransformer()
122+
modified_tree = tree.visit(transformer)
123+
124+
# Convert back to source code
125+
modified_source = modified_tree.code
126+
127+
# Create a new GeneratedTests object with the modified source
128+
modified_test = GeneratedTests(
129+
generated_original_test_source=modified_source,
130+
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
131+
instrumented_perf_test_source=test.instrumented_perf_test_source,
132+
behavior_file_path=test.behavior_file_path,
133+
perf_file_path=test.perf_file_path,
134+
)
135+
modified_tests.append(modified_test)
136+
except Exception as e:
137+
# If parsing fails, keep the original test
138+
logger.debug(f"Failed to add runtime comments to test: {e}")
139+
modified_tests.append(test)
140+
141+
return GeneratedTestsList(generated_tests=modified_tests)

0 commit comments

Comments
 (0)