diff --git a/pyproject.toml b/pyproject.toml index a6d5c80..eaf764b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,20 +5,18 @@ build-backend = "setuptools.build_meta" [project] name = "weco" -authors = [ - {name = "Weco AI Team", email = "contact@weco.ai"}, -] +authors = [{ name = "Weco AI Team", email = "contact@weco.ai" }] description = "Documentation for `weco`, a CLI for using Weco AI's code optimizer." readme = "README.md" -version = "0.2.17" -license = {text = "MIT"} +version = "0.2.18" +license = { text = "MIT" } requires-python = ">=3.8" dependencies = ["requests", "rich"] keywords = ["AI", "Code Optimization", "Code Generation"] classifiers = [ "Programming Language :: Python :: 3", "Operating System :: OS Independent", - "License :: OSI Approved :: MIT License" + "License :: OSI Approved :: MIT License", ] [project.scripts] diff --git a/weco/__init__.py b/weco/__init__.py index fc06bd0..a5595f0 100644 --- a/weco/__init__.py +++ b/weco/__init__.py @@ -1,7 +1,8 @@ import os +import importlib.metadata # DO NOT EDIT -__pkg_version__ = "0.2.17" +__pkg_version__ = importlib.metadata.version("weco") __api_version__ = "v1" __base_url__ = f"https://api.weco.ai/{__api_version__}" diff --git a/weco/api.py b/weco/api.py index 99646ce..5d709d3 100644 --- a/weco/api.py +++ b/weco/api.py @@ -1,14 +1,20 @@ -from typing import Dict, Any +from typing import Dict, Any, Optional import rich import requests from weco import __pkg_version__, __base_url__ import sys +from rich.console import Console def handle_api_error(e: requests.exceptions.HTTPError, console: rich.console.Console) -> None: """Extract and display error messages from API responses in a structured format.""" - console.print(f"[bold red]{e.response.json()['detail']}[/]") - sys.exit(1) + try: + detail = e.response.json()["detail"] + except (ValueError, KeyError): # Handle cases where response is not JSON or detail key is missing + detail = f"HTTP {e.response.status_code} Error: {e.response.text}" + console.print(f"[bold red]{detail}[/]") + # Avoid exiting here, let the caller decide if the error is fatal + # sys.exit(1) def start_optimization_session( @@ -28,25 +34,32 @@ def start_optimization_session( ) -> Dict[str, Any]: """Start the optimization session.""" with console.status("[bold green]Starting Optimization..."): - response = requests.post( - f"{__base_url__}/sessions", # Path is relative to base_url - json={ - "source_code": source_code, - "additional_instructions": additional_instructions, - "objective": {"evaluation_command": evaluation_command, "metric_name": metric_name, "maximize": maximize}, - "optimizer": { - "steps": steps, - "code_generator": code_generator_config, - "evaluator": evaluator_config, - "search_policy": search_policy_config, + try: + response = requests.post( + f"{__base_url__}/sessions", # Path is relative to base_url + json={ + "source_code": source_code, + "additional_instructions": additional_instructions, + "objective": {"evaluation_command": evaluation_command, "metric_name": metric_name, "maximize": maximize}, + "optimizer": { + "steps": steps, + "code_generator": code_generator_config, + "evaluator": evaluator_config, + "search_policy": search_policy_config, + }, + "metadata": {"client_name": "cli", "client_version": __pkg_version__, **api_keys}, }, - "metadata": {"client_name": "cli", "client_version": __pkg_version__, **api_keys}, - }, - headers=auth_headers, # Add headers - timeout=timeout, - ) - response.raise_for_status() - return response.json() + headers=auth_headers, # Add headers + timeout=timeout, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + handle_api_error(e, console) + sys.exit(1) # Exit if starting session fails + except requests.exceptions.RequestException as e: + console.print(f"[bold red]Network Error starting session: {e}[/]") + sys.exit(1) def evaluate_feedback_then_suggest_next_solution( @@ -58,29 +71,92 @@ def evaluate_feedback_then_suggest_next_solution( timeout: int = 800, ) -> Dict[str, Any]: """Evaluate the feedback and suggest the next solution.""" - response = requests.post( - f"{__base_url__}/sessions/{session_id}/suggest", # Path is relative to base_url - json={ - "execution_output": execution_output, - "additional_instructions": additional_instructions, - "metadata": {**api_keys}, - }, - headers=auth_headers, # Add headers - timeout=timeout, - ) - response.raise_for_status() - return response.json() + try: + response = requests.post( + f"{__base_url__}/sessions/{session_id}/suggest", # Path is relative to base_url + json={ + "execution_output": execution_output, + "additional_instructions": additional_instructions, + "metadata": {**api_keys}, + }, + headers=auth_headers, # Add headers + timeout=timeout, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + # Allow caller to handle suggest errors, maybe retry or terminate + handle_api_error(e, Console()) # Use default console if none passed + raise # Re-raise the exception + except requests.exceptions.RequestException as e: + print(f"Network Error during suggest: {e}") # Use print as console might not be available + raise # Re-raise the exception def get_optimization_session_status( session_id: str, include_history: bool = False, auth_headers: dict = {}, timeout: int = 800 ) -> Dict[str, Any]: """Get the current status of the optimization session.""" - response = requests.get( - f"{__base_url__}/sessions/{session_id}", # Path is relative to base_url - params={"include_history": include_history}, - headers=auth_headers, - timeout=timeout, - ) - response.raise_for_status() - return response.json() + try: + response = requests.get( + f"{__base_url__}/sessions/{session_id}", # Path is relative to base_url + params={"include_history": include_history}, + headers=auth_headers, + timeout=timeout, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + handle_api_error(e, Console()) # Use default console + raise # Re-raise + except requests.exceptions.RequestException as e: + print(f"Network Error getting status: {e}") + raise # Re-raise + + +def send_heartbeat( + session_id: str, + auth_headers: dict = {}, + timeout: int = 10, # Shorter timeout for non-critical heartbeat +) -> bool: + """Send a heartbeat signal to the backend.""" + try: + response = requests.put(f"{__base_url__}/sessions/{session_id}/heartbeat", headers=auth_headers, timeout=timeout) + response.raise_for_status() # Raises HTTPError for bad responses (4xx or 5xx) + return True + except requests.exceptions.HTTPError as e: + # Log non-critical errors like 409 Conflict (session not running) + if e.response.status_code == 409: + print(f"Heartbeat ignored: Session {session_id} is not running.", file=sys.stderr) + else: + print(f"Heartbeat failed for session {session_id}: HTTP {e.response.status_code}", file=sys.stderr) + # Don't exit, just report failure + return False + except requests.exceptions.RequestException as e: + # Network errors are also non-fatal for heartbeats + print(f"Heartbeat network error for session {session_id}: {e}", file=sys.stderr) + return False + + +def report_termination( + session_id: str, + status_update: str, + reason: str, + details: Optional[str] = None, + auth_headers: dict = {}, + timeout: int = 30, # Reasonably longer timeout for important termination message +) -> bool: + """Report the termination reason to the backend.""" + try: + response = requests.post( + f"{__base_url__}/sessions/{session_id}/terminate", + json={"status_update": status_update, "termination_reason": reason, "termination_details": details}, + headers=auth_headers, + timeout=timeout, + ) + response.raise_for_status() + return True + except requests.exceptions.RequestException as e: + # Log failure, but don't prevent CLI exit + print(f"Warning: Failed to report termination to backend for session {session_id}: {e}", file=sys.stderr) + return False diff --git a/weco/cli.py b/weco/cli.py index 87b40d7..e5ac583 100644 --- a/weco/cli.py +++ b/weco/cli.py @@ -5,6 +5,9 @@ import time import requests import webbrowser +import threading +import signal +import traceback from rich.console import Console from rich.live import Live from rich.panel import Panel @@ -15,6 +18,8 @@ evaluate_feedback_then_suggest_next_solution, get_optimization_session_status, handle_api_error, + send_heartbeat, + report_termination, ) from . import __base_url__ @@ -42,6 +47,67 @@ install(show_locals=True) console = Console() +# --- Global variable for heartbeat thread --- +heartbeat_thread = None +stop_heartbeat_event = threading.Event() +current_session_id_for_heartbeat = None +current_auth_headers_for_heartbeat = {} + + +# --- Heartbeat Sender Class --- +class HeartbeatSender(threading.Thread): + def __init__(self, session_id: str, auth_headers: dict, stop_event: threading.Event, interval: int = 30): + super().__init__(daemon=True) # Daemon thread exits when main thread exits + self.session_id = session_id + self.auth_headers = auth_headers + self.interval = interval + self.stop_event = stop_event + + def run(self): + try: + while not self.stop_event.is_set(): + if not send_heartbeat(self.session_id, self.auth_headers): + # send_heartbeat itself prints errors to stderr if it returns False + # No explicit HeartbeatSender log needed here unless more detail is desired for a False return + pass # Continue trying as per original logic + + if self.stop_event.is_set(): # Check before waiting for responsiveness + break + + self.stop_event.wait(self.interval) # Wait for interval or stop signal + + except Exception as e: + # Catch any unexpected error in the loop to prevent silent thread death + print( + f"[ERROR HeartbeatSender] Unhandled exception in run loop for session {self.session_id}: {e}", file=sys.stderr + ) + traceback.print_exc(file=sys.stderr) + # The loop will break due to the exception, and thread will terminate via finally. + + +# --- Signal Handling --- +def signal_handler(signum, frame): + signal_name = signal.Signals(signum).name + console.print(f"\n[bold yellow]Termination signal ({signal_name}) received. Shutting down...[/]") + + # Stop heartbeat thread + stop_heartbeat_event.set() + if heartbeat_thread and heartbeat_thread.is_alive(): + heartbeat_thread.join(timeout=2) # Give it a moment to stop + + # Report termination (best effort) + if current_session_id_for_heartbeat: + report_termination( + session_id=current_session_id_for_heartbeat, + status_update="terminated", + reason=f"user_terminated_{signal_name.lower()}", + details=f"Process terminated by signal {signal_name} ({signum}).", + auth_headers=current_auth_headers_for_heartbeat, + ) + + # Exit gracefully + sys.exit(0) + def perform_login(console: Console): """Handles the device login flow.""" @@ -161,6 +227,10 @@ def perform_login(console: Console): def main() -> None: """Main function for the Weco CLI.""" + # Setup signal handlers + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + # --- Argument Parsing --- parser = argparse.ArgumentParser( description="[bold cyan]Weco CLI[/]", formatter_class=argparse.RawDescriptionHelpFormatter @@ -207,6 +277,10 @@ def main() -> None: # --- Handle Run Command --- elif args.command == "run": + global heartbeat_thread, current_session_id_for_heartbeat, current_auth_headers_for_heartbeat # Allow modification of globals + + session_id = None # Initialize session_id + optimization_completed_normally = False # Flag for finally block # --- Check Authentication --- weco_api_key = load_weco_api_key() llm_api_keys = read_api_keys_from_env() # Read keys from client environment @@ -238,10 +312,9 @@ def main() -> None: # --- Prepare API Call Arguments --- auth_headers = {} - if weco_api_key: auth_headers["Authorization"] = f"Bearer {weco_api_key}" - # Backend will decide whether to use client keys based on auth status + current_auth_headers_for_heartbeat = auth_headers # Store for signal handler # --- Main Run Logic --- try: @@ -289,16 +362,22 @@ def main() -> None: evaluator_config=evaluator_config, search_policy_config=search_policy_config, additional_instructions=additional_instructions, - api_keys=llm_api_keys, # Pass client LLM keys - auth_headers=auth_headers, # Pass Weco key if logged in + api_keys=llm_api_keys, + auth_headers=auth_headers, timeout=timeout, ) + session_id = session_response["session_id"] + current_session_id_for_heartbeat = session_id # Store for signal handler/finally + + # --- Start Heartbeat Thread --- + stop_heartbeat_event.clear() # Ensure event is clear before starting + heartbeat_thread = HeartbeatSender(session_id, auth_headers, stop_heartbeat_event) + heartbeat_thread.start() # --- Live Update Loop --- refresh_rate = 4 with Live(layout, refresh_per_second=refresh_rate, screen=True) as live: # Define the runs directory (.runs/) - session_id = session_response["session_id"] runs_dir = pathlib.Path(args.log_dir) / session_id runs_dir.mkdir(parents=True, exist_ok=True) @@ -358,6 +437,9 @@ def main() -> None: transition_delay=0.1, ) + # # Send initial heartbeat immediately after starting + # send_heartbeat(session_id, auth_headers) + # Run evaluation on the initial solution term_out = run_evaluation(eval_command=args.eval_command) @@ -386,6 +468,7 @@ def main() -> None: auth_headers=auth_headers, # Pass Weco key if logged in timeout=timeout, ) + # Save next solution (.runs//step_.) write_to_path( fp=runs_dir / f"step_{step}{source_fp.suffix}", content=eval_and_next_solution_response["code"] @@ -476,7 +559,7 @@ def main() -> None: additional_instructions=args.additional_instructions ) - # Ensure we pass evaluation results for the last step's generated solution + # Final evaluation report eval_and_next_solution_response = evaluate_feedback_then_suggest_next_solution( session_id=session_id, execution_output=term_out, @@ -555,14 +638,74 @@ def main() -> None: # write the best solution to the source file write_to_path(fp=source_fp, content=best_solution_content) + # Mark as completed normally for the finally block + optimization_completed_normally = True + console.print(end_optimization_layout) except Exception as e: + # Catch errors during the main optimization loop or setup try: - error_message = e.response.json()["detail"] + error_message = e.response.json()["detail"] # Try to get API error detail except Exception: - error_message = str(e) - console.print(Panel(f"[bold red]Error: {error_message}", title="[bold red]Error", border_style="red")) - # Print traceback for debugging + error_message = str(e) # Otherwise, use the exception string + console.print(Panel(f"[bold red]Error: {error_message}", title="[bold red]Optimization Error", border_style="red")) + # Print traceback for debugging if needed (can be noisy) # console.print_exception(show_locals=False) - sys.exit(1) + + # Ensure optimization_completed_normally is False + optimization_completed_normally = False + + # Prepare details for termination report + error_details = traceback.format_exc() + + # Exit code will be handled by finally block or sys.exit below + exit_code = 1 # Indicate error + # No sys.exit here, let finally block run + + finally: + # This block runs whether the try block completed normally or raised an exception + + # Stop heartbeat thread + stop_heartbeat_event.set() + if heartbeat_thread and heartbeat_thread.is_alive(): + heartbeat_thread.join(timeout=2) # Give it a moment to stop + + # Report final status if a session was started + if session_id: + final_status = "unknown" + final_reason = "unknown_termination" + final_details = None + + if optimization_completed_normally: + final_status = "completed" + final_reason = "completed_successfully" + else: + # If an exception was caught and we have details + if "error_details" in locals(): + final_status = "error" + final_reason = "error_cli_internal" + final_details = error_details + # else: # Should have been handled by signal handler if terminated by user + # Keep default 'unknown' if we somehow end up here without error/completion/signal + + # Avoid reporting if terminated by signal handler (already reported) + # Check a flag or rely on status not being 'unknown' + if final_status != "unknown": + report_termination( + session_id=session_id, + status_update=final_status, + reason=final_reason, + details=final_details, + auth_headers=auth_headers, + ) + + # Ensure proper exit code if an error occurred + if not optimization_completed_normally and "exit_code" in locals() and exit_code != 0: + sys.exit(exit_code) + elif not optimization_completed_normally: + # Generic error exit if no specific code was set but try block failed + sys.exit(1) + else: + # Normal exit + sys.exit(0) diff --git a/weco/panels.py b/weco/panels.py index f086aee..7d50d34 100644 --- a/weco/panels.py +++ b/weco/panels.py @@ -253,7 +253,7 @@ def get_display(self, is_done: bool) -> Panel: # Make sure the metric tree is built before calling build_rich_tree return Panel( self._build_rich_tree(), - title="[bold]🔎 Exploring Solutions..." if not is_done else "[bold]🔎 Optimization Complete!", + title=("[bold]🔎 Exploring Solutions..." if not is_done else "[bold]🔎 Optimization Complete!"), border_style="green", expand=True, padding=(0, 1),