Skip to content

Commit 099cd00

Browse files
authored
Merge branch 'main' into chore/add-staging/docs
2 parents bd1de12 + 217ced2 commit 099cd00

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+1174
-535
lines changed

README.md

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
<a href="https://github.com/codeflash-ai/codeflash">
44
<img src="https://img.shields.io/github/commit-activity/m/codeflash-ai/codeflash" alt="GitHub commit activity">
55
</a>
6-
<a href="https://pypi.org/project/codeflash/">
7-
<img src="https://img.shields.io/pypi/dm/codeflash" alt="PyPI Downloads">
8-
</a>
6+
<a href="https://pypi.org/project/codeflash/"><img src="https://static.pepy.tech/badge/codeflash" alt="PyPI Downloads"></a>
97
<a href="https://pypi.org/project/codeflash/">
108
<img src="https://img.shields.io/pypi/v/codeflash?label=PyPI%20version" alt="PyPI Downloads">
119
</a>
@@ -83,4 +81,4 @@ Join our community for support and discussions. If you have any questions, feel
8381

8482
## License
8583

86-
Codeflash is licensed under the BSL-1.1 License. See the LICENSE file for details.
84+
Codeflash is licensed under the BSL-1.1 License. See the [LICENSE](https://github.com/codeflash-ai/codeflash/blob/main/codeflash/LICENSE) file for details.

codeflash/LICENSE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Business Source License 1.1
33
Parameters
44

55
Licensor: CodeFlash Inc.
6-
Licensed Work: Codeflash Client version 0.15.x
6+
Licensed Work: Codeflash Client version 0.16.x
77
The Licensed Work is (c) 2024 CodeFlash Inc.
88

99
Additional Use Grant: None. Production use of the Licensed Work is only permitted
@@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
1313
Platform. Please visit codeflash.ai for further
1414
information.
1515

16-
Change Date: 2029-07-03
16+
Change Date: 2029-08-14
1717

1818
Change License: MIT
1919

codeflash/api/aiservice.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from pydantic.json import pydantic_encoder
1111

1212
from codeflash.cli_cmds.console import console, logger
13-
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
13+
from codeflash.code_utils.env_utils import get_codeflash_api_key
1414
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
15+
from codeflash.lsp.helpers import is_LSP_enabled
1516
from codeflash.models.ExperimentMetadata import ExperimentMetadata
1617
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
1718
from codeflash.telemetry.posthog_cf import ph
@@ -202,7 +203,7 @@ def optimize_python_code_line_profiler( # noqa: D417
202203

203204
if response.status_code == 200:
204205
optimizations_json = response.json()["optimizations"]
205-
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
206+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations using line profiler information.")
206207
console.rule()
207208
return [
208209
OptimizedCandidate(
@@ -248,7 +249,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
248249
}
249250
for opt in request
250251
]
251-
logger.info(f"Refining {len(request)} optimizations…")
252+
logger.debug(f"Refining {len(request)} optimizations…")
252253
console.rule()
253254
try:
254255
response = self.make_ai_service_request("/refinement", payload=payload, timeout=600)
@@ -259,7 +260,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
259260

260261
if response.status_code == 200:
261262
refined_optimizations = response.json()["refinements"]
262-
logger.info(f"Generated {len(refined_optimizations)} candidate refinements.")
263+
logger.debug(f"Generated {len(refined_optimizations)} candidate refinements.")
263264
console.rule()
264265
return [
265266
OptimizedCandidate(
@@ -339,7 +340,6 @@ def get_new_explanation( # noqa: D417
339340

340341
if response.status_code == 200:
341342
explanation: str = response.json()["explanation"]
342-
logger.debug(f"New Explanation: {explanation}")
343343
console.rule()
344344
return explanation
345345
try:

codeflash/api/cfapi.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
from codeflash.cli_cmds.console import console, logger
1616
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
17-
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir
17+
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name
1818
from codeflash.github.PrComment import FileDiffContent, PrComment
19+
from codeflash.lsp.helpers import is_LSP_enabled
1920
from codeflash.version import __version__
2021

2122
if TYPE_CHECKING:
@@ -101,7 +102,7 @@ def get_user_id() -> Optional[str]:
101102
if min_version and version.parse(min_version) > version.parse(__version__):
102103
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
103104
console.print(f"[bold red]{msg}[/bold red]")
104-
if console.quiet: # lsp
105+
if is_LSP_enabled():
105106
logger.debug(msg)
106107
return f"Error: {msg}"
107108
sys.exit(1)
@@ -203,6 +204,9 @@ def create_staging(
203204
generated_original_test_source: str,
204205
function_trace_id: str,
205206
coverage_message: str,
207+
replay_tests: str,
208+
concolic_tests: str,
209+
root_dir: Path,
206210
) -> Response:
207211
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
208212
@@ -215,12 +219,10 @@ def create_staging(
215219
:param coverage_message: Coverage report or summary.
216220
:return: The response object from the backend.
217221
"""
218-
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
222+
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
219223

220224
build_file_changes = {
221-
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
222-
oldContent=original_code[p], newContent=new_code[p]
223-
)
225+
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(oldContent=original_code[p], newContent=new_code[p])
224226
for p in original_code
225227
}
226228

@@ -243,6 +245,8 @@ def create_staging(
243245
"generatedTests": generated_original_test_source,
244246
"traceId": function_trace_id,
245247
"coverage_message": coverage_message,
248+
"replayTests": replay_tests,
249+
"concolicTests": concolic_tests,
246250
}
247251

248252
return make_cfapi_request(endpoint="/create-staging", method="POST", payload=payload)

codeflash/cli_cmds/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def parse_args() -> Namespace:
9494
help="Path to the directory of the project, where all the pytest-benchmark tests are located.",
9595
)
9696
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
97+
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
9798

9899
args, unknown_args = parser.parse_known_args()
99100
sys.argv[:] = [sys.argv[0], *unknown_args]

codeflash/cli_cmds/console.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import os
45
from contextlib import contextmanager
56
from itertools import cycle
67
from typing import TYPE_CHECKING
@@ -28,6 +29,10 @@
2829
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
2930

3031
console = Console()
32+
33+
if os.getenv("CODEFLASH_LSP"):
34+
console.quiet = True
35+
3136
logging.basicConfig(
3237
level=logging.INFO,
3338
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],

codeflash/code_utils/code_extractor.py

Lines changed: 94 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,79 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine) -> None:
195195
self.last_import_line = self.current_line
196196

197197

198+
class DottedImportCollector(cst.CSTVisitor):
199+
"""Collects all top-level imports from a Python module in normalized dotted format, including top-level conditional imports like `if TYPE_CHECKING:`.
200+
201+
Examples
202+
--------
203+
import os ==> "os"
204+
import dbt.adapters.factory ==> "dbt.adapters.factory"
205+
from pathlib import Path ==> "pathlib.Path"
206+
from recce.adapter.base import BaseAdapter ==> "recce.adapter.base.BaseAdapter"
207+
from typing import Any, List, Optional ==> "typing.Any", "typing.List", "typing.Optional"
208+
from recce.util.lineage import ( build_column_key, filter_dependency_maps) ==> "recce.util.lineage.build_column_key", "recce.util.lineage.filter_dependency_maps"
209+
210+
"""
211+
212+
def __init__(self) -> None:
213+
self.imports: set[str] = set()
214+
self.depth = 0 # top-level
215+
216+
def get_full_dotted_name(self, expr: cst.BaseExpression) -> str:
217+
if isinstance(expr, cst.Name):
218+
return expr.value
219+
if isinstance(expr, cst.Attribute):
220+
return f"{self.get_full_dotted_name(expr.value)}.{expr.attr.value}"
221+
return ""
222+
223+
def _collect_imports_from_block(self, block: cst.IndentedBlock) -> None:
224+
for statement in block.body:
225+
if isinstance(statement, cst.SimpleStatementLine):
226+
for child in statement.body:
227+
if isinstance(child, cst.Import):
228+
for alias in child.names:
229+
module = self.get_full_dotted_name(alias.name)
230+
asname = alias.asname.name.value if alias.asname else alias.name.value
231+
if isinstance(asname, cst.Attribute):
232+
self.imports.add(module)
233+
else:
234+
self.imports.add(module if module == asname else f"{module}.{asname}")
235+
236+
elif isinstance(child, cst.ImportFrom):
237+
if child.module is None:
238+
continue
239+
module = self.get_full_dotted_name(child.module)
240+
for alias in child.names:
241+
if isinstance(alias, cst.ImportAlias):
242+
name = alias.name.value
243+
asname = alias.asname.name.value if alias.asname else name
244+
self.imports.add(f"{module}.{asname}")
245+
246+
def visit_Module(self, node: cst.Module) -> None:
247+
self.depth = 0
248+
self._collect_imports_from_block(node)
249+
250+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
251+
self.depth += 1
252+
253+
def leave_FunctionDef(self, node: cst.FunctionDef) -> None:
254+
self.depth -= 1
255+
256+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
257+
self.depth += 1
258+
259+
def leave_ClassDef(self, node: cst.ClassDef) -> None:
260+
self.depth -= 1
261+
262+
def visit_If(self, node: cst.If) -> None:
263+
if self.depth == 0:
264+
self._collect_imports_from_block(node.body)
265+
266+
def visit_Try(self, node: cst.Try) -> None:
267+
if self.depth == 0:
268+
self._collect_imports_from_block(node.body)
269+
270+
198271
class ImportInserter(cst.CSTTransformer):
199272
"""Transformer that inserts global statements after the last import."""
200273

@@ -329,38 +402,49 @@ def add_needed_imports_from_module(
329402
except Exception as e:
330403
logger.error(f"Error parsing source module code: {e}")
331404
return dst_module_code
405+
406+
dotted_import_collector = DottedImportCollector()
407+
try:
408+
parsed_dst_module = cst.parse_module(dst_module_code)
409+
parsed_dst_module.visit(dotted_import_collector)
410+
except cst.ParserSyntaxError as e:
411+
logger.exception(f"Syntax error in destination module code: {e}")
412+
return dst_module_code # Return the original code if there's a syntax error
413+
332414
try:
333415
for mod in gatherer.module_imports:
334-
AddImportsVisitor.add_needed_import(dst_context, mod)
416+
if mod not in dotted_import_collector.imports:
417+
AddImportsVisitor.add_needed_import(dst_context, mod)
335418
RemoveImportsVisitor.remove_unused_import(dst_context, mod)
336419
for mod, obj_seq in gatherer.object_mapping.items():
337420
for obj in obj_seq:
338421
if (
339422
f"{mod}.{obj}" in helper_functions_fqn or dst_context.full_module_name == mod # avoid circular deps
340423
):
341424
continue # Skip adding imports for helper functions already in the context
342-
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
425+
if f"{mod}.{obj}" not in dotted_import_collector.imports:
426+
AddImportsVisitor.add_needed_import(dst_context, mod, obj)
343427
RemoveImportsVisitor.remove_unused_import(dst_context, mod, obj)
344428
except Exception as e:
345429
logger.exception(f"Error adding imports to destination module code: {e}")
346430
return dst_module_code
431+
347432
for mod, asname in gatherer.module_aliases.items():
348-
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
433+
if f"{mod}.{asname}" not in dotted_import_collector.imports:
434+
AddImportsVisitor.add_needed_import(dst_context, mod, asname=asname)
349435
RemoveImportsVisitor.remove_unused_import(dst_context, mod, asname=asname)
436+
350437
for mod, alias_pairs in gatherer.alias_mapping.items():
351438
for alias_pair in alias_pairs:
352439
if f"{mod}.{alias_pair[0]}" in helper_functions_fqn:
353440
continue
354-
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
441+
442+
if f"{mod}.{alias_pair[1]}" not in dotted_import_collector.imports:
443+
AddImportsVisitor.add_needed_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
355444
RemoveImportsVisitor.remove_unused_import(dst_context, mod, alias_pair[0], asname=alias_pair[1])
356445

357446
try:
358-
parsed_module = cst.parse_module(dst_module_code)
359-
except cst.ParserSyntaxError as e:
360-
logger.exception(f"Syntax error in destination module code: {e}")
361-
return dst_module_code # Return the original code if there's a syntax error
362-
try:
363-
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_module)
447+
transformed_module = AddImportsVisitor(dst_context).transform_module(parsed_dst_module)
364448
transformed_module = RemoveImportsVisitor(dst_context).transform_module(transformed_module)
365449
return transformed_module.code.lstrip("\n")
366450
except Exception as e:

codeflash/code_utils/coverage_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,20 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
3939
return full_name
4040

4141

42-
def generate_candidates(source_code_path: Path) -> list[str]:
42+
def generate_candidates(source_code_path: Path) -> set[str]:
4343
"""Generate all the possible candidates for coverage data based on the source code path."""
44-
candidates = [source_code_path.name]
44+
candidates = set()
45+
candidates.add(source_code_path.name)
4546
current_path = source_code_path.parent
4647

48+
last_added = source_code_path.name
4749
while current_path != current_path.parent:
48-
candidate_path = str(Path(current_path.name) / candidates[-1])
49-
candidates.append(candidate_path)
50+
candidate_path = str(Path(current_path.name) / last_added)
51+
candidates.add(candidate_path)
52+
last_added = candidate_path
5053
current_path = current_path.parent
5154

55+
candidates.add(str(source_code_path))
5256
return candidates
5357

5458

codeflash/code_utils/env_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from pathlib import Path
88
from typing import Any, Optional
99

10-
from codeflash.cli_cmds.console import console, logger
10+
from codeflash.cli_cmds.console import logger
1111
from codeflash.code_utils.code_utils import exit_with_message
1212
from codeflash.code_utils.formatter import format_code
1313
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
14+
from codeflash.lsp.helpers import is_LSP_enabled
1415

1516

1617
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
@@ -34,11 +35,12 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
3435

3536
@lru_cache(maxsize=1)
3637
def get_codeflash_api_key() -> str:
37-
if console.quiet: # lsp
38-
# prefer shell config over env var in lsp mode
39-
api_key = read_api_key_from_shell_config()
40-
else:
41-
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
38+
# prefer shell config over env var in lsp mode
39+
api_key = (
40+
read_api_key_from_shell_config()
41+
if is_LSP_enabled()
42+
else os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
43+
)
4244

4345
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa
4446
if not api_key:
@@ -125,11 +127,6 @@ def is_ci() -> bool:
125127
return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS"))
126128

127129

128-
@lru_cache(maxsize=1)
129-
def is_LSP_enabled() -> bool:
130-
return console.quiet
131-
132-
133130
def is_pr_draft() -> bool:
134131
"""Check if the PR is draft. in the github action context."""
135132
event = get_cached_gh_event_data()

0 commit comments

Comments
 (0)