Skip to content

Commit 29fbeec

Browse files
authored
Merge branch 'main' into feat/hypothesis-tests
2 parents db21c92 + 44e46b4 commit 29fbeec

File tree

10 files changed

+48
-60
lines changed

10 files changed

+48
-60
lines changed

codeflash/api/cfapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
9191
9292
:return: The userid or None if the request fails.
9393
"""
94-
if not ensure_codeflash_api_key():
94+
if not api_key and not ensure_codeflash_api_key():
9595
return None
9696

9797
response = make_cfapi_request(

codeflash/code_utils/edit_generated_tests.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import re
66
from pathlib import Path
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, Optional
88

99
import libcst as cst
1010
from libcst import MetadataWrapper
@@ -149,18 +149,19 @@ def leave_SimpleStatementSuite(
149149
return updated_node
150150

151151

152-
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, int]:
152+
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]], tests_project_rootdir: Path) -> dict[str, int]:
153153
unique_inv_ids: dict[str, int] = {}
154154
for inv_id, runtimes in inv_id_runtimes.items():
155155
test_qualified_name = (
156156
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
157157
if inv_id.test_class_name
158158
else inv_id.test_function_name
159159
)
160-
abs_path = str(Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve().with_suffix(""))
161-
if "__unit_test_" not in abs_path:
160+
abs_path = tests_project_rootdir / Path(inv_id.test_module_path.replace(".", os.sep)).with_suffix(".py")
161+
abs_path_str = str(abs_path.resolve().with_suffix(""))
162+
if "__unit_test_" not in abs_path_str or not test_qualified_name:
162163
continue
163-
key = test_qualified_name + "#" + abs_path # type: ignore[operator]
164+
key = test_qualified_name + "#" + abs_path_str
164165
parts = inv_id.iteration_id.split("_").__len__() # type: ignore[union-attr]
165166
cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr]
166167
match_key = key + "#" + cur_invid
@@ -174,10 +175,11 @@ def add_runtime_comments_to_generated_tests(
174175
generated_tests: GeneratedTestsList,
175176
original_runtimes: dict[InvocationId, list[int]],
176177
optimized_runtimes: dict[InvocationId, list[int]],
178+
tests_project_rootdir: Optional[Path] = None,
177179
) -> GeneratedTestsList:
178180
"""Add runtime performance comments to function calls in generated tests."""
179-
original_runtimes_dict = unique_inv_id(original_runtimes)
180-
optimized_runtimes_dict = unique_inv_id(optimized_runtimes)
181+
original_runtimes_dict = unique_inv_id(original_runtimes, tests_project_rootdir or Path())
182+
optimized_runtimes_dict = unique_inv_id(optimized_runtimes, tests_project_rootdir or Path())
181183
# Process each generated test
182184
modified_tests = []
183185
for test in generated_tests.generated_tests:

codeflash/code_utils/formatter.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,9 @@ def apply_formatter_cmds(
7676
logger.error(f"Failed to format code with {' '.join(formatter_cmd_list)}")
7777
except FileNotFoundError as e:
7878
from rich.panel import Panel
79-
from rich.text import Text
8079

81-
panel = Panel(
82-
Text.from_markup(f"⚠️ Formatter command not found: {' '.join(formatter_cmd_list)}", style="bold red"),
83-
expand=False,
84-
)
80+
command_str = " ".join(str(part) for part in formatter_cmd_list)
81+
panel = Panel(f"⚠️ Formatter command not found: {command_str}", expand=False, border_style="yellow")
8582
console.print(panel)
8683
if exit_on_failure:
8784
raise e from None

codeflash/code_utils/git_worktree_utils.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,18 @@
33
import subprocess
44
import tempfile
55
import time
6-
from functools import lru_cache
76
from pathlib import Path
8-
from typing import TYPE_CHECKING, Optional
7+
from typing import Optional
98

109
import git
1110

1211
from codeflash.cli_cmds.console import logger
1312
from codeflash.code_utils.compat import codeflash_cache_dir
1413
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
1514

16-
if TYPE_CHECKING:
17-
from git import Repo
18-
19-
2015
worktree_dirs = codeflash_cache_dir / "worktrees"
2116
patches_dir = codeflash_cache_dir / "patches"
2217

23-
if TYPE_CHECKING:
24-
from git import Repo
25-
26-
27-
@lru_cache(maxsize=1)
28-
def get_git_project_id() -> str:
29-
"""Return the first commit sha of the repo."""
30-
repo: Repo = git.Repo(search_parent_directories=True)
31-
root_commits = list(repo.iter_commits(rev="HEAD", max_parents=0))
32-
return root_commits[0].hexsha
33-
3418

3519
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
3620
repository = git.Repo(worktree_dir, search_parent_directories=True)
@@ -96,12 +80,6 @@ def remove_worktree(worktree_dir: Path) -> None:
9680
logger.exception(f"Failed to remove worktree: {worktree_dir}")
9781

9882

99-
@lru_cache(maxsize=1)
100-
def get_patches_dir_for_project() -> Path:
101-
project_id = get_git_project_id() or ""
102-
return Path(patches_dir / project_id)
103-
104-
10583
def create_diff_patch_from_worktree(
10684
worktree_dir: Path, files: list[str], fto_name: Optional[str] = None
10785
) -> Optional[Path]:
@@ -115,10 +93,8 @@ def create_diff_patch_from_worktree(
11593
if not uni_diff_text.endswith("\n"):
11694
uni_diff_text += "\n"
11795

118-
project_patches_dir = get_patches_dir_for_project()
119-
project_patches_dir.mkdir(parents=True, exist_ok=True)
120-
121-
patch_path = project_patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
96+
patches_dir.mkdir(parents=True, exist_ok=True)
97+
patch_path = Path(patches_dir / f"{worktree_dir.name}.{fto_name}.patch")
12298
with patch_path.open("w", encoding="utf8") as f:
12399
f.write(uni_diff_text)
124100

codeflash/lsp/beta.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,15 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
263263
def _initialize_optimizer_if_api_key_is_valid(
264264
server: CodeflashLanguageServer, api_key: Optional[str] = None
265265
) -> dict[str, str]:
266+
key_check_result = _check_api_key_validity(api_key)
267+
if key_check_result.get("status") != "success":
268+
return key_check_result
269+
270+
_initialize_optimizer(server)
271+
return key_check_result
272+
273+
274+
def _check_api_key_validity(api_key: Optional[str]) -> dict[str, str]:
266275
user_id = get_user_id(api_key=api_key)
267276
if user_id is None:
268277
return {"status": "error", "message": "api key not found or invalid"}
@@ -271,11 +280,15 @@ def _initialize_optimizer_if_api_key_is_valid(
271280
error_msg = user_id[7:]
272281
return {"status": "error", "message": error_msg}
273282

283+
return {"status": "success", "user_id": user_id}
284+
285+
286+
def _initialize_optimizer(server: CodeflashLanguageServer) -> None:
274287
from codeflash.optimization.optimizer import Optimizer
275288

276289
new_args = process_args(server)
277-
server.optimizer = Optimizer(new_args)
278-
return {"status": "success", "user_id": user_id}
290+
if not server.optimizer:
291+
server.optimizer = Optimizer(new_args)
279292

280293

281294
def process_args(server: CodeflashLanguageServer) -> Namespace:
@@ -302,16 +315,16 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
302315
if not api_key.startswith("cf-"):
303316
return {"status": "error", "message": "Api key is not valid"}
304317

305-
# clear cache to ensure the new api key is used
318+
# # clear cache to ensure the new api key is used
306319
get_codeflash_api_key.cache_clear()
307320
get_user_id.cache_clear()
308-
309-
init_result = _initialize_optimizer_if_api_key_is_valid(server, api_key)
310-
if init_result["status"] == "error":
311-
return {"status": "error", "message": "Api key is not valid"}
312-
313-
user_id = init_result["user_id"]
321+
key_check_result = _check_api_key_validity(api_key)
322+
if key_check_result.get("status") != "success":
323+
return key_check_result
324+
user_id = key_check_result["user_id"]
314325
result = save_api_key_to_rc(api_key)
326+
# initialize optimizer with the new api key
327+
_initialize_optimizer(server)
315328
if not is_successful(result):
316329
return {"status": "error", "message": result.failure()}
317330
return {"status": "success", "message": "Api key saved successfully", "user_id": user_id} # noqa: TRY300
@@ -325,21 +338,19 @@ def initialize_function_optimization(
325338
) -> dict[str, str]:
326339
document_uri = params.textDocument.uri
327340
document = server.workspace.get_text_document(document_uri)
341+
file_path = Path(document.path)
328342

329343
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {document_uri}", "Info")
330344

331345
if server.optimizer is None:
332346
_initialize_optimizer_if_api_key_is_valid(server)
333347

334-
server.optimizer.worktree_mode()
335-
336-
original_args, _ = server.optimizer.original_args_and_test_cfg
337-
348+
server.optimizer.args.file = file_path
338349
server.optimizer.args.function = params.functionName
339-
original_relative_file_path = Path(document.path).relative_to(original_args.project_root)
340-
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
341350
server.optimizer.args.previous_checkpoint_functions = False
342351

352+
server.optimizer.worktree_mode()
353+
343354
server.show_message_log(
344355
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
345356
)

codeflash/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
1010
from codeflash.cli_cmds.cmd_init import CODEFLASH_LOGO, ask_run_end_to_end_test
1111
from codeflash.cli_cmds.console import paneled_text
12+
from codeflash.code_utils import env_utils
1213
from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions
1314
from codeflash.code_utils.config_parser import parse_config_file
1415
from codeflash.code_utils.version_check import check_for_newer_minor_version
@@ -39,6 +40,8 @@ def main() -> None:
3940
ask_run_end_to_end_test(args)
4041
else:
4142
args = process_pyproject_config(args)
43+
if not env_utils.check_formatter_installed(args.formatter_cmds):
44+
return
4245
args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args)
4346
init_sentry(not args.disable_telemetry, exclude_errors=True)
4447
posthog_cf.initialize_posthog(not args.disable_telemetry)

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1429,7 +1429,7 @@ def process_review(
14291429
)
14301430

14311431
generated_tests = add_runtime_comments_to_generated_tests(
1432-
generated_tests, original_runtime_by_test, optimized_runtime_by_test
1432+
generated_tests, original_runtime_by_test, optimized_runtime_by_test, self.test_cfg.tests_project_rootdir
14331433
)
14341434

14351435
generated_tests_str = "\n#------------------------------------------------\n".join(

codeflash/optimization/optimizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,6 @@ def run(self) -> None:
261261
console.rule()
262262
if not env_utils.ensure_codeflash_api_key():
263263
return
264-
if not env_utils.check_formatter_installed(self.args.formatter_cmds):
265-
return
266264
if self.args.no_draft and is_pr_draft():
267265
logger.warning("PR is in draft mode, skipping optimization")
268266
return

docs/crisp.js

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/docs.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
},
1010
"favicon": "/favicon.ico",
1111
"integrations": {
12-
"intercom": {
13-
"appId": "ljxo1nzr"
12+
"posthog": {
13+
"apiKey": "phc_aUO790jHd7z1SXwsYCz8dRApxueplZlZWeDSpKc5hol"
1414
}
1515
},
1616
"navigation": {

0 commit comments

Comments
 (0)