Skip to content

Commit 6fc8926

Browse files
Merge branch 'main' of github.com:codeflash-ai/codeflash into feat/markdown-read-writable-context
2 parents 684661e + 8fdc591 commit 6fc8926

File tree

14 files changed

+110
-28
lines changed

14 files changed

+110
-28
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
repos:
2-
- repo: https://github.com/astral-sh/ruff-pre-commit
3-
rev: "v0.11.0"
4-
hooks:
5-
- id: ruff
6-
args: [--fix, --exit-non-zero-on-fix, --config=pyproject.toml]
7-
- id: ruff-format
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
rev: v0.12.7
4+
hooks:
5+
# Run the linter.
6+
- id: ruff-check
7+
# Run the formatter.
8+
- id: ruff-format

codeflash/api/cfapi.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def get_user_id() -> Optional[str]:
101101
if min_version and version.parse(min_version) > version.parse(__version__):
102102
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
103103
console.print(f"[bold red]{msg}[/bold red]")
104+
if console.quiet: # lsp
105+
logger.debug(msg)
106+
return f"Error: {msg}"
104107
sys.exit(1)
105108
return userid
106109

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ def leave_ClassDef(
3232
def visit_ClassDef(self, node: ClassDef) -> Optional[bool]:
3333
if self.class_name: # Don't go into nested class
3434
return False
35-
self.class_name = node.name.value # noqa: RET503
35+
self.class_name = node.name.value
36+
return None
3637

3738
def visit_FunctionDef(self, node: FunctionDef) -> Optional[bool]:
3839
if self.function_name: # Don't go into nested function
3940
return False
40-
self.function_name = node.name.value # noqa: RET503
41+
self.function_name = node.name.value
42+
return None
4143

4244
def leave_FunctionDef(self, original_node: FunctionDef, updated_node: FunctionDef) -> FunctionDef:
4345
if self.function_name == original_node.name.value:

codeflash/cli_cmds/cmd_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from argparse import Namespace
3939

4040
CODEFLASH_LOGO: str = (
41-
f"{LF}" # noqa: ISC003
41+
f"{LF}"
4242
r" _ ___ _ _ " + f"{LF}"
4343
r" | | / __)| | | | " + f"{LF}"
4444
r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}"

codeflash/code_utils/env_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
3434

3535
@lru_cache(maxsize=1)
3636
def get_codeflash_api_key() -> str:
37-
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
37+
if console.quiet: # lsp
38+
# prefer shell config over env var in lsp mode
39+
api_key = read_api_key_from_shell_config()
40+
else:
41+
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
42+
3843
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa
3944
if not api_key:
4045
msg = (

codeflash/code_utils/shell_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_shell_rc_path() -> Path:
4242

4343

4444
def get_api_key_export_line(api_key: str) -> str:
45-
return f"{SHELL_RC_EXPORT_PREFIX}{api_key}"
45+
return f'{SHELL_RC_EXPORT_PREFIX}"{api_key}"'
4646

4747

4848
def save_api_key_to_rc(api_key: str) -> Result[str, str]:

codeflash/lsp/beta.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from pygls import uris
1010

11+
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
12+
from codeflash.code_utils.shell_utils import save_api_key_to_rc
1113
from codeflash.either import is_successful
1214
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol
1315

@@ -28,6 +30,11 @@ class FunctionOptimizationParams:
2830
functionName: str # noqa: N815
2931

3032

33+
@dataclass
34+
class ProvideApiKeyParams:
35+
api_key: str
36+
37+
3138
server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)
3239

3340

@@ -118,6 +125,53 @@ def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOpt
118125
return {"functionName": params.functionName, "status": "success", "discovered_tests": num_discovered_tests}
119126

120127

128+
def _initialize_optimizer_if_valid(server: CodeflashLanguageServer) -> dict[str, str]:
129+
user_id = get_user_id()
130+
if user_id is None:
131+
return {"status": "error", "message": "api key not found or invalid"}
132+
133+
if user_id.startswith("Error: "):
134+
error_msg = user_id[7:]
135+
return {"status": "error", "message": error_msg}
136+
137+
from codeflash.optimization.optimizer import Optimizer
138+
139+
server.optimizer = Optimizer(server.args)
140+
return {"status": "success", "user_id": user_id}
141+
142+
143+
@server.feature("apiKeyExistsAndValid")
144+
def check_api_key(server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
145+
try:
146+
return _initialize_optimizer_if_valid(server)
147+
except Exception:
148+
return {"status": "error", "message": "something went wrong while validating the api key"}
149+
150+
151+
@server.feature("provideApiKey")
152+
def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams) -> dict[str, str]:
153+
try:
154+
api_key = params.api_key
155+
if not api_key.startswith("cf-"):
156+
return {"status": "error", "message": "Api key is not valid"}
157+
158+
result = save_api_key_to_rc(api_key)
159+
if not is_successful(result):
160+
return {"status": "error", "message": result.failure()}
161+
162+
# clear cache to ensure the new api key is used
163+
get_codeflash_api_key.cache_clear()
164+
get_user_id.cache_clear()
165+
166+
init_result = _initialize_optimizer_if_valid(server)
167+
if init_result["status"] == "error":
168+
return {"status": "error", "message": "Api key is not valid"}
169+
170+
return {"status": "success", "message": "Api key saved successfully", "user_id": init_result["user_id"]}
171+
except Exception:
172+
return {"status": "error", "message": "something went wrong while saving the api key"}
173+
174+
121175
@server.feature("prepareOptimization")
122176
def prepare_optimization(server: CodeflashLanguageServer, params: FunctionOptimizationParams) -> dict[str, str]:
123177
current_function = server.optimizer.current_function_being_optimized

codeflash/lsp/server.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
2525
workspace_path = uris.to_fs_path(workspace_uri)
2626
pyproject_toml_path = self._find_pyproject_toml(workspace_path)
2727
if pyproject_toml_path:
28-
server.initialize_optimizer(pyproject_toml_path)
28+
server.prepare_optimizer_arguments(pyproject_toml_path)
2929
server.show_message(f"Found pyproject.toml at: {pyproject_toml_path}")
3030
else:
3131
server.show_message("No pyproject.toml found in workspace.")
@@ -45,16 +45,17 @@ class CodeflashLanguageServer(LanguageServer):
4545
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
4646
super().__init__(*args, **kwargs)
4747
self.optimizer = None
48+
self.args = None
4849

49-
def initialize_optimizer(self, config_file: Path) -> None:
50+
def prepare_optimizer_arguments(self, config_file: Path) -> None:
5051
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
51-
from codeflash.optimization.optimizer import Optimizer
5252

5353
args = parse_args()
5454
args.config_file = config_file
5555
args.no_pr = True # LSP server should not create PRs
5656
args = process_pyproject_config(args)
57-
self.optimizer = Optimizer(args)
57+
self.args = args
58+
# avoid initializing the optimizer during initialization, because it can cause an error if the api key is invalid
5859

5960
def show_message_log(self, message: str, message_type: str) -> None:
6061
"""Send a log message to the client's output channel.
@@ -70,6 +71,7 @@ def show_message_log(self, message: str, message_type: str) -> None:
7071
"Warning": MessageType.Warning,
7172
"Error": MessageType.Error,
7273
"Log": MessageType.Log,
74+
"Debug": MessageType.Debug,
7375
}
7476

7577
lsp_message_type = type_mapping.get(message_type, MessageType.Info)

codeflash/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def unique_invocation_loop_id(self) -> str:
552552
return f"{self.loop_index}:{self.id.id()}"
553553

554554

555-
class TestResults(BaseModel):
555+
class TestResults(BaseModel): # noqa: PLW1641
556556
# don't modify these directly, use the add method
557557
# also we don't support deletion of test results elements - caution is advised
558558
test_results: list[FunctionTestInvocation] = []

codeflash/result/critic.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Optional
3+
from typing import TYPE_CHECKING
44

55
from codeflash.cli_cmds.console import logger
66
from codeflash.code_utils import env_utils
@@ -29,7 +29,8 @@ def speedup_critic(
2929
candidate_result: OptimizedCandidateResult,
3030
original_code_runtime: int,
3131
best_runtime_until_now: int | None,
32-
disable_gh_action_noise: Optional[bool] = None,
32+
*,
33+
disable_gh_action_noise: bool = False,
3334
) -> bool:
3435
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
3536
@@ -39,10 +40,8 @@ def speedup_critic(
3940
The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there.
4041
"""
4142
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
42-
if not disable_gh_action_noise:
43-
in_github_actions_mode = bool(env_utils.get_pr_number())
44-
if in_github_actions_mode:
45-
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
43+
if not disable_gh_action_noise and env_utils.is_ci():
44+
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
4645

4746
perf_gain = performance_gain(
4847
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime

0 commit comments

Comments
 (0)