diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index aae2b7bee..63f811cf9 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -36,13 +36,46 @@ def get_optimizable_functions( server: CodeflashLanguageServer, params: OptimizableFunctionsParams ) -> dict[str, list[str]]: file_path = Path(uris.to_fs_path(params.textDocument.uri)) - server.optimizer.args.file = file_path - server.optimizer.args.previous_checkpoint_functions = False - optimizable_funcs, _ = server.optimizer.get_optimizable_functions() - path_to_qualified_names = {} - for path, functions in optimizable_funcs.items(): - path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions] - return path_to_qualified_names + server.show_message_log(f"Getting optimizable functions for: {file_path}", "Info") + + # Save original args to restore later + original_file = getattr(server.optimizer.args, "file", None) + original_function = getattr(server.optimizer.args, "function", None) + original_checkpoint = getattr(server.optimizer.args, "previous_checkpoint_functions", None) + + server.show_message_log(f"Original args - file: {original_file}, function: {original_function}", "Info") + + try: + # Set temporary args for this request only + server.optimizer.args.file = file_path + server.optimizer.args.function = None # Always get ALL functions, not just one + server.optimizer.args.previous_checkpoint_functions = False + + server.show_message_log("Calling get_optimizable_functions...", "Info") + optimizable_funcs, _ = server.optimizer.get_optimizable_functions() + + path_to_qualified_names = {} + for path, functions in optimizable_funcs.items(): + path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions] + + server.show_message_log( + f"Found {len(path_to_qualified_names)} files with functions: {path_to_qualified_names}", "Info" + ) + return path_to_qualified_names + finally: + # Restore original args to prevent state corruption + if original_file is not None: + server.optimizer.args.file = original_file + if original_function is not None: + server.optimizer.args.function = original_function + else: + server.optimizer.args.function = None + if original_checkpoint is not None: + server.optimizer.args.previous_checkpoint_functions = original_checkpoint + + server.show_message_log( + f"Restored args - file: {server.optimizer.args.file}, function: {server.optimizer.args.function}", "Info" + ) @server.feature("initializeFunctionOptimization") @@ -50,13 +83,24 @@ def initialize_function_optimization( server: CodeflashLanguageServer, params: FunctionOptimizationParams ) -> dict[str, str]: file_path = Path(uris.to_fs_path(params.textDocument.uri)) + server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info") + + # IMPORTANT: Store the specific function for optimization, but don't corrupt global state server.optimizer.args.function = params.functionName server.optimizer.args.file = file_path + + server.show_message_log( + f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info" + ) + optimizable_funcs, _ = server.optimizer.get_optimizable_functions() if not optimizable_funcs: + server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning") return {"functionName": params.functionName, "status": "not found", "args": None} + fto = optimizable_funcs.popitem()[1][0] server.optimizer.current_function_being_optimized = fto + server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info") return {"functionName": params.functionName, "status": "success"} @@ -136,11 +180,20 @@ def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimization @server.feature("performFunctionOptimization") -def perform_function_optimization( +def perform_function_optimization( # noqa: PLR0911 server: CodeflashLanguageServer, params: FunctionOptimizationParams ) -> dict[str, str]: + server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") current_function = server.optimizer.current_function_being_optimized + if not current_function: + server.show_message_log(f"No current function being optimized for {params.functionName}", "Error") + return { + "functionName": params.functionName, + "status": "error", + "message": "No function currently being optimized", + } + module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) validated_original_code, original_module_ast = module_prep_result @@ -214,6 +267,9 @@ def perform_function_optimization( ) if not best_optimization: + server.show_message_log( + f"No best optimizations found for function {function_to_optimize_qualified_name}", "Warning" + ) return { "functionName": params.functionName, "status": "error", @@ -221,12 +277,19 @@ def perform_function_optimization( } optimized_source = best_optimization.candidate.source_code + speedup = original_code_baseline.runtime / best_optimization.runtime + + server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") + + # CRITICAL: Clear the function filter after optimization to prevent state corruption + server.optimizer.args.function = None + server.show_message_log("Cleared function filter to prevent state corruption", "Info") return { "functionName": params.functionName, "status": "success", "message": "Optimization completed successfully", - "extra": f"Speedup: {original_code_baseline.runtime / best_optimization.runtime:.2f}x faster", + "extra": f"Speedup: {speedup:.2f}x faster", "optimization": optimized_source, }