Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions codeflash/benchmarking/codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def setup(self, trace_path: str) -> None:
"""Set up the database connection for direct writing.
Args:
----
trace_path: Path to the trace database file
"""
Expand Down Expand Up @@ -52,6 +53,7 @@ def write_function_timings(self) -> None:
"""Write function call data directly to the database.
Args:
----
data: List of function call data tuples to write
"""
Expand Down Expand Up @@ -94,9 +96,11 @@ def __call__(self, func: Callable) -> Callable:
"""Use as a decorator to trace function execution.
Args:
----
func: The function to be decorated
Returns:
-------
The wrapped function
"""
Expand Down
2 changes: 2 additions & 0 deletions codeflash/benchmarking/instrument_codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
"""Add codeflash_trace to a function.
Args:
----
code: The source code as a string
functions_to_optimize: List of FunctionToOptimize instances containing function details
Returns:
-------
The modified source code as a string
"""
Expand Down
4 changes: 4 additions & 0 deletions codeflash/benchmarking/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
"""Process the trace file and extract timing data for all functions.
Args:
----
trace_path: Path to the trace file
Returns:
-------
A nested dictionary where:
- Outer keys are module_name.qualified_name (module.class.function)
- Inner keys are of type BenchmarkKey
Expand Down Expand Up @@ -132,9 +134,11 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
"""Extract total benchmark timings from trace files.
Args:
----
trace_path: Path to the trace file
Returns:
-------
A dictionary mapping where:
- Keys are of type BenchmarkKey
- Values are total benchmark timing in milliseconds (with overhead subtracted)
Expand Down
4 changes: 4 additions & 0 deletions codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ def create_trace_replay_test_code(
"""Create a replay test for functions based on trace data.

Args:
----
trace_file: Path to the SQLite database file
functions_data: List of dictionaries with function info extracted from DB
test_framework: 'pytest' or 'unittest'
max_run_count: Maximum number of runs to include in the test

Returns:
-------
A string containing the test code

"""
Expand Down Expand Up @@ -218,12 +220,14 @@ def generate_replay_test(
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.

Args:
----
trace_file_path: Path to the SQLite database file
output_dir: Directory to write the generated tests (if None, only returns the code)
test_framework: 'pytest' or 'unittest'
max_run_count: Maximum number of runs to include per function

Returns:
-------
Dictionary mapping benchmark names to generated test code

"""
Expand Down
2 changes: 2 additions & 0 deletions codeflash/benchmarking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ def process_benchmark_data(
"""Process benchmark data and generate detailed benchmark information.

Args:
----
replay_performance_gain: The performance gain from replay
fto_benchmark_timings: Function to optimize benchmark timings
total_benchmark_timings: Total benchmark timings

Returns:
-------
ProcessedBenchmarkInfo containing processed benchmark details

"""
Expand Down
4 changes: 2 additions & 2 deletions codeflash/cli_cmds/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
],
force=True,
)
logging.info("Verbose DEBUG logging enabled") # noqa: LOG015
logging.info("Verbose DEBUG logging enabled")
else:
logging.info("Logging level set to INFO") # noqa: LOG015
logging.info("Logging level set to INFO")
console.rule()
4 changes: 3 additions & 1 deletion codeflash/code_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def add_function_to_checkpoint(
"""Add a function to the checkpoint after it has been processed.
Args:
----
function_fully_qualified_name: The fully qualified name of the function
status: Status of optimization (e.g., "optimized", "failed", "skipped")
additional_info: Any additional information to store about the function
Expand Down Expand Up @@ -104,7 +105,8 @@ def cleanup(self) -> None:
def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]:
"""Get information about all processed functions, regardless of status.
Returns:
Returns
-------
Dictionary mapping function names to their processing information
"""
Expand Down
141 changes: 141 additions & 0 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import re

import libcst as cst

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_time
from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults


def remove_functions_from_generated_tests(
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
) -> GeneratedTestsList:
new_generated_tests = []
for generated_test in generated_tests.generated_tests:
for test_function in test_functions_to_remove:
function_pattern = re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
re.DOTALL,
)

match = function_pattern.search(generated_test.generated_original_test_source)

if match is None or "@pytest.mark.parametrize" in match.group(0):
continue

generated_test.generated_original_test_source = function_pattern.sub(
"", generated_test.generated_original_test_source
)

new_generated_tests.append(generated_test)

return GeneratedTestsList(generated_tests=new_generated_tests)


def add_runtime_comments_to_generated_tests(
generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults
) -> GeneratedTestsList:
"""Add runtime performance comments to function calls in generated tests."""
# Create dictionaries for fast lookup of runtime data
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()

class RuntimeCommentTransformer(cst.CSTTransformer):
def __init__(self) -> None:
self.in_test_function = False
self.current_test_name: str | None = None

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
if node.name.value.startswith("test_"):
self.in_test_function = True
self.current_test_name = node.name.value
else:
self.in_test_function = False
self.current_test_name = None

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
if original_node.name.value.startswith("test_"):
self.in_test_function = False
self.current_test_name = None
return updated_node

def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine, # noqa: ARG002
updated_node: cst.SimpleStatementLine,
) -> cst.SimpleStatementLine:
if not self.in_test_function or not self.current_test_name:
return updated_node

# Look for assignment statements that assign to codeflash_output
# Handle both single statements and multiple statements on one line
codeflash_assignment_found = False
for stmt in updated_node.body:
if isinstance(stmt, cst.Assign) and (
len(stmt.targets) == 1
and isinstance(stmt.targets[0].target, cst.Name)
and stmt.targets[0].target.value == "codeflash_output"
):
codeflash_assignment_found = True
break

if codeflash_assignment_found:
# Find matching test cases by looking for this test function name in the test results
matching_original_times = []
matching_optimized_times = []

for invocation_id, runtimes in original_runtime_by_test.items():
if invocation_id.test_function_name == self.current_test_name:
matching_original_times.extend(runtimes)

for invocation_id, runtimes in optimized_runtime_by_test.items():
if invocation_id.test_function_name == self.current_test_name:
matching_optimized_times.extend(runtimes)

if matching_original_times and matching_optimized_times:
original_time = min(matching_original_times)
optimized_time = min(matching_optimized_times)

# Create the runtime comment
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}"

# Add comment to the trailing whitespace
new_trailing_whitespace = cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "),
comment=cst.Comment(comment_text),
newline=updated_node.trailing_whitespace.newline,
)

return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)

return updated_node

# Process each generated test
modified_tests = []
for test in generated_tests.generated_tests:
try:
# Parse the test source code
tree = cst.parse_module(test.generated_original_test_source)

# Transform the tree to add runtime comments
transformer = RuntimeCommentTransformer()
modified_tree = tree.visit(transformer)

# Convert back to source code
modified_source = modified_tree.code

# Create a new GeneratedTests object with the modified source
modified_test = GeneratedTests(
generated_original_test_source=modified_source,
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
instrumented_perf_test_source=test.instrumented_perf_test_source,
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
modified_tests.append(modified_test)
except Exception as e:
# If parsing fails, keep the original test
logger.debug(f"Failed to add runtime comments to test: {e}")
modified_tests.append(test)

return GeneratedTestsList(generated_tests=modified_tests)
3 changes: 3 additions & 0 deletions codeflash/code_utils/line_profile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, qualified_name: str, decorator_name: str) -> None:
"""Initialize the transformer.
Args:
----
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
decorator_name: The name of the decorator to add.
Expand Down Expand Up @@ -144,11 +145,13 @@ def add_decorator_to_qualified_function(module: cst.Module, qualified_name: str,
"""Add a decorator to a function with the exact qualified name in the source code.
Args:
----
module: The Python source code as a CST module.
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
decorator_name: The name of the decorator to add.
Returns:
-------
The modified CST module.
"""
Expand Down
28 changes: 0 additions & 28 deletions codeflash/code_utils/remove_generated_tests.py

This file was deleted.

37 changes: 37 additions & 0 deletions codeflash/code_utils/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,40 @@ def humanize_runtime(time_in_ns: int) -> str:
runtime_human = runtime_human_parts[0]

return f"{runtime_human} {units}"


def format_time(nanoseconds: int) -> str:
"""Format nanoseconds into a human-readable string with 3 significant digits when needed."""
# Inlined significant digit check: >= 3 digits if value >= 100
if nanoseconds < 1_000:
return f"{nanoseconds}ns"
if nanoseconds < 1_000_000:
microseconds_int = nanoseconds // 1_000
if microseconds_int >= 100:
return f"{microseconds_int}μs"
microseconds = nanoseconds / 1_000
# Format with precision: 3 significant digits
if microseconds >= 100:
return f"{microseconds:.0f}μs"
if microseconds >= 10:
return f"{microseconds:.1f}μs"
return f"{microseconds:.2f}μs"
if nanoseconds < 1_000_000_000:
milliseconds_int = nanoseconds // 1_000_000
if milliseconds_int >= 100:
return f"{milliseconds_int}ms"
milliseconds = nanoseconds / 1_000_000
if milliseconds >= 100:
return f"{milliseconds:.0f}ms"
if milliseconds >= 10:
return f"{milliseconds:.1f}ms"
return f"{milliseconds:.2f}ms"
seconds_int = nanoseconds // 1_000_000_000
if seconds_int >= 100:
return f"{seconds_int}s"
seconds = nanoseconds / 1_000_000_000
if seconds >= 100:
return f"{seconds:.0f}s"
if seconds >= 10:
return f"{seconds:.1f}s"
return f"{seconds:.2f}s"
Loading
Loading