Skip to content

Commit cdbe3e0

Browse files
authored
Merge branch 'main' into opt-impact-aseem
2 parents 2e8414f + 443cb4d commit cdbe3e0

File tree

8 files changed

+126
-66
lines changed

8 files changed

+126
-66
lines changed

codeflash/api/aiservice.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pydantic.json import pydantic_encoder
1212

1313
from codeflash.cli_cmds.console import console, logger
14+
from codeflash.code_utils.config_consts import N_CANDIDATES_EFFECTIVE, N_CANDIDATES_LP_EFFECTIVE
1415
from codeflash.code_utils.env_utils import get_codeflash_api_key
1516
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
1617
from codeflash.code_utils.time_utils import humanize_runtime
@@ -132,6 +133,7 @@ def optimize_python_code( # noqa: D417
132133
"current_username": get_last_commit_author_if_pr_exists(None),
133134
"repo_owner": git_repo_owner,
134135
"repo_name": git_repo_name,
136+
"n_candidates": N_CANDIDATES_EFFECTIVE,
135137
}
136138

137139
logger.info("!lsp|Generating optimized candidates…")
@@ -193,6 +195,7 @@ def optimize_python_code_line_profiler( # noqa: D417
193195
"experiment_metadata": experiment_metadata,
194196
"codeflash_version": codeflash_version,
195197
"lsp_mode": is_LSP_enabled(),
198+
"n_candidates_lp": N_CANDIDATES_LP_EFFECTIVE,
196199
}
197200

198201
console.rule()

codeflash/code_utils/config_consts.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,25 @@
1111
MIN_TESTCASE_PASSED_THRESHOLD = 6
1212
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
1313
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
14+
N_CANDIDATES_LP = 6
15+
16+
# LSP-specific
17+
N_CANDIDATES_LSP = 3
18+
N_TESTS_TO_GENERATE_LSP = 2
19+
TOTAL_LOOPING_TIME_LSP = 10.0 # Kept same timing for LSP mode to avoid in increase in performance reporting
20+
N_CANDIDATES_LP_LSP = 3
21+
22+
MAX_N_CANDIDATES = 5
23+
MAX_N_CANDIDATES_LP = 6
24+
25+
try:
26+
from codeflash.lsp.helpers import is_LSP_enabled
27+
28+
_IS_LSP_ENABLED = is_LSP_enabled()
29+
except ImportError:
30+
_IS_LSP_ENABLED = False
31+
32+
N_CANDIDATES_EFFECTIVE = min(N_CANDIDATES_LSP if _IS_LSP_ENABLED else N_CANDIDATES, MAX_N_CANDIDATES)
33+
N_CANDIDATES_LP_EFFECTIVE = min(N_CANDIDATES_LP_LSP if _IS_LSP_ENABLED else N_CANDIDATES_LP, MAX_N_CANDIDATES_LP)
34+
N_TESTS_TO_GENERATE_EFFECTIVE = N_TESTS_TO_GENERATE_LSP if _IS_LSP_ENABLED else N_TESTS_TO_GENERATE
35+
TOTAL_LOOPING_TIME_EFFECTIVE = TOTAL_LOOPING_TIME_LSP if _IS_LSP_ENABLED else TOTAL_LOOPING_TIME

codeflash/code_utils/git_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from unidiff import PatchSet
1717

1818
from codeflash.cli_cmds.console import logger
19-
from codeflash.code_utils.config_consts import N_CANDIDATES
19+
from codeflash.code_utils.config_consts import N_CANDIDATES_EFFECTIVE
2020

2121
if TYPE_CHECKING:
2222
from git import Repo
@@ -164,7 +164,7 @@ def create_git_worktrees(
164164
) -> tuple[Path | None, list[Path]]:
165165
if git_root and worktree_root_dir:
166166
worktree_root = Path(tempfile.mkdtemp(dir=worktree_root_dir))
167-
worktrees = [Path(tempfile.mkdtemp(dir=worktree_root)) for _ in range(N_CANDIDATES + 1)]
167+
worktrees = [Path(tempfile.mkdtemp(dir=worktree_root)) for _ in range(N_CANDIDATES_EFFECTIVE + 1)]
168168
for worktree in worktrees:
169169
subprocess.run(["git", "worktree", "add", "-d", worktree], cwd=module_root, check=True)
170170
else:

codeflash/lsp/beta.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
get_functions_within_git_diff,
2525
)
2626
from codeflash.either import is_successful
27-
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
27+
from codeflash.lsp.server import CodeflashLanguageServer
2828

2929
if TYPE_CHECKING:
3030
from argparse import Namespace
@@ -50,6 +50,13 @@ class ProvideApiKeyParams:
5050
api_key: str
5151

5252

53+
@dataclass
54+
class ValidateProjectParams:
55+
root_path_abs: str
56+
config_file: Optional[str] = None
57+
skip_validation: bool = False
58+
59+
5360
@dataclass
5461
class OnPatchAppliedParams:
5562
patch_id: str
@@ -60,7 +67,8 @@ class OptimizableFunctionsInCommitParams:
6067
commit_hash: str
6168

6269

63-
server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
70+
# server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
71+
server = CodeflashLanguageServer("codeflash-language-server", "v1.0")
6472

6573

6674
@server.feature("getOptimizableFunctionsInCurrentDiff")
@@ -160,17 +168,60 @@ def initialize_function_optimization(
160168
return {"functionName": params.functionName, "status": "success"}
161169

162170

163-
@server.feature("validateProject")
164-
def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizationParams) -> dict[str, str]:
171+
def _find_pyproject_toml(workspace_path: str) -> Path | None:
172+
workspace_path_obj = Path(workspace_path)
173+
max_depth = 2
174+
base_depth = len(workspace_path_obj.parts)
175+
176+
for root, dirs, files in os.walk(workspace_path_obj):
177+
depth = len(Path(root).parts) - base_depth
178+
if depth > max_depth:
179+
# stop going deeper into this branch
180+
dirs.clear()
181+
continue
182+
183+
if "pyproject.toml" in files:
184+
file_path = Path(root) / "pyproject.toml"
185+
with file_path.open("r", encoding="utf-8", errors="ignore") as f:
186+
for line in f:
187+
if line.strip() == "[tool.codeflash]":
188+
return file_path.resolve()
189+
return None
190+
191+
192+
# should be called the first thing to initialize and validate the project
193+
@server.feature("initProject")
194+
def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) -> dict[str, str]:
165195
from codeflash.cli_cmds.cmd_init import is_valid_pyproject_toml
166196

197+
pyproject_toml_path: Path | None = getattr(params, "config_file", None)
198+
199+
if server.args is None:
200+
if pyproject_toml_path is not None:
201+
# if there is a config file provided use it
202+
server.prepare_optimizer_arguments(pyproject_toml_path)
203+
else:
204+
# otherwise look for it
205+
pyproject_toml_path = _find_pyproject_toml(params.root_path_abs)
206+
server.show_message_log(f"Found pyproject.toml at: {pyproject_toml_path}", "Info")
207+
if pyproject_toml_path:
208+
server.prepare_optimizer_arguments(pyproject_toml_path)
209+
else:
210+
return {
211+
"status": "error",
212+
"message": "No pyproject.toml found in workspace.",
213+
} # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth
214+
215+
if getattr(params, "skip_validation", False):
216+
return {"status": "success", "moduleRoot": server.args.module_root, "pyprojectPath": pyproject_toml_path}
217+
167218
server.show_message_log("Validating project...", "Info")
168-
config = is_valid_pyproject_toml(server.args.config_file)
219+
config = is_valid_pyproject_toml(pyproject_toml_path)
169220
if config is None:
170221
server.show_message_log("pyproject.toml is not valid", "Error")
171222
return {
172223
"status": "error",
173-
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions
224+
"message": "pyproject.toml is not valid", # keep the error message the same, the extension is matching "pyproject.toml" in the error message to show the codeflash init instructions,
174225
}
175226

176227
args = process_args(server)
@@ -183,7 +234,7 @@ def validate_project(server: CodeflashLanguageServer, _params: FunctionOptimizat
183234
except Exception:
184235
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}
185236

186-
return {"status": "success", "moduleRoot": args.module_root}
237+
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path}
187238

188239

189240
def _initialize_optimizer_if_api_key_is_valid(
@@ -328,7 +379,7 @@ def perform_function_optimization( # noqa: PLR0911
328379

329380
devnull_writer = open(os.devnull, "w") # noqa
330381
with contextlib.redirect_stdout(devnull_writer):
331-
function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
382+
function_to_tests, _num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
332383
function_optimizer.function_to_tests = function_to_tests
333384

334385
test_setup_result = function_optimizer.generate_and_instrument_tests(

codeflash/lsp/lsp_logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def setup_logging() -> logging.Logger:
124124
logger = logging.getLogger()
125125
logger.handlers.clear()
126126

127-
# Set up stderr handler for VS Code output channel with [LSP-Server] prefix
127+
# Set up stderr handler for VS Code output channel
128128
handler = logging.StreamHandler(sys.stderr)
129129
handler.setLevel(logging.DEBUG)
130130

codeflash/lsp/server.py

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,20 @@
11
from __future__ import annotations
22

3-
from pathlib import Path
43
from typing import TYPE_CHECKING, Any
54

6-
from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType
7-
from pygls import uris
8-
from pygls.protocol import LanguageServerProtocol, lsp_method
5+
from lsprotocol.types import LogMessageParams, MessageType
6+
from pygls.protocol import LanguageServerProtocol
97
from pygls.server import LanguageServer
108

119
if TYPE_CHECKING:
12-
from lsprotocol.types import InitializeParams, InitializeResult
10+
from pathlib import Path
1311

1412
from codeflash.optimization.optimizer import Optimizer
1513

1614

1715
class CodeflashLanguageServerProtocol(LanguageServerProtocol):
1816
_server: CodeflashLanguageServer
1917

20-
@lsp_method(INITIALIZE)
21-
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
22-
server = self._server
23-
initialize_result: InitializeResult = super().lsp_initialize(params)
24-
25-
workspace_uri = params.root_uri
26-
if workspace_uri:
27-
workspace_path = uris.to_fs_path(workspace_uri)
28-
pyproject_toml_path = self._find_pyproject_toml(workspace_path)
29-
if pyproject_toml_path:
30-
server.prepare_optimizer_arguments(pyproject_toml_path)
31-
else:
32-
server.show_message("No pyproject.toml found in workspace.")
33-
else:
34-
server.show_message("No workspace URI provided.")
35-
36-
return initialize_result
37-
38-
def _find_pyproject_toml(self, workspace_path: str) -> Path | None:
39-
workspace_path_obj = Path(workspace_path)
40-
for file_path in workspace_path_obj.rglob("pyproject.toml"):
41-
return file_path.resolve()
42-
return None
43-
4418

4519
class CodeflashLanguageServer(LanguageServer):
4620
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401

0 commit comments

Comments
 (0)