Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions codeflash/benchmarking/codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def setup(self, trace_path: str) -> None:
"""Set up the database connection for direct writing.

Args:
----
trace_path: Path to the trace database file

"""
Expand Down Expand Up @@ -52,6 +53,7 @@ def write_function_timings(self) -> None:
"""Write function call data directly to the database.

Args:
----
data: List of function call data tuples to write

"""
Expand Down Expand Up @@ -94,9 +96,11 @@ def __call__(self, func: Callable) -> Callable:
"""Use as a decorator to trace function execution.

Args:
----
func: The function to be decorated

Returns:
-------
The wrapped function

"""
Expand Down
2 changes: 2 additions & 0 deletions codeflash/benchmarking/instrument_codeflash_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,12 @@ def add_codeflash_decorator_to_code(code: str, functions_to_optimize: list[Funct
"""Add codeflash_trace to a function.

Args:
----
code: The source code as a string
functions_to_optimize: List of FunctionToOptimize instances containing function details

Returns:
-------
The modified source code as a string

"""
Expand Down
4 changes: 4 additions & 0 deletions codeflash/benchmarking/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def get_function_benchmark_timings(trace_path: Path) -> dict[str, dict[Benchmark
"""Process the trace file and extract timing data for all functions.

Args:
----
trace_path: Path to the trace file

Returns:
-------
A nested dictionary where:
- Outer keys are module_name.qualified_name (module.class.function)
- Inner keys are of type BenchmarkKey
Expand Down Expand Up @@ -132,9 +134,11 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
"""Extract total benchmark timings from trace files.

Args:
----
trace_path: Path to the trace file

Returns:
-------
A dictionary mapping where:
- Keys are of type BenchmarkKey
- Values are total benchmark timing in milliseconds (with overhead subtracted)
Expand Down
4 changes: 4 additions & 0 deletions codeflash/benchmarking/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,14 @@ def create_trace_replay_test_code(
"""Create a replay test for functions based on trace data.

Args:
----
trace_file: Path to the SQLite database file
functions_data: List of dictionaries with function info extracted from DB
test_framework: 'pytest' or 'unittest'
max_run_count: Maximum number of runs to include in the test

Returns:
-------
A string containing the test code

"""
Expand Down Expand Up @@ -218,12 +220,14 @@ def generate_replay_test(
"""Generate multiple replay tests from the traced function calls, grouped by benchmark.

Args:
----
trace_file_path: Path to the SQLite database file
output_dir: Directory to write the generated tests (if None, only returns the code)
test_framework: 'pytest' or 'unittest'
max_run_count: Maximum number of runs to include per function

Returns:
-------
Dictionary mapping benchmark names to generated test code

"""
Expand Down
2 changes: 2 additions & 0 deletions codeflash/benchmarking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ def process_benchmark_data(
"""Process benchmark data and generate detailed benchmark information.

Args:
----
replay_performance_gain: The performance gain from replay
fto_benchmark_timings: Function to optimize benchmark timings
total_benchmark_timings: Total benchmark timings

Returns:
-------
ProcessedBenchmarkInfo containing processed benchmark details

"""
Expand Down
4 changes: 2 additions & 2 deletions codeflash/cli_cmds/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def set_level(level: int, *, echo_setting: bool = True) -> None:
],
force=True,
)
logging.info("Verbose DEBUG logging enabled") # noqa: LOG015
logging.info("Verbose DEBUG logging enabled")
else:
logging.info("Logging level set to INFO") # noqa: LOG015
logging.info("Logging level set to INFO")
console.rule()
4 changes: 3 additions & 1 deletion codeflash/code_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def add_function_to_checkpoint(
"""Add a function to the checkpoint after it has been processed.

Args:
----
function_fully_qualified_name: The fully qualified name of the function
status: Status of optimization (e.g., "optimized", "failed", "skipped")
additional_info: Any additional information to store about the function
Expand Down Expand Up @@ -104,7 +105,8 @@ def cleanup(self) -> None:
def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]:
"""Get information about all processed functions, regardless of status.

Returns:
Returns
-------
Dictionary mapping function names to their processing information

"""
Expand Down
141 changes: 141 additions & 0 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import re

import libcst as cst

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.time_utils import format_time
from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults


def remove_functions_from_generated_tests(
generated_tests: GeneratedTestsList, test_functions_to_remove: list[str]
) -> GeneratedTestsList:
new_generated_tests = []
for generated_test in generated_tests.generated_tests:
for test_function in test_functions_to_remove:
function_pattern = re.compile(
rf"(@pytest\.mark\.parametrize\(.*?\)\s*)?def\s+{re.escape(test_function)}\(.*?\):.*?(?=\ndef\s|$)",
re.DOTALL,
)

match = function_pattern.search(generated_test.generated_original_test_source)

if match is None or "@pytest.mark.parametrize" in match.group(0):
continue

generated_test.generated_original_test_source = function_pattern.sub(
"", generated_test.generated_original_test_source
)

new_generated_tests.append(generated_test)

return GeneratedTestsList(generated_tests=new_generated_tests)


def add_runtime_comments_to_generated_tests(
generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults
) -> GeneratedTestsList:
"""Add runtime performance comments to function calls in generated tests."""
# Create dictionaries for fast lookup of runtime data
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()

class RuntimeCommentTransformer(cst.CSTTransformer):
def __init__(self) -> None:
self.in_test_function = False
self.current_test_name: str | None = None

def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
if node.name.value.startswith("test_"):
self.in_test_function = True
self.current_test_name = node.name.value
else:
self.in_test_function = False
self.current_test_name = None

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
if original_node.name.value.startswith("test_"):
self.in_test_function = False
self.current_test_name = None
return updated_node

def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine, # noqa: ARG002
updated_node: cst.SimpleStatementLine,
) -> cst.SimpleStatementLine:
if not self.in_test_function or not self.current_test_name:
return updated_node

# Look for assignment statements that assign to codeflash_output
# Handle both single statements and multiple statements on one line
codeflash_assignment_found = False
for stmt in updated_node.body:
if isinstance(stmt, cst.Assign) and (
len(stmt.targets) == 1
and isinstance(stmt.targets[0].target, cst.Name)
and stmt.targets[0].target.value == "codeflash_output"
):
codeflash_assignment_found = True
break

if codeflash_assignment_found:
# Find matching test cases by looking for this test function name in the test results
matching_original_times = []
matching_optimized_times = []

for invocation_id, runtimes in original_runtime_by_test.items():
if invocation_id.test_function_name == self.current_test_name:
matching_original_times.extend(runtimes)

for invocation_id, runtimes in optimized_runtime_by_test.items():
if invocation_id.test_function_name == self.current_test_name:
matching_optimized_times.extend(runtimes)

if matching_original_times and matching_optimized_times:
original_time = min(matching_original_times)
optimized_time = min(matching_optimized_times)

# Create the runtime comment
comment_text = f"# {format_time(original_time)} -> {format_time(optimized_time)}"

# Add comment to the trailing whitespace
new_trailing_whitespace = cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "),
comment=cst.Comment(comment_text),
newline=updated_node.trailing_whitespace.newline,
)

return updated_node.with_changes(trailing_whitespace=new_trailing_whitespace)

return updated_node

# Process each generated test
modified_tests = []
for test in generated_tests.generated_tests:
try:
# Parse the test source code
tree = cst.parse_module(test.generated_original_test_source)

# Transform the tree to add runtime comments
transformer = RuntimeCommentTransformer()
modified_tree = tree.visit(transformer)

# Convert back to source code
modified_source = modified_tree.code

# Create a new GeneratedTests object with the modified source
modified_test = GeneratedTests(
generated_original_test_source=modified_source,
instrumented_behavior_test_source=test.instrumented_behavior_test_source,
instrumented_perf_test_source=test.instrumented_perf_test_source,
behavior_file_path=test.behavior_file_path,
perf_file_path=test.perf_file_path,
)
modified_tests.append(modified_test)
except Exception as e:
# If parsing fails, keep the original test
logger.debug(f"Failed to add runtime comments to test: {e}")
modified_tests.append(test)

return GeneratedTestsList(generated_tests=modified_tests)
3 changes: 3 additions & 0 deletions codeflash/code_utils/line_profile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(self, qualified_name: str, decorator_name: str) -> None:
"""Initialize the transformer.

Args:
----
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
decorator_name: The name of the decorator to add.

Expand Down Expand Up @@ -144,11 +145,13 @@ def add_decorator_to_qualified_function(module: cst.Module, qualified_name: str,
"""Add a decorator to a function with the exact qualified name in the source code.

Args:
----
module: The Python source code as a CST module.
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
decorator_name: The name of the decorator to add.

Returns:
-------
The modified CST module.

"""
Expand Down
28 changes: 0 additions & 28 deletions codeflash/code_utils/remove_generated_tests.py

This file was deleted.

45 changes: 45 additions & 0 deletions codeflash/code_utils/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,48 @@ def humanize_runtime(time_in_ns: int) -> str:
runtime_human = runtime_human_parts[0]

return f"{runtime_human} {units}"


def format_time(nanoseconds: int) -> str:
"""Format nanoseconds into a human-readable string with 3 significant digits when needed."""

def count_significant_digits(num: int) -> int:
"""Count significant digits in an integer."""
return len(str(abs(num)))

def format_with_precision(value: float, unit: str) -> str:
"""Format a value with 3 significant digits precision."""
if value >= 100:
return f"{value:.0f}{unit}"
if value >= 10:
return f"{value:.1f}{unit}"
return f"{value:.2f}{unit}"

result = ""
if nanoseconds < 1_000:
result = f"{nanoseconds}ns"
elif nanoseconds < 1_000_000:
# Convert to microseconds
microseconds_int = nanoseconds // 1_000
if count_significant_digits(microseconds_int) >= 3:
result = f"{microseconds_int}μs"
else:
microseconds_float = nanoseconds / 1_000
result = format_with_precision(microseconds_float, "μs")
elif nanoseconds < 1_000_000_000:
# Convert to milliseconds
milliseconds_int = nanoseconds // 1_000_000
if count_significant_digits(milliseconds_int) >= 3:
result = f"{milliseconds_int}ms"
else:
milliseconds_float = nanoseconds / 1_000_000
result = format_with_precision(milliseconds_float, "ms")
else:
# Convert to seconds
seconds_int = nanoseconds // 1_000_000_000
if count_significant_digits(seconds_int) >= 3:
result = f"{seconds_int}s"
else:
seconds_float = nanoseconds / 1_000_000_000
result = format_with_precision(seconds_float, "s")
return result
Loading
Loading