Skip to content

Commit 91d898b

Browse files
Merge branch 'main' of github.com:codeflash-ai/codeflash into lsp/demo-optimization
2 parents 3ac68eb + 1f6cf3f commit 91d898b

20 files changed

+565
-70
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ jobs:
2828
- name: install dependencies
2929
run: uv sync
3030

31-
- name: Install test-only dependencies (Python 3.13)
32-
if: matrix.python-version == '3.13'
31+
- name: Install test-only dependencies (Python 3.9 and 3.13)
32+
if: matrix.python-version == '3.9' || matrix.python-version == '3.13'
3333
run: uv sync --group tests
3434

3535
- name: Unit tests

code_to_optimize/tests/pytest/test_topological_sort.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def test_topological_sort():
1010
g.addEdge(2, 3)
1111
g.addEdge(3, 1)
1212

13-
assert g.topologicalSort() == [5, 4, 2, 3, 1, 0]
13+
assert g.topologicalSort()[0] == [5, 4, 2, 3, 1, 0]
1414

1515

1616
def test_topological_sort_2():
@@ -20,15 +20,15 @@ def test_topological_sort_2():
2020
for j in range(i + 1, 10):
2121
g.addEdge(i, j)
2222

23-
assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
23+
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
2424

2525
g = Graph(10)
2626

2727
for i in range(10):
2828
for j in range(i + 1, 10):
2929
g.addEdge(i, j)
3030

31-
assert g.topologicalSort() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
31+
assert g.topologicalSort()[0] == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
3232

3333

3434
def test_topological_sort_3():
@@ -38,4 +38,4 @@ def test_topological_sort_3():
3838
for j in range(i + 1, 1000):
3939
g.addEdge(j, i)
4040

41-
assert g.topologicalSort() == list(reversed(range(1000)))
41+
assert g.topologicalSort()[0] == list(reversed(range(1000)))

codeflash/api/aiservice.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
255255
"optimized_code_runtime": opt.optimized_code_runtime,
256256
"speedup": opt.speedup,
257257
"trace_id": opt.trace_id,
258+
"function_references": opt.function_references,
259+
"python_version": platform.python_version(),
258260
}
259261
for opt in request
260262
]
@@ -308,6 +310,7 @@ def get_new_explanation( # noqa: D417
308310
original_throughput: str | None = None,
309311
optimized_throughput: str | None = None,
310312
throughput_improvement: str | None = None,
313+
function_references: str | None = None,
311314
) -> str:
312315
"""Optimize the given python code for performance by making a request to the Django endpoint.
313316
@@ -327,6 +330,7 @@ def get_new_explanation( # noqa: D417
327330
- original_throughput: str | None - throughput for the baseline code (operations per second)
328331
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
329332
- throughput_improvement: str | None - throughput improvement percentage
333+
- function_references: str | None - where the function is called in the codebase
330334
331335
Returns
332336
-------
@@ -349,6 +353,7 @@ def get_new_explanation( # noqa: D417
349353
"original_throughput": original_throughput,
350354
"optimized_throughput": optimized_throughput,
351355
"throughput_improvement": throughput_improvement,
356+
"function_references": function_references,
352357
}
353358
logger.info("loading|Generating explanation")
354359
console.rule()
@@ -373,7 +378,12 @@ def get_new_explanation( # noqa: D417
373378
return ""
374379

375380
def generate_ranking( # noqa: D417
376-
self, trace_id: str, diffs: list[str], optimization_ids: list[str], speedups: list[float]
381+
self,
382+
trace_id: str,
383+
diffs: list[str],
384+
optimization_ids: list[str],
385+
speedups: list[float],
386+
function_references: str | None = None,
377387
) -> list[int] | None:
378388
"""Optimize the given python code for performance by making a request to the Django endpoint.
379389
@@ -382,6 +392,7 @@ def generate_ranking( # noqa: D417
382392
- trace_id : unique uuid of function
383393
- diffs : list of unified diff strings of opt candidates
384394
- speedups : list of speedups of opt candidates
395+
- function_references : where the function is called in the codebase
385396
386397
Returns
387398
-------
@@ -394,6 +405,7 @@ def generate_ranking( # noqa: D417
394405
"speedups": speedups,
395406
"optimization_ids": optimization_ids,
396407
"python_version": platform.python_version(),
408+
"function_references": function_references,
397409
}
398410
logger.info("loading|Generating ranking")
399411
console.rule()
@@ -594,6 +606,7 @@ def get_optimization_review(
594606
"optimized_runtime": humanize_runtime(explanation.best_runtime_ns),
595607
"original_runtime": humanize_runtime(explanation.original_runtime_ns),
596608
"calling_fn_details": calling_fn_details,
609+
"python_version": platform.python_version(),
597610
}
598611
console.rule()
599612
try:

codeflash/api/cfapi.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
"""Module for interacting with the Codeflash API."""
2-
31
from __future__ import annotations
42

53
import json
64
import os
7-
import sys
85
from dataclasses import dataclass
96
from functools import lru_cache
107
from pathlib import Path
@@ -101,12 +98,13 @@ def make_cfapi_request(
10198

10299

103100
@lru_cache(maxsize=1)
104-
def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
101+
def get_user_id(api_key: Optional[str] = None) -> Optional[str]: # noqa: PLR0911
105102
"""Retrieve the user's userid by making a request to the /cfapi/cli-get-user endpoint.
106103
107104
:param api_key: The API key to use. If None, uses get_codeflash_api_key().
108105
:return: The userid or None if the request fails.
109106
"""
107+
lsp_enabled = is_LSP_enabled()
110108
if not api_key and not ensure_codeflash_api_key():
111109
return None
112110

@@ -127,19 +125,21 @@ def get_user_id(api_key: Optional[str] = None) -> Optional[str]:
127125
if min_version and version.parse(min_version) > version.parse(__version__):
128126
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
129127
console.print(f"[bold red]{msg}[/bold red]")
130-
if is_LSP_enabled():
128+
if lsp_enabled:
131129
logger.debug(msg)
132130
return f"Error: {msg}"
133-
sys.exit(1)
131+
exit_with_message(msg, error_on_exit=True)
134132
return userid
135133

136134
logger.error("Failed to retrieve userid from the response.")
137135
return None
138136

139-
# Handle 403 (Invalid API key) - exit with error message
140137
if response.status_code == 403:
138+
error_title = "Invalid Codeflash API key. The API key you provided is not valid."
139+
if lsp_enabled:
140+
return f"Error: {error_title}"
141141
msg = (
142-
"Invalid Codeflash API key. The API key you provided is not valid.\n"
142+
f"{error_title}\n"
143143
"Please generate a new one at https://app.codeflash.ai/app/apikeys ,\n"
144144
"then set it as a CODEFLASH_API_KEY environment variable.\n"
145145
"For more information, refer to the documentation at \n"

codeflash/cli_cmds/cmd_init.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def get_valid_subdirs(current_dir: Optional[Path] = None) -> list[str]:
264264
]
265265

266266

267-
def get_suggestions(section: str) -> tuple(list[str], Optional[str]):
267+
def get_suggestions(section: str) -> tuple[list[str], Optional[str]]:
268268
valid_subdirs = get_valid_subdirs()
269269
if section == CommonSections.module_root:
270270
return [d for d in valid_subdirs if d != "tests"], None
@@ -391,7 +391,7 @@ def collect_setup_info() -> CLISetupInfo:
391391
tests_root_answer = tests_answers["tests_root"]
392392

393393
if tests_root_answer == create_for_me_option:
394-
tests_root = Path(curdir) / default_tests_subdir
394+
tests_root = Path(curdir) / (default_tests_subdir or "tests")
395395
tests_root.mkdir()
396396
click.echo(f"✅ Created directory {tests_root}{os.path.sep}{LF}")
397397
elif tests_root_answer == custom_dir_option:

codeflash/code_utils/code_extractor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import time
45
from dataclasses import dataclass
56
from itertools import chain
67
from pathlib import Path
@@ -1138,6 +1139,7 @@ def find_specific_function_in_file(
11381139
def get_fn_references_jedi(
11391140
source_code: str, file_path: Path, project_root: Path, target_function: str, target_class: str | None
11401141
) -> list[Path]:
1142+
start_time = time.perf_counter()
11411143
function_position: CodePosition = find_specific_function_in_file(
11421144
source_code, file_path, target_function, target_class
11431145
)
@@ -1146,6 +1148,8 @@ def get_fn_references_jedi(
11461148
# Get references to the function
11471149
references = script.get_references(line=function_position.line_no, column=function_position.col_no)
11481150
# Collect unique file paths where references are found
1151+
end_time = time.perf_counter()
1152+
logger.debug(f"Jedi for function references ran in {end_time - start_time:.2f} seconds")
11491153
reference_files = set()
11501154
for ref in references:
11511155
if ref.module_path:
@@ -1163,6 +1167,7 @@ def get_fn_references_jedi(
11631167
def get_opt_review_metrics(
11641168
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path
11651169
) -> str:
1170+
start_time = time.perf_counter()
11661171
try:
11671172
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
11681173
if len(qualified_name_split) == 1:
@@ -1176,4 +1181,6 @@ def get_opt_review_metrics(
11761181
except Exception as e:
11771182
calling_fns_details = ""
11781183
logger.debug(f"Investigate {e}")
1184+
end_time = time.perf_counter()
1185+
logger.debug(f"Got function references in {end_time - start_time:.2f} seconds")
11791186
return calling_fns_details

codeflash/code_utils/code_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from codeflash.cli_cmds.console import logger, paneled_text
1919
from codeflash.code_utils.config_parser import find_pyproject_toml, get_all_closest_config_files
20+
from codeflash.lsp.helpers import is_LSP_enabled
2021

2122
ImportErrorPattern = re.compile(r"ModuleNotFoundError.*$", re.MULTILINE)
2223

@@ -352,6 +353,10 @@ def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
352353

353354

354355
def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
356+
"""Don't Call it inside the lsp process, it will terminate the lsp server."""
357+
if is_LSP_enabled():
358+
logger.error(message)
359+
return
355360
paneled_text(message, panel_args={"style": "red"})
356361

357362
sys.exit(1 if error_on_exit else 0)

codeflash/code_utils/env_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from codeflash.code_utils.code_utils import exit_with_message
1414
from codeflash.code_utils.formatter import format_code
1515
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc
16+
from codeflash.lsp.helpers import is_LSP_enabled
1617

1718

1819
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
@@ -70,7 +71,10 @@ def get_codeflash_api_key() -> str:
7071
except Exception as e:
7172
logger.debug(f"Failed to automatically save API key to shell config: {e}")
7273

73-
api_key = env_api_key or shell_api_key
74+
# Prefer the shell configuration over environment variables for lsp,
75+
# as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart
76+
# within the same process, the environment variable could become outdated.
77+
api_key = shell_api_key or env_api_key if is_LSP_enabled() else env_api_key or shell_api_key
7478

7579
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa
7680
if not api_key:

0 commit comments

Comments
 (0)