Skip to content

Commit 283524d

Browse files
Merge branch 'main' into lsp/move-config-suggestion-default-values-to-the-client
2 parents 8a0a866 + 848faa5 commit 283524d

File tree

14 files changed

+1490
-860
lines changed

14 files changed

+1490
-860
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -249,25 +249,29 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path)
249249

250250

251251
def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
252-
if hasattr(args, "all"):
253-
import git
254-
255-
from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
256-
from codeflash.code_utils.github_utils import require_github_app_or_exit
257-
258-
# Ensure that the user can actually open PRs on the repo.
259-
try:
260-
git_repo = git.Repo(search_parent_directories=True)
261-
except git.exc.InvalidGitRepositoryError:
262-
logger.exception(
263-
"I couldn't find a git repository in the current directory. "
264-
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
265-
)
266-
apologize_and_exit()
267-
if not args.no_pr and not check_and_push_branch(git_repo, git_remote=args.git_remote):
268-
exit_with_message("Branch is not pushed...", error_on_exit=True)
269-
owner, repo = get_repo_owner_and_name(git_repo)
270-
if not args.no_pr:
252+
if hasattr(args, "all") or (hasattr(args, "file") and args.file):
253+
no_pr = getattr(args, "no_pr", False)
254+
255+
if not no_pr:
256+
import git
257+
258+
from codeflash.code_utils.git_utils import check_and_push_branch, get_repo_owner_and_name
259+
from codeflash.code_utils.github_utils import require_github_app_or_exit
260+
261+
# Ensure that the user can actually open PRs on the repo.
262+
try:
263+
git_repo = git.Repo(search_parent_directories=True)
264+
except git.exc.InvalidGitRepositoryError:
265+
mode = "--all" if hasattr(args, "all") else "--file"
266+
logger.exception(
267+
f"I couldn't find a git repository in the current directory. "
268+
f"I need a git repository to run {mode} and open PRs for optimizations. Exiting..."
269+
)
270+
apologize_and_exit()
271+
git_remote = getattr(args, "git_remote", None)
272+
if not check_and_push_branch(git_repo, git_remote=git_remote):
273+
exit_with_message("Branch is not pushed...", error_on_exit=True)
274+
owner, repo = get_repo_owner_and_name(git_repo)
271275
require_github_app_or_exit(owner, repo)
272276
if not hasattr(args, "all"):
273277
args.all = None

codeflash/code_utils/config_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def parse_config_file(
105105
if lsp_mode:
106106
# don't fail in lsp mode if codeflash config is not found.
107107
return {}, config_file_path
108-
msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to create the config file."
108+
msg = f"Could not find the 'codeflash' block in the config file {config_file_path}. Please run 'codeflash init' to add Codeflash config in the pyproject.toml config file."
109109
raise ValueError(msg) from e
110110
assert isinstance(config, dict)
111111

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -684,27 +684,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
684684
)
685685

686686

687-
def instrument_source_module_with_async_decorators(
688-
source_path: Path, function_to_optimize: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
689-
) -> tuple[bool, str | None]:
690-
if not function_to_optimize.is_async:
691-
return False, None
692-
693-
try:
694-
with source_path.open(encoding="utf8") as f:
695-
source_code = f.read()
696-
697-
modified_code, decorator_added = add_async_decorator_to_function(source_code, function_to_optimize, mode)
698-
699-
if decorator_added:
700-
return True, modified_code
701-
702-
except Exception:
703-
return False, None
704-
else:
705-
return False, None
706-
707-
708687
def inject_async_profiling_into_existing_test(
709688
test_path: Path,
710689
call_positions: list[CodePosition],
@@ -1288,25 +1267,29 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
12881267

12891268

12901269
def add_async_decorator_to_function(
1291-
source_code: str, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1292-
) -> tuple[str, bool]:
1293-
"""Add async decorator to an async function definition.
1270+
source_path: Path, function: FunctionToOptimize, mode: TestingMode = TestingMode.BEHAVIOR
1271+
) -> bool:
1272+
"""Add async decorator to an async function definition and write back to file.
12941273
12951274
Args:
12961275
----
1297-
source_code: The source code to modify.
1276+
source_path: Path to the source file to modify in-place.
12981277
function: The FunctionToOptimize object representing the target async function.
12991278
mode: The testing mode to determine which decorator to apply.
13001279
13011280
Returns:
13021281
-------
1303-
Tuple of (modified_source_code, was_decorator_added).
1282+
Boolean indicating whether the decorator was successfully added.
13041283
13051284
"""
13061285
if not function.is_async:
1307-
return source_code, False
1286+
return False
13081287

13091288
try:
1289+
# Read source code
1290+
with source_path.open(encoding="utf8") as f:
1291+
source_code = f.read()
1292+
13101293
module = cst.parse_module(source_code)
13111294

13121295
# Add the decorator to the function
@@ -1318,10 +1301,17 @@ def add_async_decorator_to_function(
13181301
import_transformer = AsyncDecoratorImportAdder(mode)
13191302
module = module.visit(import_transformer)
13201303

1321-
return sort_imports(code=module.code, float_to_top=True), decorator_transformer.added_decorator
1304+
modified_code = sort_imports(code=module.code, float_to_top=True)
13221305
except Exception as e:
13231306
logger.exception(f"Error adding async decorator to function {function.qualified_name}: {e}")
1324-
return source_code, False
1307+
return False
1308+
else:
1309+
if decorator_transformer.added_decorator:
1310+
with source_path.open("w", encoding="utf8") as f:
1311+
f.write(modified_code)
1312+
logger.debug(f"Applied async {mode.value} instrumentation to {source_path}")
1313+
return True
1314+
return False
13251315

13261316

13271317
def create_instrumented_source_module_path(source_path: Path, temp_dir: Path) -> Path:

codeflash/context/unused_definition_remover.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -469,22 +469,32 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
469469
qualified_function_names: Set of function names to keep. For methods, use format 'classname.methodname'
470470
471471
"""
472-
module = cst.parse_module(code)
473-
# Collect all definitions (top level classes, variables or function)
474-
definitions = collect_top_level_definitions(module)
472+
try:
473+
module = cst.parse_module(code)
474+
except Exception as e:
475+
logger.debug(f"Failed to parse code with libcst: {type(e).__name__}: {e}")
476+
return code
475477

476-
# Collect dependencies between definitions using the visitor pattern
477-
dependency_collector = DependencyCollector(definitions)
478-
module.visit(dependency_collector)
478+
try:
479+
# Collect all definitions (top level classes, variables or function)
480+
definitions = collect_top_level_definitions(module)
479481

480-
# Mark definitions used by specified functions, and their dependencies recursively
481-
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
482-
usage_marker.mark_used_definitions()
482+
# Collect dependencies between definitions using the visitor pattern
483+
dependency_collector = DependencyCollector(definitions)
484+
module.visit(dependency_collector)
483485

484-
# Apply the recursive removal transformation
485-
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
486+
# Mark definitions used by specified functions, and their dependencies recursively
487+
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
488+
usage_marker.mark_used_definitions()
486489

487-
return modified_module.code if modified_module else ""
490+
# Apply the recursive removal transformation
491+
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
492+
493+
return modified_module.code if modified_module else "" # noqa: TRY300
494+
except Exception as e:
495+
# If any other error occurs during processing, return the original code
496+
logger.debug(f"Error processing code to remove unused definitions: {type(e).__name__}: {e}")
497+
return code
488498

489499

490500
def print_definitions(definitions: dict[str, UsageInfo]) -> None:

codeflash/discovery/functions_to_optimize.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def get_functions_to_optimize(
201201
elif file is not None:
202202
logger.info("!lsp|Finding all functions in the file '%s'…", file)
203203
console.rule()
204-
functions = find_all_functions_in_file(file)
204+
functions: dict[Path, list[FunctionToOptimize]] = find_all_functions_in_file(file)
205205
if only_get_this_function is not None:
206206
split_function = only_get_this_function.split(".")
207207
if len(split_function) > 2:
@@ -224,8 +224,16 @@ def get_functions_to_optimize(
224224
if found_function is None:
225225
if is_lsp:
226226
return functions, 0, None
227+
found = closest_matching_file_function_name(only_get_this_function, functions)
228+
if found is not None:
229+
file, found_function = found
230+
exit_with_message(
231+
f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property.\n"
232+
f"Did you mean {found_function.qualified_name} instead?"
233+
)
234+
227235
exit_with_message(
228-
f"Function {only_function_name} not found in file {file}\nor the function does not have a 'return' statement or is a property"
236+
f"Function {only_get_this_function} not found in file {file}\nor the function does not have a 'return' statement or is a property"
229237
)
230238
functions[file] = [found_function]
231239
else:
@@ -259,6 +267,76 @@ def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[F
259267
return get_functions_within_lines(modified_lines)
260268

261269

270+
def closest_matching_file_function_name(
271+
qualified_fn_to_find: str, found_fns: dict[Path, list[FunctionToOptimize]]
272+
) -> tuple[Path, FunctionToOptimize] | None:
273+
"""Find the closest matching function name using Levenshtein distance.
274+
275+
Args:
276+
qualified_fn_to_find: Function name to find in format "Class.function" or "function"
277+
found_fns: Dictionary of file paths to list of functions
278+
279+
Returns:
280+
Tuple of (file_path, function) for closest match, or None if no matches found
281+
282+
"""
283+
min_distance = 4
284+
closest_match = None
285+
closest_file = None
286+
287+
qualified_fn_to_find_lower = qualified_fn_to_find.lower()
288+
289+
# Cache levenshtein_distance locally for improved lookup speed
290+
_levenshtein = levenshtein_distance
291+
292+
for file_path, functions in found_fns.items():
293+
for function in functions:
294+
# Compare either full qualified name or just function name
295+
fn_name = function.qualified_name.lower()
296+
# If the absolute length difference is already >= min_distance, skip calculation
297+
if abs(len(qualified_fn_to_find_lower) - len(fn_name)) >= min_distance:
298+
continue
299+
dist = _levenshtein(qualified_fn_to_find_lower, fn_name)
300+
301+
if dist < min_distance:
302+
min_distance = dist
303+
closest_match = function
304+
closest_file = file_path
305+
306+
if closest_match is not None:
307+
return closest_file, closest_match
308+
return None
309+
310+
311+
def levenshtein_distance(s1: str, s2: str) -> int:
312+
if len(s1) > len(s2):
313+
s1, s2 = s2, s1
314+
len1 = len(s1)
315+
len2 = len(s2)
316+
# Use a preallocated list instead of creating a new list every iteration
317+
previous = list(range(len1 + 1))
318+
current = [0] * (len1 + 1)
319+
320+
for index2 in range(len2):
321+
char2 = s2[index2]
322+
current[0] = index2 + 1
323+
for index1 in range(len1):
324+
char1 = s1[index1]
325+
if char1 == char2:
326+
current[index1 + 1] = previous[index1]
327+
else:
328+
# Fast min calculation without tuple construct
329+
a = previous[index1]
330+
b = previous[index1 + 1]
331+
c = current[index1]
332+
min_val = min(b, a)
333+
min_val = min(c, min_val)
334+
current[index1 + 1] = 1 + min_val
335+
# Swap references instead of copying
336+
previous, current = current, previous
337+
return previous[len1]
338+
339+
262340
def get_functions_inside_a_commit(commit_hash: str) -> dict[str, list[FunctionToOptimize]]:
263341
modified_lines: dict[str, list[int]] = get_git_diff(only_this_commit=commit_hash)
264342
return get_functions_within_lines(modified_lines)

codeflash/github/PrComment.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ class PrComment:
2121
winning_behavior_test_results: TestResults
2222
winning_benchmarking_test_results: TestResults
2323
benchmark_details: Optional[list[BenchmarkDetail]] = None
24+
original_async_throughput: Optional[int] = None
25+
best_async_throughput: Optional[int] = None
2426

25-
def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str, Optional[list[BenchmarkDetail]]]]:
27+
def to_json(self) -> dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]]:
2628
report_table = {
2729
test_type.to_name(): result
2830
for test_type, result in self.winning_behavior_test_results.get_test_pass_fail_report_by_type().items()
2931
if test_type.to_name()
3032
}
3133

32-
return {
34+
result: dict[str, Union[str, int, dict[str, dict[str, int]], list[BenchmarkDetail], None]] = {
3335
"optimization_explanation": self.optimization_explanation,
3436
"best_runtime": humanize_runtime(self.best_runtime),
3537
"original_runtime": humanize_runtime(self.original_runtime),
@@ -42,6 +44,12 @@ def to_json(self) -> dict[str, Union[dict[str, dict[str, int]], int, str, Option
4244
"benchmark_details": self.benchmark_details if self.benchmark_details else None,
4345
}
4446

47+
if self.original_async_throughput is not None and self.best_async_throughput is not None:
48+
result["original_async_throughput"] = str(self.original_async_throughput)
49+
result["best_async_throughput"] = str(self.best_async_throughput)
50+
51+
return result
52+
4553

4654
class FileDiffContent(BaseModel):
4755
oldContent: str # noqa: N815

0 commit comments

Comments
 (0)