Skip to content

Commit aa33d77

Browse files
authored
Merge branch 'main' into benchmark-fixture
2 parents 5220b5e + 62a4575 commit aa33d77

20 files changed

+1657
-200
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ https://github.com/user-attachments/assets/38f44f4e-be1c-4f84-8db9-63d5ee3e61e5
6565

6666
Join our community for support and discussions. If you have any questions, feel free to reach out to us using one of the following methods:
6767

68+
- [Free live Installation Support](https://calendly.com/codeflash-saurabh/codeflash-setup)
6869
- [Join our Discord](https://www.codeflash.ai/discord)
6970
- [Follow us on Twitter](https://x.com/codeflashAI)
7071
- [Follow us on Linkedin](https://www.linkedin.com/in/saurabh-misra/)
71-
- [Email founders](mailto:[email protected])
7272

7373
## License
7474

codeflash/LICENSE

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ Business Source License 1.1
33
Parameters
44

55
Licensor: CodeFlash Inc.
6-
Licensed Work: Codeflash Client version 0.13.x
6+
Licensed Work: Codeflash Client version 0.14.x
77
The Licensed Work is (c) 2024 CodeFlash Inc.
88

99
Additional Use Grant: None. Production use of the Licensed Work is only permitted
@@ -13,7 +13,7 @@ Additional Use Grant: None. Production use of the Licensed Work is only permitte
1313
Platform. Please visit codeflash.ai for further
1414
information.
1515

16-
Change Date: 2029-06-03
16+
Change Date: 2029-06-09
1717

1818
Change License: MIT
1919

codeflash/api/aiservice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def optimize_python_code( # noqa: D417
118118

119119
if response.status_code == 200:
120120
optimizations_json = response.json()["optimizations"]
121-
logger.info(f"Generated {len(optimizations_json)} candidates.")
121+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
122122
console.rule()
123123
end_time = time.perf_counter()
124124
logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.")
@@ -189,7 +189,7 @@ def optimize_python_code_line_profiler( # noqa: D417
189189

190190
if response.status_code == 200:
191191
optimizations_json = response.json()["optimizations"]
192-
logger.info(f"Generated {len(optimizations_json)} candidates.")
192+
logger.info(f"Generated {len(optimizations_json)} candidate optimizations.")
193193
console.rule()
194194
return [
195195
OptimizedCandidate(

codeflash/api/cfapi.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
from typing import TYPE_CHECKING, Any, Optional
99

10+
import git
1011
import requests
1112
import sentry_sdk
1213
from pydantic.json import pydantic_encoder
@@ -191,3 +192,35 @@ def get_blocklisted_functions() -> dict[str, set[str]] | dict[str, Any]:
191192
return {}
192193

193194
return {Path(k).name: {v.replace("()", "") for v in values} for k, values in content.items()}
195+
196+
197+
def is_function_being_optimized_again(
198+
owner: str, repo: str, pr_number: int, code_contexts: list[dict[str, str]]
199+
) -> Any: # noqa: ANN401
200+
"""Check if the function being optimized is being optimized again."""
201+
response = make_cfapi_request(
202+
"/is-already-optimized",
203+
"POST",
204+
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_contexts": code_contexts},
205+
)
206+
response.raise_for_status()
207+
return response.json()
208+
209+
210+
def add_code_context_hash(code_context_hash: str) -> None:
211+
"""Add code context to the DB cache."""
212+
pr_number = get_pr_number()
213+
if pr_number is None:
214+
return
215+
try:
216+
owner, repo = get_repo_owner_and_name()
217+
pr_number = get_pr_number()
218+
except git.exc.InvalidGitRepositoryError:
219+
return
220+
221+
if owner and repo and pr_number is not None:
222+
make_cfapi_request(
223+
"/add-code-hash",
224+
"POST",
225+
{"owner": owner, "repo": repo, "pr_number": pr_number, "code_hash": code_context_hash},
226+
)

codeflash/cli_cmds/console.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,34 @@ def code_print(code_str: str) -> None:
6666

6767

6868
@contextmanager
69-
def progress_bar(message: str, *, transient: bool = False) -> Generator[TaskID, None, None]:
70-
"""Display a progress bar with a spinner and elapsed time."""
71-
progress = Progress(
72-
SpinnerColumn(next(spinners)),
73-
*Progress.get_default_columns(),
74-
TimeElapsedColumn(),
75-
console=console,
76-
transient=transient,
77-
)
78-
task = progress.add_task(message, total=None)
79-
with progress:
80-
yield task
69+
def progress_bar(
70+
message: str, *, transient: bool = False, revert_to_print: bool = False
71+
) -> Generator[TaskID, None, None]:
72+
"""Display a progress bar with a spinner and elapsed time.
73+
74+
If revert_to_print is True, falls back to printing a single logger.info message
75+
instead of showing a progress bar.
76+
"""
77+
if revert_to_print:
78+
logger.info(message)
79+
80+
# Create a fake task ID since we still need to yield something
81+
class DummyTask:
82+
def __init__(self) -> None:
83+
self.id = 0
84+
85+
yield DummyTask().id
86+
else:
87+
progress = Progress(
88+
SpinnerColumn(next(spinners)),
89+
*Progress.get_default_columns(),
90+
TimeElapsedColumn(),
91+
console=console,
92+
transient=transient,
93+
)
94+
task = progress.add_task(message, total=None)
95+
with progress:
96+
yield task
8197

8298

8399
@contextmanager

codeflash/code_utils/config_consts.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
1010
COVERAGE_THRESHOLD = 60.0
1111
MIN_TESTCASE_PASSED_THRESHOLD = 6
12+
REPEAT_OPTIMIZATION_PROBABILITY = 0.1

codeflash/code_utils/git_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import sys
66
import tempfile
77
import time
8+
from functools import cache
89
from io import StringIO
910
from pathlib import Path
1011
from typing import TYPE_CHECKING
@@ -79,6 +80,7 @@ def get_git_remotes(repo: Repo) -> list[str]:
7980
return [remote.name for remote in repository.remotes]
8081

8182

83+
@cache
8284
def get_repo_owner_and_name(repo: Repo | None = None, git_remote: str | None = "origin") -> tuple[str, str]:
8385
remote_url = get_remote_url(repo, git_remote) # call only once
8486
remote_url = remote_url.removesuffix(".git") if remote_url.endswith(".git") else remote_url

codeflash/context/code_context_extractor.py

Lines changed: 116 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3+
import ast
4+
import hashlib
35
import os
46
from collections import defaultdict
57
from itertools import chain
6-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, cast
79

810
import libcst as cst
911

@@ -31,8 +33,8 @@
3133
def get_code_optimization_context(
3234
function_to_optimize: FunctionToOptimize,
3335
project_root_path: Path,
34-
optim_token_limit: int = 8000,
35-
testgen_token_limit: int = 8000,
36+
optim_token_limit: int = 16000,
37+
testgen_token_limit: int = 16000,
3638
) -> CodeOptimizationContext:
3739
# Get FunctionSource representation of helpers of FTO
3840
helpers_of_fto_dict, helpers_of_fto_list = get_function_sources_from_jedi(
@@ -73,6 +75,13 @@ def get_code_optimization_context(
7375
remove_docstrings=False,
7476
code_context_type=CodeContextType.READ_ONLY,
7577
)
78+
hashing_code_context = extract_code_markdown_context_from_files(
79+
helpers_of_fto_dict,
80+
helpers_of_helpers_dict,
81+
project_root_path,
82+
remove_docstrings=True,
83+
code_context_type=CodeContextType.HASHING,
84+
)
7685

7786
# Handle token limits
7887
final_read_writable_tokens = encoded_tokens_len(final_read_writable_code)
@@ -125,11 +134,15 @@ def get_code_optimization_context(
125134
testgen_context_code_tokens = encoded_tokens_len(testgen_context_code)
126135
if testgen_context_code_tokens > testgen_token_limit:
127136
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
137+
code_hash_context = hashing_code_context.markdown
138+
code_hash = hashlib.sha256(code_hash_context.encode("utf-8")).hexdigest()
128139

129140
return CodeOptimizationContext(
130141
testgen_context_code=testgen_context_code,
131142
read_writable_code=final_read_writable_code,
132143
read_only_context_code=read_only_context_code,
144+
hashing_code_context=code_hash_context,
145+
hashing_code_context_hash=code_hash,
133146
helper_functions=helpers_of_fto_list,
134147
preexisting_objects=preexisting_objects,
135148
)
@@ -309,8 +322,8 @@ def extract_code_markdown_context_from_files(
309322
logger.debug(f"Error while getting read-only code: {e}")
310323
continue
311324
if code_context.strip():
312-
code_context_with_imports = CodeString(
313-
code=add_needed_imports_from_module(
325+
if code_context_type != CodeContextType.HASHING:
326+
code_context = add_needed_imports_from_module(
314327
src_module_code=original_code,
315328
dst_module_code=code_context,
316329
src_path=file_path,
@@ -319,10 +332,9 @@ def extract_code_markdown_context_from_files(
319332
helper_functions=list(
320333
helpers_of_fto.get(file_path, set()) | helpers_of_helpers.get(file_path, set())
321334
),
322-
),
323-
file_path=file_path.relative_to(project_root_path),
324-
)
325-
code_context_markdown.code_strings.append(code_context_with_imports)
335+
)
336+
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
337+
code_context_markdown.code_strings.append(code_string_context)
326338
# Extract code from file paths containing helpers of helpers
327339
for file_path, helper_function_sources in helpers_of_helpers_no_overlap.items():
328340
try:
@@ -343,18 +355,17 @@ def extract_code_markdown_context_from_files(
343355
continue
344356

345357
if code_context.strip():
346-
code_context_with_imports = CodeString(
347-
code=add_needed_imports_from_module(
358+
if code_context_type != CodeContextType.HASHING:
359+
code_context = add_needed_imports_from_module(
348360
src_module_code=original_code,
349361
dst_module_code=code_context,
350362
src_path=file_path,
351363
dst_path=file_path,
352364
project_root=project_root_path,
353365
helper_functions=list(helpers_of_helpers_no_overlap.get(file_path, set())),
354-
),
355-
file_path=file_path.relative_to(project_root_path),
356-
)
357-
code_context_markdown.code_strings.append(code_context_with_imports)
366+
)
367+
code_string_context = CodeString(code=code_context, file_path=file_path.relative_to(project_root_path))
368+
code_context_markdown.code_strings.append(code_string_context)
358369
return code_context_markdown
359370

360371

@@ -492,13 +503,18 @@ def parse_code_and_prune_cst(
492503
filtered_node, found_target = prune_cst_for_testgen_code(
493504
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
494505
)
506+
elif code_context_type == CodeContextType.HASHING:
507+
filtered_node, found_target = prune_cst_for_code_hashing(module, target_functions)
495508
else:
496509
raise ValueError(f"Unknown code_context_type: {code_context_type}") # noqa: EM102
497510

498511
if not found_target:
499512
raise ValueError("No target functions found in the provided code")
500513
if filtered_node and isinstance(filtered_node, cst.Module):
501-
return str(filtered_node.code)
514+
code = str(filtered_node.code)
515+
if code_context_type == CodeContextType.HASHING:
516+
code = ast.unparse(ast.parse(code)) # Makes it standard
517+
return code
502518
return ""
503519

504520

@@ -583,6 +599,90 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
583599
return (node.with_changes(**updates) if updates else node), True
584600

585601

602+
def prune_cst_for_code_hashing( # noqa: PLR0911
603+
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
604+
) -> tuple[cst.CSTNode | None, bool]:
605+
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.
606+
607+
Returns
608+
-------
609+
(filtered_node, found_target):
610+
filtered_node: The modified CST node or None if it should be removed.
611+
found_target: True if a target function was found in this node's subtree.
612+
613+
"""
614+
if isinstance(node, (cst.Import, cst.ImportFrom)):
615+
return None, False
616+
617+
if isinstance(node, cst.FunctionDef):
618+
qualified_name = f"{prefix}.{node.name.value}" if prefix else node.name.value
619+
if qualified_name in target_functions:
620+
new_body = remove_docstring_from_body(node.body) if isinstance(node.body, cst.IndentedBlock) else node.body
621+
return node.with_changes(body=new_body), True
622+
return None, False
623+
624+
if isinstance(node, cst.ClassDef):
625+
# Do not recurse into nested classes
626+
if prefix:
627+
return None, False
628+
# Assuming always an IndentedBlock
629+
if not isinstance(node.body, cst.IndentedBlock):
630+
raise ValueError("ClassDef body is not an IndentedBlock") # noqa: TRY004
631+
class_prefix = f"{prefix}.{node.name.value}" if prefix else node.name.value
632+
new_class_body: list[cst.CSTNode] = []
633+
found_target = False
634+
635+
for stmt in node.body.body:
636+
if isinstance(stmt, cst.FunctionDef):
637+
qualified_name = f"{class_prefix}.{stmt.name.value}"
638+
if qualified_name in target_functions:
639+
stmt_with_changes = stmt.with_changes(
640+
body=remove_docstring_from_body(cast("cst.IndentedBlock", stmt.body))
641+
)
642+
new_class_body.append(stmt_with_changes)
643+
found_target = True
644+
# If no target functions found, remove the class entirely
645+
if not new_class_body or not found_target:
646+
return None, False
647+
return node.with_changes(
648+
body=cst.IndentedBlock(cast("list[cst.BaseStatement]", new_class_body))
649+
) if new_class_body else None, found_target
650+
651+
# For other nodes, we preserve them only if they contain target functions in their children.
652+
section_names = get_section_names(node)
653+
if not section_names:
654+
return node, False
655+
656+
updates: dict[str, list[cst.CSTNode] | cst.CSTNode] = {}
657+
found_any_target = False
658+
659+
for section in section_names:
660+
original_content = getattr(node, section, None)
661+
if isinstance(original_content, (list, tuple)):
662+
new_children = []
663+
section_found_target = False
664+
for child in original_content:
665+
filtered, found_target = prune_cst_for_code_hashing(child, target_functions, prefix)
666+
if filtered:
667+
new_children.append(filtered)
668+
section_found_target |= found_target
669+
670+
if section_found_target:
671+
found_any_target = True
672+
updates[section] = new_children
673+
elif original_content is not None:
674+
filtered, found_target = prune_cst_for_code_hashing(original_content, target_functions, prefix)
675+
if found_target:
676+
found_any_target = True
677+
if filtered:
678+
updates[section] = filtered
679+
680+
if not found_any_target:
681+
return None, False
682+
683+
return (node.with_changes(**updates) if updates else node), True
684+
685+
586686
def prune_cst_for_read_only_code( # noqa: PLR0911
587687
node: cst.CSTNode,
588688
target_functions: set[str],

0 commit comments

Comments
 (0)