Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 72 additions & 9 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,71 @@ def get_optimizable_functions(
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
) -> dict[str, list[str]]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.optimizer.args.file = file_path
server.optimizer.args.previous_checkpoint_functions = False
optimizable_funcs, _ = server.optimizer.get_optimizable_functions()
path_to_qualified_names = {}
for path, functions in optimizable_funcs.items():
path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions]
return path_to_qualified_names
server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info")

# Save original args to restore later
original_file = getattr(server.optimizer.args, "file", None)
original_function = getattr(server.optimizer.args, "function", None)
original_checkpoint = getattr(server.optimizer.args, "previous_checkpoint_functions", None)

server.show_message_log(f"Original args - file: {original_file}, function: {original_function}", "Info")

try:
# Set temporary args for this request only
server.optimizer.args.file = file_path
server.optimizer.args.function = None # Always get ALL functions, not just one
server.optimizer.args.previous_checkpoint_functions = False

server.show_message_log("Calling get_optimizable_functions...", "Info")
optimizable_funcs, _ = server.optimizer.get_optimizable_functions()

path_to_qualified_names = {}
for path, functions in optimizable_funcs.items():
path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions]

server.show_message_log(
f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info"
)
return path_to_qualified_names
finally:
# Restore original args to prevent state corruption
if original_file is not None:
server.optimizer.args.file = original_file
if original_function is not None:
server.optimizer.args.function = original_function
else:
server.optimizer.args.function = None
if original_checkpoint is not None:
server.optimizer.args.previous_checkpoint_functions = original_checkpoint

server.show_message_log(
f"Restored args - file: {server.optimizer.args.file}, function: {server.optimizer.args.function}", "Info"
)


@server.feature("initializeFunctionOptimization")
def initialize_function_optimization(
server: CodeflashLanguageServer, params: FunctionOptimizationParams
) -> dict[str, str]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")

# IMPORTANT: Store the specific function for optimization, but don't corrupt global state
server.optimizer.args.function = params.functionName
server.optimizer.args.file = file_path

server.show_message_log(
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)

optimizable_funcs, _ = server.optimizer.get_optimizable_functions()
if not optimizable_funcs:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
return {"functionName": params.functionName, "status": "not found", "args": None}

fto = optimizable_funcs.popitem()[1][0]
server.optimizer.current_function_being_optimized = fto
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
return {"functionName": params.functionName, "status": "success"}


Expand Down Expand Up @@ -136,11 +180,20 @@ def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimization


@server.feature("performFunctionOptimization")
def perform_function_optimization(
def perform_function_optimization( # noqa: PLR0911
server: CodeflashLanguageServer, params: FunctionOptimizationParams
) -> dict[str, str]:
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
current_function = server.optimizer.current_function_being_optimized

if not current_function:
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
return {
"functionName": params.functionName,
"status": "error",
"message": "No function currently being optimized",
}

module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)

validated_original_code, original_module_ast = module_prep_result
Expand Down Expand Up @@ -214,19 +267,29 @@ def perform_function_optimization(
)

if not best_optimization:
server.show_message_log(
f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning"
)
return {
"functionName": params.functionName,
"status": "error",
"message": f"No best optimizations found for function {function_to_optimize_qualified_name}",
}

optimized_source = best_optimization.candidate.source_code
speedup = original_code_baseline.runtime / best_optimization.runtime

server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")

# CRITICAL: Clear the function filter after optimization to prevent state corruption
server.optimizer.args.function = None
server.show_message_log("Cleared function filter to prevent state corruption", "Info")

return {
"functionName": params.functionName,
"status": "success",
"message": "Optimization completed successfully",
"extra": f"Speedup: {original_code_baseline.runtime / best_optimization.runtime:.2f}x faster",
"extra": f"Speedup: {speedup:.2f}x faster",
"optimization": optimized_source,
}

Expand Down
Loading