Skip to content

Commit c4e3e00

Browse files
authored
Merge branch 'main' into alpha-async
2 parents a5182c6 + fdaf6c0 commit c4e3e00

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1217
-567
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ Join our community for support and discussions. If you have any questions, feel
8181

8282
## License
8383

84-
Codeflash is licensed under the BSL-1.1 License. See the LICENSE file for details.
84+
Codeflash is licensed under the BSL-1.1 License. See the [LICENSE](https://github.com/codeflash-ai/codeflash/blob/main/codeflash/LICENSE) file for details.

codeflash/api/aiservice.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from pydantic.json import pydantic_encoder
1111

1212
from codeflash.cli_cmds.console import console, logger
13-
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
13+
from codeflash.code_utils.env_utils import get_codeflash_api_key
1414
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
15+
from codeflash.lsp.helpers import is_LSP_enabled
1516
from codeflash.models.ExperimentMetadata import ExperimentMetadata
1617
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
1718
from codeflash.telemetry.posthog_cf import ph
@@ -80,6 +81,19 @@ def make_ai_service_request(
8081
# response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
8182
return response
8283

84+
def _get_valid_candidates(self, optimizations_json: list[dict[str, Any]]) -> list[OptimizedCandidate]:
85+
candidates: list[OptimizedCandidate] = []
86+
for opt in optimizations_json:
87+
code = CodeStringsMarkdown.parse_markdown_code(opt["source_code"])
88+
if not code.code_strings:
89+
continue
90+
candidates.append(
91+
OptimizedCandidate(
92+
source_code=code, explanation=opt["explanation"], optimization_id=opt["optimization_id"]
93+
)
94+
)
95+
return candidates
96+
8397
def optimize_python_code( # noqa: D417
8498
self,
8599
source_code: str,
@@ -134,14 +148,7 @@ def optimize_python_code( # noqa: D417
134148
console.rule()
135149
end_time = time.perf_counter()
136150
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
137-
return [
138-
OptimizedCandidate(
139-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
140-
explanation=opt["explanation"],
141-
optimization_id=opt["optimization_id"],
142-
)
143-
for opt in optimizations_json
144-
]
151+
return self._get_valid_candidates(optimizations_json)
145152
try:
146153
error = response.json()["error"]
147154
except Exception:
@@ -204,14 +211,7 @@ def optimize_python_code_line_profiler( # noqa: D417
204211
optimizations_json = response.json()["optimizations"]
205212
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
206213
console.rule()
207-
return [
208-
OptimizedCandidate(
209-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
210-
explanation=opt["explanation"],
211-
optimization_id=opt["optimization_id"],
212-
)
213-
for opt in optimizations_json
214-
]
214+
return self._get_valid_candidates(optimizations_json)
215215
try:
216216
error = response.json()["error"]
217217
except Exception:
@@ -261,14 +261,17 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
261261
refined_optimizations = response.json()["refinements"]
262262
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
263263
console.rule()
264+
265+
refinements = self._get_valid_candidates(refined_optimizations)
264266
return [
265267
OptimizedCandidate(
266-
source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]),
267-
explanation=opt["explanation"],
268-
optimization_id=opt["optimization_id"][:-4] + "refi",
268+
source_code=c.source_code,
269+
explanation=c.explanation,
270+
optimization_id=c.optimization_id[:-4] + "refi",
269271
)
270-
for opt in refined_optimizations
272+
for c in refinements
271273
]
274+
272275
try:
273276
error = response.json()["error"]
274277
except Exception:

codeflash/api/cfapi.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
from codeflash.cli_cmds.console import console, logger
1616
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
17-
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir
17+
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name
1818
from codeflash.github.PrComment import FileDiffContent, PrComment
19+
from codeflash.lsp.helpers import is_LSP_enabled
1920
from codeflash.version import __version__
2021

2122
if TYPE_CHECKING:
@@ -101,7 +102,7 @@ def get_user_id() -> Optional[str]:
101102
if min_version and version.parse(min_version) > version.parse(__version__):
102103
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
103104
console.print(f"[bold red]{msg}[/bold red]")
104-
if console.quiet: # lsp
105+
if is_LSP_enabled():
105106
logger.debug(msg)
106107
return f"Error: {msg}"
107108
sys.exit(1)
@@ -203,8 +204,9 @@ def create_staging(
203204
generated_original_test_source: str,
204205
function_trace_id: str,
205206
coverage_message: str,
206-
replay_tests: str = "",
207-
concolic_tests: str = "",
207+
replay_tests: str,
208+
concolic_tests: str,
209+
root_dir: Path,
208210
) -> Response:
209211
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
210212
@@ -217,12 +219,10 @@ def create_staging(
217219
:param coverage_message: Coverage report or summary.
218220
:return: The response object from the backend.
219221
"""
220-
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
222+
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
221223

222224
build_file_changes = {
223-
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
224-
oldContent=original_code[p], newContent=new_code[p]
225-
)
225+
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(oldContent=original_code[p], newContent=new_code[p])
226226
for p in original_code
227227
}
228228

codeflash/cli_cmds/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from codeflash.code_utils import env_utils
1111
from codeflash.code_utils.code_utils import exit_with_message
1212
from codeflash.code_utils.config_parser import parse_config_file
13+
from codeflash.lsp.helpers import is_LSP_enabled
1314
from codeflash.version import __version__ as version
1415

1516

@@ -94,6 +95,7 @@ def parse_args() -> Namespace:
9495
help="Path to the directory of the project, where all the pytest-benchmark tests are located.",
9596
)
9697
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
98+
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
9799

98100
args, unknown_args = parser.parse_known_args()
99101
sys.argv[:] = [sys.argv[0], *unknown_args]
@@ -210,6 +212,9 @@ def process_pyproject_config(args: Namespace) -> Namespace:
210212
if args.benchmarks_root:
211213
args.benchmarks_root = Path(args.benchmarks_root).resolve()
212214
args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path)
215+
if is_LSP_enabled():
216+
args.all = None
217+
return args
213218
return handle_optimize_all_arg_parsing(args)
214219

215220

codeflash/cli_cmds/cmd_init.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,22 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
155155
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
156156

157157

158+
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> dict[str, Any] | None:
159+
if not pyproject_toml_path.exists():
160+
return None
161+
try:
162+
config, _ = parse_config_file(pyproject_toml_path)
163+
except Exception:
164+
return None
165+
166+
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
167+
return None
168+
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
169+
return None
170+
171+
return config
172+
173+
158174
def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
159175
"""Check if the current directory contains a valid pyproject.toml file with codeflash config.
160176
@@ -163,16 +179,9 @@ def should_modify_pyproject_toml() -> tuple[bool, dict[str, Any] | None]:
163179
from rich.prompt import Confirm
164180

165181
pyproject_toml_path = Path.cwd() / "pyproject.toml"
166-
if not pyproject_toml_path.exists():
167-
return True, None
168-
try:
169-
config, config_file_path = parse_config_file(pyproject_toml_path)
170-
except Exception:
171-
return True, None
172182

173-
if "module_root" not in config or config["module_root"] is None or not Path(config["module_root"]).is_dir():
174-
return True, None
175-
if "tests_root" not in config or config["tests_root"] is None or not Path(config["tests_root"]).is_dir():
183+
config = is_valid_pyproject_toml(pyproject_toml_path)
184+
if config is None:
176185
return True, None
177186

178187
return Confirm.ask(
@@ -968,6 +977,11 @@ def install_github_app(git_remote: str) -> None:
968977
except git.InvalidGitRepositoryError:
969978
click.echo("Skipping GitHub app installation because you're not in a git repository.")
970979
return
980+
981+
if git_remote not in get_git_remotes(git_repo):
982+
click.echo(f"Skipping GitHub app installation, remote ({git_remote}) does not exist in this repository.")
983+
return
984+
971985
owner, repo = get_repo_owner_and_name(git_repo, git_remote)
972986

973987
if is_github_app_installed_on_repo(owner, repo, suppress_errors=True):

codeflash/cli_cmds/console.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import os
45
from contextlib import contextmanager
56
from itertools import cycle
67
from typing import TYPE_CHECKING
@@ -28,6 +29,10 @@
2829
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
2930

3031
console = Console()
32+
33+
if os.getenv("CODEFLASH_LSP"):
34+
console.quiet = True
35+
3136
logging.basicConfig(
3237
level=logging.INFO,
3338
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],

codeflash/code_utils/code_extractor.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import ast
5+
from itertools import chain
56
from typing import TYPE_CHECKING, Optional
67

78
import libcst as cst
@@ -119,6 +120,32 @@ def leave_Assign(self, original_node: cst.Assign, updated_node: cst.Assign) -> c
119120

120121
return updated_node
121122

123+
def _find_insertion_index(self, updated_node: cst.Module) -> int:
124+
"""Find the position of the last import statement in the top-level of the module."""
125+
insert_index = 0
126+
for i, stmt in enumerate(updated_node.body):
127+
is_top_level_import = isinstance(stmt, cst.SimpleStatementLine) and any(
128+
isinstance(child, (cst.Import, cst.ImportFrom)) for child in stmt.body
129+
)
130+
131+
is_conditional_import = isinstance(stmt, cst.If) and all(
132+
isinstance(inner, cst.SimpleStatementLine)
133+
and all(isinstance(child, (cst.Import, cst.ImportFrom)) for child in inner.body)
134+
for inner in stmt.body.body
135+
)
136+
137+
if is_top_level_import or is_conditional_import:
138+
insert_index = i + 1
139+
140+
# Stop scanning once we reach a class or function definition.
141+
# Imports are supposed to be at the top of the file, but they can technically appear anywhere, even at the bottom of the file.
142+
# Without this check, a stray import later in the file
143+
# would incorrectly shift our insertion index below actual code definitions.
144+
if isinstance(stmt, (cst.ClassDef, cst.FunctionDef)):
145+
break
146+
147+
return insert_index
148+
122149
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
123150
# Add any new assignments that weren't in the original file
124151
new_statements = list(updated_node.body)
@@ -131,18 +158,26 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
131158
]
132159

133160
if assignments_to_append:
134-
# Add a blank line before appending new assignments if needed
135-
if new_statements and not isinstance(new_statements[-1], cst.EmptyLine):
136-
new_statements.append(cst.SimpleStatementLine([cst.Pass()], leading_lines=[cst.EmptyLine()]))
137-
new_statements.pop() # Remove the Pass statement but keep the empty line
138-
139-
# Add the new assignments
140-
new_statements.extend(
141-
[
142-
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
143-
for assignment in assignments_to_append
144-
]
145-
)
161+
# after last top-level imports
162+
insert_index = self._find_insertion_index(updated_node)
163+
164+
assignment_lines = [
165+
cst.SimpleStatementLine([assignment], leading_lines=[cst.EmptyLine()])
166+
for assignment in assignments_to_append
167+
]
168+
169+
new_statements = list(chain(new_statements[:insert_index], assignment_lines, new_statements[insert_index:]))
170+
171+
# Add a blank line after the last assignment if needed
172+
after_index = insert_index + len(assignment_lines)
173+
if after_index < len(new_statements):
174+
next_stmt = new_statements[after_index]
175+
# If there's no empty line, add one
176+
has_empty = any(isinstance(line, cst.EmptyLine) for line in next_stmt.leading_lines)
177+
if not has_empty:
178+
new_statements[after_index] = next_stmt.with_changes(
179+
leading_lines=[cst.EmptyLine(), *next_stmt.leading_lines]
180+
)
146181

147182
return updated_node.with_changes(body=new_statements)
148183

codeflash/code_utils/coverage_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,20 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
3939
return full_name
4040

4141

42-
def generate_candidates(source_code_path: Path) -> list[str]:
42+
def generate_candidates(source_code_path: Path) -> set[str]:
4343
"""Generate all the possible candidates for coverage data based on the source code path."""
44-
candidates = [source_code_path.name]
44+
candidates = set()
45+
candidates.add(source_code_path.name)
4546
current_path = source_code_path.parent
4647

48+
last_added = source_code_path.name
4749
while current_path != current_path.parent:
48-
candidate_path = str(Path(current_path.name) / candidates[-1])
49-
candidates.append(candidate_path)
50+
candidate_path = str(Path(current_path.name) / last_added)
51+
candidates.add(candidate_path)
52+
last_added = candidate_path
5053
current_path = current_path.parent
5154

55+
candidates.add(str(source_code_path))
5256
return candidates
5357

5458

codeflash/code_utils/env_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from pathlib import Path
88
from typing import Any, Optional
99

10-
from codeflash.cli_cmds.console import console, logger
10+
from codeflash.cli_cmds.console import logger
1111
from codeflash.code_utils.code_utils import exit_with_message
1212
from codeflash.code_utils.formatter import format_code
1313
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
14+
from codeflash.lsp.helpers import is_LSP_enabled
1415

1516

1617
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
@@ -34,11 +35,12 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
3435

3536
@lru_cache(maxsize=1)
3637
def get_codeflash_api_key() -> str:
37-
if console.quiet: # lsp
38-
# prefer shell config over env var in lsp mode
39-
api_key = read_api_key_from_shell_config()
40-
else:
41-
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
38+
# prefer shell config over env var in lsp mode
39+
api_key = (
40+
read_api_key_from_shell_config()
41+
if is_LSP_enabled()
42+
else os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
43+
)
4244

4345
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa
4446
if not api_key:
@@ -125,11 +127,6 @@ def is_ci() -> bool:
125127
return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS"))
126128

127129

128-
@lru_cache(maxsize=1)
129-
def is_LSP_enabled() -> bool:
130-
return console.quiet
131-
132-
133130
def is_pr_draft() -> bool:
134131
"""Check if the PR is draft. in the github action context."""
135132
event = get_cached_gh_event_data()

0 commit comments

Comments
 (0)