Skip to content
16 changes: 9 additions & 7 deletions codeflash/code_utils/edit_generated_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import re
from pathlib import Path
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

import libcst as cst
from libcst import MetadataWrapper
Expand Down Expand Up @@ -149,18 +149,19 @@ def leave_SimpleStatementSuite(
return updated_node


def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, int]:
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]:
unique_inv_ids: dict[str, int] = {}
for inv_id, runtimes in inv_id_runtimes.items():
test_qualified_name = (
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
if inv_id.test_class_name
else inv_id.test_function_name
)
abs_path = str(Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve().with_suffix(""))
if "__unit_test_" not in abs_path:
abs_path = tests_project_rootdir / Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py")
abs_path_str = str(abs_path.resolve())
if "__unit_test_" not in abs_path_str or not test_qualified_name:
continue
key = test_qualified_name + "#" + abs_path # type: ignore[operator]
key = test_qualified_name + "#" + abs_path_str
parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr]
cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr]
match_key = key + "#" + cur_invid
Expand All @@ -174,10 +175,11 @@ def add_runtime_comments_to_generated_tests(
generated_tests: GeneratedTestsList,
original_runtimes: dict[InvocationId, list[int]],
optimized_runtimes: dict[InvocationId, list[int]],
tests_project_rootdir: Optional[Path] = None,
) -> GeneratedTestsList:
"""Add runtime performance comments to function calls in generated tests."""
original_runtimes_dict = unique_inv_id(original_runtimes)
optimized_runtimes_dict = unique_inv_id(optimized_runtimes)
original_runtimes_dict = unique_inv_id(original_runtimes, tests_project_rootdir or Path())
optimized_runtimes_dict = unique_inv_id(optimized_runtimes, tests_project_rootdir or Path())
# Process each generated test
modified_tests = []
for test in generated_tests.generated_tests:
Expand Down
10 changes: 4 additions & 6 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,21 +338,19 @@ def initialize_function_optimization(
) -> dict[str, str]:
document_uri = params.textDocument.uri
document = server.workspace.get_text_document(document_uri)
file_path = Path(document.path)

server.show_message_log(f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info")

if server.optimizer is None:
_initialize_optimizer_if_api_key_is_valid(server)

server.optimizer.worktree_mode()

original_args, _ = server.optimizer.original_args_and_test_cfg

server.optimizer.args.file = file_path
server.optimizer.args.function = params.functionName
original_relative_file_path = Path(document.path).relative_to(original_args.project_root)
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
server.optimizer.args.previous_checkpoint_functions = False

server.optimizer.worktree_mode()

server.show_message_log(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)
Expand Down
2 changes: 1 addition & 1 deletion codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,7 +1376,7 @@ def process_review(
)

generated_tests = add_runtime_comments_to_generated_tests(
generated_tests, original_runtime_by_test, optimized_runtime_by_test
generated_tests, original_runtime_by_test, optimized_runtime_by_test, self.test_cfg.tests_project_rootdir
)

generated_tests_str = "\n#------------------------------------------------\n".join(
Expand Down
Loading