From ebe5cefda7f06b242ca7f8898cfcff05add4090d Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Sat, 3 Jan 2026 00:26:38 +0200 Subject: [PATCH 01/14] feat: Add computer use tools with screenshot streaming and remote GUI control --- .infer/config.yaml | 28 ++ cmd/chat.go | 42 ++ cmd/root.go | 19 + config/config.go | 102 ++++ examples/computer-use/.env.example | 4 + examples/computer-use/.gitignore | 1 + examples/computer-use/.infer/.gitignore | 2 + examples/computer-use/.infer/config.yaml | 198 ++++++++ examples/computer-use/Dockerfile.ubuntu-gui | 103 +++++ examples/computer-use/README.md | 436 ++++++++++++++++++ examples/computer-use/docker-compose.yml | 122 +++++ go.mod | 4 + go.sum | 8 + internal/domain/events.go | 13 + internal/domain/interfaces.go | 64 +++ .../services/circular_screenshot_buffer.go | 194 ++++++++ internal/services/screenshot_server.go | 307 ++++++++++++ .../services/tools/computer_use_common.go | 123 +++++ .../services/tools/computer_use_wayland.go | 269 +++++++++++ internal/services/tools/computer_use_x11.go | 176 +++++++ .../services/tools/get_latest_screenshot.go | 179 +++++++ internal/services/tools/keyboard_type.go | 280 +++++++++++ internal/services/tools/mouse_click.go | 325 +++++++++++++ internal/services/tools/mouse_move.go | 247 ++++++++++ internal/services/tools/registry.go | 27 ++ internal/services/tools/screenshot.go | 424 +++++++++++++++++ internal/ui/autocomplete/autocomplete.go | 6 +- internal/web/pty_manager.go | 58 ++- internal/web/server.go | 176 +++++-- internal/web/session_manager.go | 71 ++- internal/web/ssh_session.go | 262 ++++++++++- internal/web/static/app.js | 66 +++ internal/web/static/screenshot-overlay.js | 227 +++++++++ internal/web/templates/index.html | 132 ++++++ 34 files changed, 4634 insertions(+), 61 deletions(-) create mode 100644 examples/computer-use/.env.example create mode 100644 examples/computer-use/.gitignore create mode 100644 examples/computer-use/.infer/.gitignore create mode 100644 examples/computer-use/.infer/config.yaml create mode 100644 examples/computer-use/Dockerfile.ubuntu-gui create mode 100644 examples/computer-use/README.md create mode 100644 examples/computer-use/docker-compose.yml create mode 100644 internal/services/circular_screenshot_buffer.go create mode 100644 internal/services/screenshot_server.go create mode 100644 internal/services/tools/computer_use_common.go create mode 100644 internal/services/tools/computer_use_wayland.go create mode 100644 internal/services/tools/computer_use_x11.go create mode 100644 internal/services/tools/get_latest_screenshot.go create mode 100644 internal/services/tools/keyboard_type.go create mode 100644 internal/services/tools/mouse_click.go create mode 100644 internal/services/tools/mouse_move.go create mode 100644 internal/services/tools/screenshot.go create mode 100644 internal/web/static/screenshot-overlay.js diff --git a/.infer/config.yaml b/.infer/config.yaml index d24ab414..090a283b 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -638,3 +638,31 @@ web: auto_install: true install_version: latest servers: [] +computer_use: + enabled: false + display: :0 + screenshot: + enabled: true + max_width: 1920 + max_height: 1080 + format: jpeg + quality: 80 + require_approval: false + streaming_enabled: false + capture_interval: 3 + buffer_size: 30 + temp_dir: /tmp/infer-screenshots + mouse_move: + enabled: true + require_approval: true + mouse_click: + enabled: true + require_approval: true + keyboard_type: + enabled: true + max_text_length: 1000 + require_approval: true + rate_limit: + enabled: true + max_actions_per_minute: 60 + window_seconds: 60 diff --git a/cmd/chat.go b/cmd/chat.go index 8db4f3c1..24cbeafb 100644 --- a/cmd/chat.go +++ b/cmd/chat.go @@ -13,6 +13,7 @@ import ( "time" tea "github.com/charmbracelet/bubbletea" + uuid "github.com/google/uuid" cobra "github.com/spf13/cobra" viper "github.com/spf13/viper" @@ -22,6 +23,8 @@ import ( container "github.com/inference-gateway/cli/internal/container" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" + screenshotsvc "github.com/inference-gateway/cli/internal/services" + tools "github.com/inference-gateway/cli/internal/services/tools" web "github.com/inference-gateway/cli/internal/web" sdk "github.com/inference-gateway/sdk" ) @@ -155,6 +158,23 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { agentManager := services.GetAgentManager() conversationOptimizer := services.GetConversationOptimizer() + var screenshotServer *screenshotsvc.ScreenshotServer + logger.Info("Checking screenshot streaming config", + "computer_use_enabled", config.ComputerUse.Enabled, + "screenshot_enabled", config.ComputerUse.Screenshot.Enabled, + "streaming_enabled", config.ComputerUse.Screenshot.StreamingEnabled) + + if config.ComputerUse.Enabled && config.ComputerUse.Screenshot.StreamingEnabled { + screenshotServer = startScreenshotServer(config, imageService, toolRegistry) + if screenshotServer != nil { + defer func() { + if err := screenshotServer.Stop(); err != nil { + logger.Error("Failed to stop screenshot server", "error", err) + } + }() + } + } + versionInfo := GetVersionInfo() application := app.NewChatApplication( models, @@ -369,6 +389,28 @@ func processStreamingOutput(events <-chan domain.ChatEvent) error { return nil } +// startScreenshotServer initializes and starts the screenshot streaming server +func startScreenshotServer(config *config.Config, imageService domain.ImageService, toolRegistry *tools.Registry) *screenshotsvc.ScreenshotServer { + logger.Info("Screenshot streaming conditions met, starting server") + sessionID := fmt.Sprintf("%d-%s", time.Now().Unix(), uuid.New().String()[:8]) + screenshotServer := screenshotsvc.NewScreenshotServer(config, imageService, sessionID) + + if err := screenshotServer.Start(); err != nil { + logger.Warn("Failed to start screenshot server", "error", err) + return nil + } + + fmt.Printf("• Screenshot API: http://localhost:%d\n", screenshotServer.Port()) + toolRegistry.SetScreenshotServer(screenshotServer) + logger.Info("Registered GetLatestScreenshot tool with tool registry") + + if os.Getenv("INFER_GATEWAY_MODE") == "remote" { + fmt.Printf("\x1b]5555;screenshot_port=%d\x07", screenshotServer.Port()) + } + + return screenshotServer +} + func init() { rootCmd.AddCommand(chatCmd) chatCmd.Flags().Bool("web", false, "Start web terminal interface") diff --git a/cmd/root.go b/cmd/root.go index 9bf77423..11322a0a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -91,6 +91,25 @@ func initConfig() { // nolint:funlen v.SetDefault("web.ssh.auto_install", defaults.Web.SSH.AutoInstall) v.SetDefault("web.ssh.install_version", defaults.Web.SSH.InstallVersion) v.SetDefault("web.servers", defaults.Web.Servers) + v.SetDefault("computer_use", defaults.ComputerUse) + v.SetDefault("computer_use.enabled", defaults.ComputerUse.Enabled) + v.SetDefault("computer_use.display", defaults.ComputerUse.Display) + v.SetDefault("computer_use.screenshot.enabled", defaults.ComputerUse.Screenshot.Enabled) + v.SetDefault("computer_use.screenshot.max_width", defaults.ComputerUse.Screenshot.MaxWidth) + v.SetDefault("computer_use.screenshot.max_height", defaults.ComputerUse.Screenshot.MaxHeight) + v.SetDefault("computer_use.screenshot.format", defaults.ComputerUse.Screenshot.Format) + v.SetDefault("computer_use.screenshot.quality", defaults.ComputerUse.Screenshot.Quality) + v.SetDefault("computer_use.screenshot.streaming_enabled", defaults.ComputerUse.Screenshot.StreamingEnabled) + v.SetDefault("computer_use.screenshot.capture_interval", defaults.ComputerUse.Screenshot.CaptureInterval) + v.SetDefault("computer_use.screenshot.buffer_size", defaults.ComputerUse.Screenshot.BufferSize) + v.SetDefault("computer_use.screenshot.temp_dir", defaults.ComputerUse.Screenshot.TempDir) + v.SetDefault("computer_use.mouse_move.enabled", defaults.ComputerUse.MouseMove.Enabled) + v.SetDefault("computer_use.mouse_click.enabled", defaults.ComputerUse.MouseClick.Enabled) + v.SetDefault("computer_use.keyboard_type.enabled", defaults.ComputerUse.KeyboardType.Enabled) + v.SetDefault("computer_use.keyboard_type.max_text_length", defaults.ComputerUse.KeyboardType.MaxTextLength) + v.SetDefault("computer_use.rate_limit.enabled", defaults.ComputerUse.RateLimit.Enabled) + v.SetDefault("computer_use.rate_limit.max_actions_per_minute", defaults.ComputerUse.RateLimit.MaxActionsPerMinute) + v.SetDefault("computer_use.rate_limit.window_seconds", defaults.ComputerUse.RateLimit.WindowSeconds) v.SetDefault("git", defaults.Git) v.SetDefault("storage", defaults.Storage) v.SetDefault("conversation", defaults.Conversation) diff --git a/config/config.go b/config/config.go index 512978b9..88575e87 100644 --- a/config/config.go +++ b/config/config.go @@ -42,6 +42,7 @@ type Config struct { Init InitConfig `yaml:"init" mapstructure:"init"` Compact CompactConfig `yaml:"compact" mapstructure:"compact"` Web WebConfig `yaml:"web" mapstructure:"web"` + ComputerUse ComputerUseConfig `yaml:"computer_use" mapstructure:"computer_use"` configDir string } @@ -247,6 +248,57 @@ type SandboxConfig struct { ProtectedPaths []string `yaml:"protected_paths" mapstructure:"protected_paths"` } +// ComputerUseConfig contains computer use tool settings +type ComputerUseConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + Display string `yaml:"display" mapstructure:"display"` + Screenshot ScreenshotToolConfig `yaml:"screenshot" mapstructure:"screenshot"` + MouseMove MouseMoveToolConfig `yaml:"mouse_move" mapstructure:"mouse_move"` + MouseClick MouseClickToolConfig `yaml:"mouse_click" mapstructure:"mouse_click"` + KeyboardType KeyboardTypeToolConfig `yaml:"keyboard_type" mapstructure:"keyboard_type"` + RateLimit RateLimitConfig `yaml:"rate_limit" mapstructure:"rate_limit"` +} + +// ScreenshotToolConfig contains screenshot-specific tool settings +type ScreenshotToolConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + MaxWidth int `yaml:"max_width" mapstructure:"max_width"` + MaxHeight int `yaml:"max_height" mapstructure:"max_height"` + Format string `yaml:"format" mapstructure:"format"` + Quality int `yaml:"quality" mapstructure:"quality"` + RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` + StreamingEnabled bool `yaml:"streaming_enabled" mapstructure:"streaming_enabled"` + CaptureInterval int `yaml:"capture_interval" mapstructure:"capture_interval"` // seconds + BufferSize int `yaml:"buffer_size" mapstructure:"buffer_size"` // number of screenshots + TempDir string `yaml:"temp_dir" mapstructure:"temp_dir"` // path for disk storage +} + +// MouseMoveToolConfig contains mouse move-specific tool settings +type MouseMoveToolConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` +} + +// MouseClickToolConfig contains mouse click-specific tool settings +type MouseClickToolConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` +} + +// KeyboardTypeToolConfig contains keyboard type-specific tool settings +type KeyboardTypeToolConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + MaxTextLength int `yaml:"max_text_length" mapstructure:"max_text_length"` + RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` +} + +// RateLimitConfig contains rate limiting settings +type RateLimitConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + MaxActionsPerMinute int `yaml:"max_actions_per_minute" mapstructure:"max_actions_per_minute"` + WindowSeconds int `yaml:"window_seconds" mapstructure:"window_seconds"` +} + // SafetyConfig contains safety approval settings type SafetyConfig struct { RequireApproval bool `yaml:"require_approval" mapstructure:"require_approval"` @@ -977,6 +1029,40 @@ Write the AGENTS.md file to the project root when you have gathered enough infor }, Servers: []SSHServerConfig{}, }, + ComputerUse: ComputerUseConfig{ + Enabled: false, // Security: disabled by default + Display: ":0", + Screenshot: ScreenshotToolConfig{ + Enabled: true, + MaxWidth: 1920, + MaxHeight: 1080, + Format: "jpeg", + Quality: 80, + RequireApproval: &[]bool{false}[0], + StreamingEnabled: false, + CaptureInterval: 3, + BufferSize: 30, + TempDir: "/tmp/infer-screenshots", + }, + MouseMove: MouseMoveToolConfig{ + Enabled: true, + RequireApproval: &[]bool{true}[0], + }, + MouseClick: MouseClickToolConfig{ + Enabled: true, + RequireApproval: &[]bool{true}[0], + }, + KeyboardType: KeyboardTypeToolConfig{ + Enabled: true, + MaxTextLength: 1000, + RequireApproval: &[]bool{true}[0], + }, + RateLimit: RateLimitConfig{ + Enabled: true, + MaxActionsPerMinute: 60, + WindowSeconds: 60, + }, + }, } } @@ -1045,6 +1131,22 @@ func (c *Config) IsApprovalRequired(toolName string) bool { // nolint:gocyclo,cy if c.A2A.Tools.SubmitTask.RequireApproval != nil { return *c.A2A.Tools.SubmitTask.RequireApproval } + case "Screenshot": + if c.ComputerUse.Screenshot.RequireApproval != nil { + return *c.ComputerUse.Screenshot.RequireApproval + } + case "MouseMove": + if c.ComputerUse.MouseMove.RequireApproval != nil { + return *c.ComputerUse.MouseMove.RequireApproval + } + case "MouseClick": + if c.ComputerUse.MouseClick.RequireApproval != nil { + return *c.ComputerUse.MouseClick.RequireApproval + } + case "KeyboardType": + if c.ComputerUse.KeyboardType.RequireApproval != nil { + return *c.ComputerUse.KeyboardType.RequireApproval + } } return globalApproval diff --git a/examples/computer-use/.env.example b/examples/computer-use/.env.example new file mode 100644 index 00000000..70b625a5 --- /dev/null +++ b/examples/computer-use/.env.example @@ -0,0 +1,4 @@ +ANTHROPIC_API_KEY= +OLLAMA_CLOUD_API_KEY= +OPENAI_API_KEY= +GOOGLE_API_KEY= diff --git a/examples/computer-use/.gitignore b/examples/computer-use/.gitignore new file mode 100644 index 00000000..d81e3b39 --- /dev/null +++ b/examples/computer-use/.gitignore @@ -0,0 +1 @@ +.ssh-keys diff --git a/examples/computer-use/.infer/.gitignore b/examples/computer-use/.infer/.gitignore new file mode 100644 index 00000000..a9d86493 --- /dev/null +++ b/examples/computer-use/.infer/.gitignore @@ -0,0 +1,2 @@ +logs +tmp diff --git a/examples/computer-use/.infer/config.yaml b/examples/computer-use/.infer/config.yaml new file mode 100644 index 00000000..d668c8b8 --- /dev/null +++ b/examples/computer-use/.infer/config.yaml @@ -0,0 +1,198 @@ +container_runtime: + type: docker +gateway: + url: http://inference-gateway:8080 + api_key: "" + timeout: 200 + run: false + vision_enabled: true +client: + timeout: 200 + retry: + enabled: true + max_attempts: 3 + initial_backoff_sec: 5 + max_backoff_sec: 60 + backoff_multiplier: 2 + retryable_status_codes: + - 400 + - 408 + - 429 + - 500 + - 502 + - 503 + - 504 +logging: + debug: false + dir: "" +tools: + enabled: true + sandbox: + directories: + - . + - /tmp + - .infer/tmp + protected_paths: + - .infer/config.yaml + - .infer/*.db + - .git/ + - '*.env' + bash: + enabled: true + timeout: 120 + background_shells: + enabled: true + max_concurrent: 5 + max_output_buffer_mb: 10 + retention_minutes: 60 + read: + enabled: true + require_approval: false + write: + enabled: true + require_approval: true + edit: + enabled: true + require_approval: true + delete: + enabled: true + require_approval: true + grep: + enabled: true + backend: auto + require_approval: false + tree: + enabled: true + require_approval: false + web_fetch: + enabled: true + safety: + max_size: 10485760 + timeout: 30 + allow_redirect: true + cache: + enabled: true + ttl: 3600 + max_size: 52428800 + web_search: + enabled: true + default_engine: duckduckgo + max_results: 10 + engines: + - duckduckgo + - google + timeout: 10 + todo_write: + enabled: true + require_approval: false + safety: + require_approval: true +image: + max_size: 5242880 + timeout: 30 + clipboard_optimize: + enabled: true + max_width: 1920 + max_height: 1080 + quality: 75 + convert_jpeg: true +export: + output_dir: .infer/tmp + summary_model: "" +agent: + verbose_tools: false + max_turns: 50 + max_tokens: 8192 + max_concurrent_tools: 5 +storage: + enabled: true + type: jsonl + jsonl: + path: .infer/conversations +conversation: + title_generation: + enabled: true + batch_size: 10 + interval: 0 +chat: + theme: tokyo-night + keybindings: + enabled: false + status_bar: + enabled: true + indicators: + model: true + theme: true + max_output: false + a2a_agents: true + tools: true + background_shells: true + a2a_tasks: true + mcp: true + context_usage: true + session_tokens: true + cost: true + git_branch: true +a2a: + enabled: false +mcp: + enabled: false +pricing: + enabled: true + currency: USD + custom_prices: {} +compact: + enabled: true + auto_at: 80 + keep_first_messages: 2 +web: + enabled: true + port: 3000 + host: 0.0.0.0 + session_inactivity_mins: 5 + ssh: + enabled: true + use_ssh_config: false + auto_install: false + servers: + - name: "Ubuntu GUI Desktop" + id: "ubuntu-gui-desktop" + remote_host: "ubuntu-gui" + remote_user: "ubuntu" + remote_port: 22 + command_path: "/usr/local/bin/infer" + auto_install: false + description: "Ubuntu 24.04 Desktop with Computer Use capabilities" + tags: + - local + - computer-use + - ubuntu +computer_use: + enabled: true + display: "ubuntu-gui:1" + screenshot: + enabled: true + max_width: 1920 + max_height: 1080 + format: "jpeg" + quality: 80 + require_approval: false + # Screenshot streaming for web UI + streaming_enabled: true + capture_interval: 3 + buffer_size: 30 + temp_dir: "/tmp/infer-screenshots" + mouse_move: + enabled: true + require_approval: true + mouse_click: + enabled: true + require_approval: true + keyboard_type: + enabled: true + max_text_length: 1000 + require_approval: true + rate_limit: + enabled: true + max_actions_per_minute: 60 + window_seconds: 60 diff --git a/examples/computer-use/Dockerfile.ubuntu-gui b/examples/computer-use/Dockerfile.ubuntu-gui new file mode 100644 index 00000000..3c73c139 --- /dev/null +++ b/examples/computer-use/Dockerfile.ubuntu-gui @@ -0,0 +1,103 @@ +FROM ubuntu:24.04 + +ENV DEBIAN_FRONTEND=noninteractive + +ENV TZ=UTC +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +RUN apt-get update && apt-get install -y \ + xfce4 \ + xfce4-terminal \ + xfce4-goodies \ + xvfb \ + x11-xserver-utils \ + x11-apps \ + xdotool \ + xterm \ + firefox \ + mousepad \ + thunar \ + git \ + curl \ + wget \ + vim \ + nano \ + dbus-x11 \ + net-tools \ + procps \ + sudo \ + pm-utils \ + openssh-server \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN useradd -m -s /bin/bash ubuntu 2>/dev/null || true && \ + echo "ubuntu:ubuntu" | chpasswd && \ + usermod -aG sudo ubuntu && \ + echo "ubuntu ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers + +RUN mkdir -p /var/run/sshd && \ + mkdir -p /home/ubuntu/.ssh && \ + chmod 700 /home/ubuntu/.ssh && \ + chown ubuntu:ubuntu /home/ubuntu/.ssh && \ + sed -i 's/#PasswordAuthentication yes/PasswordAuthentication yes/' /etc/ssh/sshd_config && \ + sed -i 's/PasswordAuthentication no/PasswordAuthentication yes/' /etc/ssh/sshd_config + +COPY <<'EOF' /usr/local/bin/start-services.sh +#!/bin/bash +set -e + +# Export INFER_* and DISPLAY environment variables to /etc/environment for SSH sessions +echo "Exporting environment variables to /etc/environment..." +env | grep -E '^(INFER_|DISPLAY=)' | while IFS='=' read -r key value; do + sed -i "/^${key}=/d" /etc/environment 2>/dev/null || true + echo "${key}=${value}" >> /etc/environment + echo " Exported: ${key}=${value}" +done + +if [ -f /tmp/authorized_keys ]; then + cp /tmp/authorized_keys /home/ubuntu/.ssh/authorized_keys + chmod 600 /home/ubuntu/.ssh/authorized_keys + chown ubuntu:ubuntu /home/ubuntu/.ssh/authorized_keys + echo "✓ SSH keys configured" +fi + +echo "Starting SSH server..." +/usr/sbin/sshd + +SCREEN_RESOLUTION="${SCREEN_RESOLUTION:-1280x720}" +SCREEN_DEPTH="${SCREEN_DEPTH:-24}" + +echo "Starting Xvfb (headless X11 server)..." +Xvfb :1 -screen 0 ${SCREEN_RESOLUTION}x${SCREEN_DEPTH} -ac +extension GLX +render -noreset & +XVFB_PID=$! + +sleep 2 + +export DISPLAY=:1 + +echo "Starting XFCE desktop environment with D-Bus session..." +runuser -u ubuntu -- env DISPLAY=:1 dbus-launch --exit-with-session startxfce4 & +XFCE_PID=$! + +echo "✓ Headless X11 server running on display :1 (${SCREEN_RESOLUTION}x${SCREEN_DEPTH})" +echo "✓ XFCE desktop started" +echo "✓ Ready for computer use tools" + +while kill -0 $XVFB_PID 2>/dev/null && kill -0 $XFCE_PID 2>/dev/null; do + sleep 5 +done + +echo "X11 server or desktop manager stopped, exiting..." +exit 1 +EOF + +RUN chmod +x /usr/local/bin/start-services.sh + +WORKDIR /home/ubuntu + +ENV DISPLAY=:1 + +EXPOSE 22 + +CMD ["/usr/local/bin/start-services.sh"] diff --git a/examples/computer-use/README.md b/examples/computer-use/README.md new file mode 100644 index 00000000..11218f54 --- /dev/null +++ b/examples/computer-use/README.md @@ -0,0 +1,436 @@ +# Computer Use Example + +This example demonstrates the Inference Gateway CLI's computer use capabilities on a remote Ubuntu +desktop environment. The agent can control the GUI by taking screenshots, moving the mouse cursor, +clicking, and typing - all viewable through live screenshot streaming in your browser. + +## Overview + +This setup includes: + +- **Ubuntu Desktop Server** with XFCE desktop environment +- **Xvfb (Headless X11 Server)** - lightweight headless display server +- **Screenshot Streaming** - live desktop view updated every 3 seconds +- **Inference Gateway** with computer use tools enabled +- **Web Interface** for interacting with the agent and viewing screenshots + +The agent can control the remote Ubuntu GUI through specialized tools: + +- `MouseMove`: Move the cursor to specific coordinates +- `MouseClick`: Click mouse buttons (left/right/middle, single/double/triple) +- `KeyboardType`: Type text or send key combinations (e.g., "ctrl+c") + +**Note:** The Screenshot tool is automatically disabled when streaming is enabled - the LLM can see +live screenshots in the web UI overlay updated every 3 seconds. + +## Architecture + +```text +┌─────────────────┐ ┌──────────────────────┐ ┌──────────────────────┐ +│ Your Browser │◄─────────┤ Inference Gateway │◄─────────┤ Ubuntu GUI Server │ +│ (Web UI) │ WebSocket│ (Computer Use CLI) │ SSH+HTTP │ - XFCE Desktop │ +│ + Screenshot │ │ + Screenshot Proxy │ Tunnel │ - Xvfb Display :1 │ +│ Overlay │ └──────────────────────┘ │ - Screenshot API │ +└─────────────────┘ └──────────────────────┘ +``` + +**Key Features:** + +- **Live Screenshot Streaming**: View the desktop in real-time via browser overlay +- **Resource Efficient**: Xvfb uses ~10MB RAM vs VNC's ~100MB +- **HTTP Tunneling**: Screenshot streaming over SSH port forwarding +- **No VNC Client Needed**: Everything accessible via web browser + +## Prerequisites + +- Docker and Docker Compose +- At least 1GB RAM available +- Port 3000 (CLI web UI) available + +## Quick Start + +### 1. Start the Environment + +```bash +docker-compose up -d --build +``` + +This will: + +- Build the Ubuntu GUI server with XFCE desktop and Xvfb +- Start the Inference Gateway with computer use enabled +- Start screenshot streaming server on remote +- Set up SSH tunnel for screenshot API +- Expose the web interface on + +### 2. Access the Web Interface + +Open your browser to: + +```text +http://localhost:3000 +``` + +Click the "Screenshots" button in the top-right to toggle the live screenshot overlay. + +### 3. View Live Screenshots + +Once connected to the Ubuntu GUI Desktop server: + +1. Click the **"Screenshots"** button in the top-right corner +2. The overlay will appear showing the remote desktop +3. Screenshots update every 3 seconds automatically +4. You can see exactly what the agent is doing in real-time + +### 4. Try Computer Use Commands + +In the web interface, try asking the agent to: + +**Open an application:** + +```text +Please click on the Applications menu at the top left (around coordinates 10, 10), +then navigate to and click on Firefox +``` + +**Type text:** + +```text +Please type "Hello from the agent!" in the currently focused application +``` + +**Complex workflow:** + +```text +Please open the Terminal Emulator application, +type "echo 'Agent was here' > /tmp/test.txt", and press Enter +``` + +**Note:** You don't need to ask the agent to take screenshots - they're automatically captured every 3 seconds and visible in the overlay! + +## Configuration + +The computer use tools are configured via environment variables in `docker-compose.yml`: + +```yaml +environment: + # X11 Display Configuration + DISPLAY: :1 + SCREEN_RESOLUTION: 1280x720 + SCREEN_DEPTH: 24 + + # Computer Use Tools + INFER_COMPUTER_USE_ENABLED: "true" + INFER_COMPUTER_USE_DISPLAY: ":1" + + # Screenshot Streaming (automatically disables Screenshot tool) + INFER_COMPUTER_USE_SCREENSHOT_ENABLED: "true" + INFER_COMPUTER_USE_SCREENSHOT_STREAMING_ENABLED: "true" + INFER_COMPUTER_USE_SCREENSHOT_CAPTURE_INTERVAL: "3" # seconds + INFER_COMPUTER_USE_SCREENSHOT_BUFFER_SIZE: "30" # keep last 30 screenshots + + # Mouse and Keyboard + INFER_COMPUTER_USE_MOUSE_MOVE_ENABLED: "true" + INFER_COMPUTER_USE_MOUSE_CLICK_ENABLED: "true" + INFER_COMPUTER_USE_KEYBOARD_TYPE_ENABLED: "true" +``` + +You can also configure via `.infer/config.yaml`: + +```yaml +computer_use: + enabled: true + display: ":1" + + screenshot: + enabled: true + streaming_enabled: true # Disables Screenshot tool, enables streaming + capture_interval: 3 # Capture every 3 seconds + buffer_size: 30 # Keep last 30 screenshots + + mouse_move: + enabled: true + require_approval: true + + mouse_click: + enabled: true + require_approval: true + + keyboard_type: + enabled: true + require_approval: true + + rate_limit: + enabled: true + max_actions_per_minute: 60 +``` + +## Security Features + +### Approval System + +By default, all mouse and keyboard actions require user approval: + +- Screenshots are read-only and don't require approval +- Mouse movements/clicks require approval in the web UI +- Keyboard typing requires approval +- Use `--auto-accept` mode to bypass approvals (use with caution!) + +### Rate Limiting + +Prevents excessive actions: + +- Default: 60 actions per minute +- Sliding window algorithm +- Configurable per deployment + +### Auto-Accept Mode + +For fully autonomous operation: + +```bash +docker-compose exec gateway infer chat --web --agent-mode auto-accept +``` + +**⚠️ Warning:** Auto-accept mode allows the agent to control the GUI without approval. Only use in isolated/sandboxed environments. + +## Available Tools + +### Screenshot + +```yaml +Name: Screenshot +Parameters: + - region (optional): {x, y, width, height} + - display (optional): default ":1" +``` + +### MouseMove + +```yaml +Name: MouseMove +Parameters: + - x: integer (required) + - y: integer (required) + - display (optional): default ":1" +``` + +### MouseClick + +```yaml +Name: MouseClick +Parameters: + - button: "left" | "right" | "middle" (required) + - clicks: 1 | 2 | 3 (optional, default: 1) + - x: integer (optional, move before clicking) + - y: integer (optional, move before clicking) + - display (optional): default ":1" +``` + +### KeyboardType + +```yaml +Name: KeyboardType +Parameters: + - text: string (optional, mutually exclusive with key_combo) + - key_combo: string (optional, e.g., "ctrl+c", "alt+tab") + - display (optional): default ":1" +``` + +## Example Workflows + +### 1. Open and Use Firefox + +```text +Agent: Please help me open Firefox and navigate to example.com + +Steps: +1. Take a screenshot to see the current state +2. Click on Applications menu (top-left, ~10, 10) +3. Click on Web Browser +4. Wait for Firefox to open (take another screenshot) +5. Click on the address bar +6. Type "example.com" +7. Press Enter (send key combo "Return") +``` + +### 2. Create and Edit a Text File + +```text +Agent: Create a text file with some content + +Steps: +1. Take a screenshot +2. Click on Applications menu +3. Navigate to Accessories → Text Editor +4. Type "This is a test file created by the agent" +5. Press Ctrl+S to save +6. Type "/tmp/agent-test.txt" as filename +7. Press Enter to confirm +``` + +### 3. Run Terminal Commands + +```text +Agent: Please run "ls -la" in the terminal + +Steps: +1. Take a screenshot +2. Click on Applications menu +3. Click on Terminal Emulator +4. Type "ls -la" +5. Press Enter +6. Take a screenshot of the output +``` + +## Troubleshooting + +### VNC Connection Refused + +```bash +# Check if the GUI server is running +docker-compose ps ubuntu-gui + +# View logs +docker-compose logs ubuntu-gui +``` + +### X11 Connection Issues + +```bash +# Verify DISPLAY is set correctly +docker-compose exec ubuntu-gui echo $DISPLAY + +# Should output: :1 +``` + +### Computer Use Tools Not Available + +```bash +# Check configuration +docker-compose exec gateway cat /workspace/.infer/config.yaml | grep -A 20 computer_use + +# Ensure enabled: true +``` + +### Agent Can't Control GUI + +```bash +# Verify X11 is running +docker-compose exec ubuntu-gui ps aux | grep X + +# Check VNC server +docker-compose exec ubuntu-gui ps aux | grep vnc +``` + +## Cleanup + +```bash +# Stop all services +docker-compose down + +# Remove volumes (resets desktop state) +docker-compose down -v +``` + +## Advanced Usage + +### Custom Desktop Environment + +Edit `Dockerfile.ubuntu-gui` to use a different desktop environment: + +```dockerfile +# Replace XFCE with MATE, LXDE, etc. +RUN apt-get install -y ubuntu-mate-desktop +``` + +### Custom Screen Resolution + +Edit `docker-compose.yml`: + +```yaml +environment: + - VNC_RESOLUTION=1920x1080 +``` + +### Persistent Desktop State + +Mount a volume for the user home directory: + +```yaml +volumes: + - ubuntu-home:/home/ubuntu +``` + +## Implementation Details + +### Display Server Detection + +The CLI automatically detects X11 vs Wayland: + +- Ubuntu GUI server uses X11 (DISPLAY=:1) +- Pure Go implementation using `xgb`/`xgbutil` libraries +- No CGO required for X11 support + +### Screenshot Streaming + +Screenshots are: + +1. Captured via X11 protocol +2. Optimized (resized/compressed per config) +3. Base64 encoded +4. Sent via WebSocket to browser +5. Displayed in the web interface + +### Mouse/Keyboard Control + +Uses X11 protocol directly: + +- `xproto.WarpPointer` for mouse movement +- `xproto.ButtonPress`/`ButtonRelease` for clicks (requires xdotool) +- `KeyPress`/`KeyRelease` for typing (requires keysym mapping) + +**Note:** Full keyboard and mouse click support in pure Go is limited. The example uses command-line tools (`xdotool`, `xte`) as a fallback. + +## Security Considerations + +### Isolation + +- Run in isolated Docker network +- Don't expose VNC port publicly +- Use strong VNC passwords in production +- Consider using SSH tunnels for VNC access + +### Action Approval + +- Always require approval for mouse/keyboard in production +- Use auto-accept mode only in sandboxed environments +- Monitor agent actions via web interface +- Enable rate limiting to prevent abuse + +### Audit Logging + +All computer use actions are logged: + +- Timestamps +- Tool name (Screenshot, MouseMove, etc.) +- Arguments (coordinates, text, key combos) +- Success/failure status +- User approval decisions + +Check logs: + +```bash +docker-compose logs gateway | grep "computer_use" +``` + +## Further Reading + +- [Anthropic Computer Use Guide](https://docs.anthropic.com/claude/docs/computer-use) +- [X11 Protocol Documentation](https://www.x.org/releases/current/doc/) +- [VNC Protocol](https://en.wikipedia.org/wiki/Virtual_Network_Computing) +- [XFCE Desktop Environment](https://www.xfce.org/) + +## Contributing + +Found an issue or want to improve this example? Please open an issue or PR in the main repository. diff --git a/examples/computer-use/docker-compose.yml b/examples/computer-use/docker-compose.yml new file mode 100644 index 00000000..6e1eb24e --- /dev/null +++ b/examples/computer-use/docker-compose.yml @@ -0,0 +1,122 @@ +--- +services: + ssh-keygen: + image: alpine:3.23.0 + command: + - sh + - -c + - | + apk add --no-cache openssh-keygen + if [ ! -f /keys/id_rsa ]; then + echo "Generating SSH keys for demo..." + ssh-keygen -t rsa -b 2048 -f /keys/id_rsa -N "" + chmod 600 /keys/id_rsa + chmod 644 /keys/id_rsa.pub + echo "✓ SSH keys generated in .ssh-keys/" + else + echo "✓ Using existing SSH keys from .ssh-keys/" + fi + volumes: + - ./.ssh-keys:/keys + + ubuntu-gui: + build: + context: . + dockerfile: Dockerfile.ubuntu-gui + hostname: ubuntu-gui + environment: + DISPLAY: :1 + SCREEN_RESOLUTION: 1280x720 + SCREEN_DEPTH: 24 + INFER_COMPUTER_USE_ENABLED: "true" + INFER_COMPUTER_USE_DISPLAY: ":1" + INFER_COMPUTER_USE_SCREENSHOT_ENABLED: "true" + INFER_COMPUTER_USE_SCREENSHOT_STREAMING_ENABLED: "true" + INFER_COMPUTER_USE_SCREENSHOT_CAPTURE_INTERVAL: "3" + INFER_COMPUTER_USE_SCREENSHOT_BUFFER_SIZE: "30" + INFER_COMPUTER_USE_MOUSE_MOVE_ENABLED: "true" + INFER_COMPUTER_USE_MOUSE_CLICK_ENABLED: "true" + INFER_COMPUTER_USE_KEYBOARD_TYPE_ENABLED: "true" + volumes: + - ubuntu-home:/home/ubuntu + - ../../dist/infer-linux-arm64:/usr/local/bin/infer:ro + - ./.ssh-keys/id_rsa.pub:/tmp/authorized_keys:ro + networks: + - computer-use-network + depends_on: + ssh-keygen: + condition: service_completed_successfully + healthcheck: + test: + - CMD + - pgrep + - Xvfb + interval: 10s + timeout: 5s + retries: 5 + restart: unless-stopped + + cli: + image: ghcr.io/inference-gateway/cli:local + hostname: cli + working_dir: /workspace + environment: + INFER_GATEWAY_URL: http://inference-gateway:8080 + INFER_GATEWAY_MODE: docker + INFER_COMPUTER_USE_ENABLED: true + INFER_COMPUTER_USE_DISPLAY: ubuntu-gui:1 + INFER_WEB_ENABLED: true + INFER_WEB_HOST: 0.0.0.0 + INFER_WEB_PORT: 3000 + INFER_STORAGE_TYPE: jsonl + INFER_DEFAULT_MODEL: anthropic/claude-sonnet-4.5 + ports: + - "3000:3000" + volumes: + - ./:/workspace + - gateway-data:/root/.infer + - ./.ssh-keys:/home/infer/.ssh:ro + networks: + - computer-use-network + depends_on: + ssh-keygen: + condition: service_completed_successfully + ubuntu-gui: + condition: service_healthy + inference-gateway: + condition: service_started + entrypoint: + - /bin/sh + - -c + command: + - | + echo 'Waiting for Ubuntu GUI server to be ready...' + sleep 5 + echo 'Starting Inference Gateway CLI in web mode...' + infer chat --web + restart: unless-stopped + + inference-gateway: + image: ghcr.io/inference-gateway/inference-gateway:latest + hostname: inference-gateway + environment: + PORT: 8080 + LOG_LEVEL: info + ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-} + OPENAI_API_KEY: ${OPENAI_API_KEY:-} + GOOGLE_API_KEY: ${GOOGLE_API_KEY:-} + ports: + - "8080:8080" + networks: + - computer-use-network + restart: unless-stopped + +volumes: + ubuntu-home: + driver: local + gateway-data: + driver: local + +networks: + computer-use-network: + driver: bridge diff --git a/go.mod b/go.mod index 07954ad9..7c3667d5 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,8 @@ go 1.25.4 tool github.com/maxbrunsfeld/counterfeiter/v6 require ( + github.com/BurntSushi/xgb v0.0.0-20210121224620-deaf085860bc + github.com/BurntSushi/xgbutil v0.0.0-20190907113008-ad855c713046 github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/glamour v0.10.0 @@ -33,6 +35,8 @@ require ( ) require ( + github.com/BurntSushi/freetype-go v0.0.0-20160129220410-b763ddbfe298 // indirect + github.com/BurntSushi/graphics-go v0.0.0-20160129215708-b43f31a4a966 // indirect github.com/alecthomas/chroma/v2 v2.14.0 // indirect github.com/apapsch/go-jsonmerge/v2 v2.0.0 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect diff --git a/go.sum b/go.sum index 3592cb76..51f9ff39 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,11 @@ +github.com/BurntSushi/freetype-go v0.0.0-20160129220410-b763ddbfe298 h1:1qlsVAQJXZHsaM8b6OLVo6muQUQd4CwkH/D3fnnbHXA= +github.com/BurntSushi/freetype-go v0.0.0-20160129220410-b763ddbfe298/go.mod h1:D+QujdIlUNfa0igpNMk6UIvlb6C252URs4yupRUV4lQ= +github.com/BurntSushi/graphics-go v0.0.0-20160129215708-b43f31a4a966 h1:lTG4HQym5oPKjL7nGs+csTgiDna685ZXjxijkne828g= +github.com/BurntSushi/graphics-go v0.0.0-20160129215708-b43f31a4a966/go.mod h1:Mid70uvE93zn9wgF92A/r5ixgnvX8Lh68fxp9KQBaI0= +github.com/BurntSushi/xgb v0.0.0-20210121224620-deaf085860bc h1:7D+Bh06CRPCJO3gr2F7h1sriovOZ8BMhca2Rg85c2nk= +github.com/BurntSushi/xgb v0.0.0-20210121224620-deaf085860bc/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/BurntSushi/xgbutil v0.0.0-20190907113008-ad855c713046 h1:O/r2Sj+8QcMF7V5IcmiE2sMFV2q3J47BEirxbXJAdzA= +github.com/BurntSushi/xgbutil v0.0.0-20190907113008-ad855c713046/go.mod h1:uw9h2sd4WWHOPdJ13MQpwK5qYWKYDumDqxWWIknEQ+k= github.com/RaveNoX/go-jsoncommentstrip v1.0.0/go.mod h1:78ihd09MekBnJnxpICcwzCMzGrKSKYe4AqU6PDYYpjk= github.com/alecthomas/assert/v2 v2.7.0 h1:QtqSACNS3tF7oasA8CU6A6sXZSBDqnm7RfpLl9bZqbE= github.com/alecthomas/assert/v2 v2.7.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= diff --git a/internal/domain/events.go b/internal/domain/events.go index 417cb391..3ac7346d 100644 --- a/internal/domain/events.go +++ b/internal/domain/events.go @@ -391,3 +391,16 @@ type MessageEditSubmitEvent struct { func (e MessageEditSubmitEvent) GetRequestID() string { return e.RequestID } func (e MessageEditSubmitEvent) GetTimestamp() time.Time { return e.Timestamp } + +// ComputerUseScreenshotEvent is emitted when a screenshot is captured +type ComputerUseScreenshotEvent struct { + RequestID string + Timestamp time.Time + Width int + Height int + Region *ScreenRegion + ImageData string +} + +func (e ComputerUseScreenshotEvent) GetRequestID() string { return e.RequestID } +func (e ComputerUseScreenshotEvent) GetTimestamp() time.Time { return e.Timestamp } diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 16ea3496..353243b4 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -28,6 +28,70 @@ type ImageAttachment struct { SourcePath string `json:"-"` } +// Computer use result types + +// ScreenRegion represents a rectangular region of the screen +type ScreenRegion struct { + X int `json:"x"` + Y int `json:"y"` + Width int `json:"width"` + Height int `json:"height"` +} + +// Screenshot represents a captured screenshot with metadata +type Screenshot struct { + ID string `json:"id"` + Timestamp time.Time `json:"timestamp"` + Data string `json:"data"` // base64 encoded image + Width int `json:"width"` + Height int `json:"height"` + Format string `json:"format"` // "png" or "jpeg" + Method string `json:"method"` // "x11" or "wayland" +} + +// ScreenshotProvider defines the interface for getting screenshots from a buffer +type ScreenshotProvider interface { + GetLatestScreenshot() (*Screenshot, error) +} + +// ScreenshotToolResult represents the result of a screenshot capture +type ScreenshotToolResult struct { + Display string `json:"display"` + Region *ScreenRegion `json:"region,omitempty"` + Width int `json:"width"` + Height int `json:"height"` + Format string `json:"format"` + Method string `json:"method"` +} + +// MouseMoveToolResult represents the result of a mouse move operation +type MouseMoveToolResult struct { + FromX int `json:"from_x"` + FromY int `json:"from_y"` + ToX int `json:"to_x"` + ToY int `json:"to_y"` + Display string `json:"display"` + Method string `json:"method"` +} + +// MouseClickToolResult represents the result of a mouse click operation +type MouseClickToolResult struct { + Button string `json:"button"` + Clicks int `json:"clicks"` + X int `json:"x"` + Y int `json:"y"` + Display string `json:"display"` + Method string `json:"method"` +} + +// KeyboardTypeToolResult represents the result of a keyboard input operation +type KeyboardTypeToolResult struct { + Text string `json:"text,omitempty"` + KeyCombo string `json:"key_combo,omitempty"` + Display string `json:"display"` + Method string `json:"method"` +} + // ConversationEntry represents a message in the conversation with metadata type ConversationEntry struct { // Core message fields diff --git a/internal/services/circular_screenshot_buffer.go b/internal/services/circular_screenshot_buffer.go new file mode 100644 index 00000000..f2bd5632 --- /dev/null +++ b/internal/services/circular_screenshot_buffer.go @@ -0,0 +1,194 @@ +package services + +import ( + "encoding/base64" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + uuid "github.com/google/uuid" + domain "github.com/inference-gateway/cli/internal/domain" + logger "github.com/inference-gateway/cli/internal/logger" +) + +// CircularScreenshotBuffer implements a thread-safe ring buffer for screenshots +// with optional disk persistence +type CircularScreenshotBuffer struct { + screenshots []*domain.Screenshot + maxSize int + currentIndex int + count int + mu sync.RWMutex + tempDir string + sessionID string +} + +// NewCircularScreenshotBuffer creates a new circular buffer for screenshots +func NewCircularScreenshotBuffer(maxSize int, tempDir string, sessionID string) (*CircularScreenshotBuffer, error) { + // Create temp directory for this session + sessionTempDir := filepath.Join(tempDir, fmt.Sprintf("session-%s", sessionID)) + if err := os.MkdirAll(sessionTempDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create temp directory: %w", err) + } + + return &CircularScreenshotBuffer{ + screenshots: make([]*domain.Screenshot, maxSize), + maxSize: maxSize, + currentIndex: 0, + count: 0, + tempDir: sessionTempDir, + sessionID: sessionID, + }, nil +} + +// Add adds a new screenshot to the buffer +// If the buffer is full, it evicts the oldest screenshot +func (b *CircularScreenshotBuffer) Add(screenshot *domain.Screenshot) error { + b.mu.Lock() + defer b.mu.Unlock() + + // Generate ID if not set + if screenshot.ID == "" { + screenshot.ID = uuid.New().String() + } + + // Set timestamp if not set + if screenshot.Timestamp.IsZero() { + screenshot.Timestamp = time.Now() + } + + // Evict old screenshot if buffer is full + if b.count >= b.maxSize { + oldScreenshot := b.screenshots[b.currentIndex] + if oldScreenshot != nil { + b.deleteFromDisk(oldScreenshot.ID) + } + } + + b.screenshots[b.currentIndex] = screenshot + + if err := b.writeToDisk(screenshot); err != nil { + logger.Warn("Failed to write screenshot to disk", "error", err, "screenshot_id", screenshot.ID) + } + + b.currentIndex = (b.currentIndex + 1) % b.maxSize + if b.count < b.maxSize { + b.count++ + } + + return nil +} + +// GetLatest returns the most recent screenshot +func (b *CircularScreenshotBuffer) GetLatest() (*domain.Screenshot, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + if b.count == 0 { + return nil, fmt.Errorf("buffer is empty") + } + + latestIndex := (b.currentIndex - 1 + b.maxSize) % b.maxSize + return b.screenshots[latestIndex], nil +} + +// GetByID returns a screenshot by its ID +func (b *CircularScreenshotBuffer) GetByID(id string) (*domain.Screenshot, error) { + b.mu.RLock() + defer b.mu.RUnlock() + + for i := 0; i < b.count; i++ { + if b.screenshots[i] != nil && b.screenshots[i].ID == id { + return b.screenshots[i], nil + } + } + + return nil, fmt.Errorf("screenshot not found: %s", id) +} + +// GetRecent returns the N most recent screenshots +func (b *CircularScreenshotBuffer) GetRecent(limit int) []*domain.Screenshot { + b.mu.RLock() + defer b.mu.RUnlock() + + if limit <= 0 || limit > b.count { + limit = b.count + } + + result := make([]*domain.Screenshot, 0, limit) + + for i := 0; i < limit; i++ { + index := (b.currentIndex - 1 - i + b.maxSize) % b.maxSize + if b.screenshots[index] != nil { + result = append(result, b.screenshots[index]) + } + } + + return result +} + +// Count returns the current number of screenshots in the buffer +func (b *CircularScreenshotBuffer) Count() int { + b.mu.RLock() + defer b.mu.RUnlock() + return b.count +} + +// Clear removes all screenshots from the buffer and deletes disk files +func (b *CircularScreenshotBuffer) Clear() error { + b.mu.Lock() + defer b.mu.Unlock() + + for i := 0; i < b.count; i++ { + if b.screenshots[i] != nil { + b.deleteFromDisk(b.screenshots[i].ID) + } + } + + b.screenshots = make([]*domain.Screenshot, b.maxSize) + b.currentIndex = 0 + b.count = 0 + + return nil +} + +// Cleanup removes the temp directory and all screenshots +func (b *CircularScreenshotBuffer) Cleanup() error { + b.mu.Lock() + defer b.mu.Unlock() + + if err := os.RemoveAll(b.tempDir); err != nil { + return fmt.Errorf("failed to cleanup temp directory: %w", err) + } + + return nil +} + +// writeToDisk writes a screenshot to disk as a PNG file +func (b *CircularScreenshotBuffer) writeToDisk(screenshot *domain.Screenshot) error { + if screenshot.Data == "" { + return nil + } + + imageData, err := base64.StdEncoding.DecodeString(screenshot.Data) + if err != nil { + return fmt.Errorf("failed to decode base64 data: %w", err) + } + + filename := filepath.Join(b.tempDir, fmt.Sprintf("screenshot-%s.png", screenshot.ID)) + if err := os.WriteFile(filename, imageData, 0644); err != nil { + return fmt.Errorf("failed to write file: %w", err) + } + + return nil +} + +// deleteFromDisk removes a screenshot file from disk +func (b *CircularScreenshotBuffer) deleteFromDisk(id string) { + filename := filepath.Join(b.tempDir, fmt.Sprintf("screenshot-%s.png", id)) + if err := os.Remove(filename); err != nil && !os.IsNotExist(err) { + logger.Warn("Failed to delete screenshot file", "error", err, "filename", filename) + } +} diff --git a/internal/services/screenshot_server.go b/internal/services/screenshot_server.go new file mode 100644 index 00000000..cb484809 --- /dev/null +++ b/internal/services/screenshot_server.go @@ -0,0 +1,307 @@ +package services + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "strconv" + "sync" + "time" + + config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" + logger "github.com/inference-gateway/cli/internal/logger" + tools "github.com/inference-gateway/cli/internal/services/tools" +) + +// ScreenshotServer provides an HTTP API for screenshot streaming +type ScreenshotServer struct { + cfg *config.Config + port int + server *http.Server + buffer *CircularScreenshotBuffer + captureCtx context.Context + captureStop context.CancelFunc + mu sync.RWMutex + sessionID string + imageSvc domain.ImageService + running bool +} + +// NewScreenshotServer creates a new screenshot server +func NewScreenshotServer(cfg *config.Config, imageService domain.ImageService, sessionID string) *ScreenshotServer { + return &ScreenshotServer{ + cfg: cfg, + sessionID: sessionID, + imageSvc: imageService, + running: false, + } +} + +// Start starts the HTTP server and background capture loop +func (s *ScreenshotServer) Start() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.running { + return fmt.Errorf("screenshot server already running") + } + + logger.Info("Starting screenshot server", "session_id", s.sessionID) + + // Create circular buffer + bufferSize := s.cfg.ComputerUse.Screenshot.BufferSize + if bufferSize <= 0 { + bufferSize = 30 // default + } + + tempDir := s.cfg.ComputerUse.Screenshot.TempDir + if tempDir == "" { + tempDir = "/tmp/infer-screenshots" + } + + logger.Info("Creating screenshot buffer", "buffer_size", bufferSize, "temp_dir", tempDir) + + buffer, err := NewCircularScreenshotBuffer(bufferSize, tempDir, s.sessionID) + if err != nil { + return fmt.Errorf("failed to create screenshot buffer: %w", err) + } + s.buffer = buffer + + // Listen on random port + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + return fmt.Errorf("failed to listen: %w", err) + } + + s.port = listener.Addr().(*net.TCPAddr).Port + logger.Info("Screenshot server listening", "port", s.port) + + // Create HTTP server + mux := http.NewServeMux() + mux.HandleFunc("/api/screenshots/latest", s.handleGetLatest) + mux.HandleFunc("/api/screenshots", s.handleGetRecent) + mux.HandleFunc("/api/screenshots/status", s.handleGetStatus) + + s.server = &http.Server{ + Handler: mux, + } + + // Start HTTP server in goroutine + go func() { + logger.Info("Screenshot HTTP server started", "port", s.port) + if err := s.server.Serve(listener); err != nil && err != http.ErrServerClosed { + logger.Error("Screenshot server error", "error", err) + } + }() + + // Start capture loop + s.captureCtx, s.captureStop = context.WithCancel(context.Background()) + go s.startCaptureLoop() + + s.running = true + logger.Info("Screenshot server fully initialized", "port", s.port, "capture_interval", s.cfg.ComputerUse.Screenshot.CaptureInterval) + + return nil +} + +// Stop stops the HTTP server and capture loop +func (s *ScreenshotServer) Stop() error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.running { + return nil + } + + // Stop capture loop + if s.captureStop != nil { + s.captureStop() + } + + // Shutdown HTTP server + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if s.server != nil { + if err := s.server.Shutdown(ctx); err != nil { + return fmt.Errorf("failed to shutdown server: %w", err) + } + } + + // Cleanup buffer + if s.buffer != nil { + if err := s.buffer.Cleanup(); err != nil { + logger.Warn("Failed to cleanup buffer", "error", err) + } + } + + s.running = false + + return nil +} + +// Port returns the port the server is listening on +func (s *ScreenshotServer) Port() int { + s.mu.RLock() + defer s.mu.RUnlock() + return s.port +} + +// startCaptureLoop runs the background screenshot capture loop +func (s *ScreenshotServer) startCaptureLoop() { + interval := s.cfg.ComputerUse.Screenshot.CaptureInterval + if interval <= 0 { + interval = 3 // default: 3 seconds + } + + logger.Info("Screenshot capture loop started", "interval_seconds", interval) + + ticker := time.NewTicker(time.Duration(interval) * time.Second) + defer ticker.Stop() + + for { + select { + case <-s.captureCtx.Done(): + logger.Info("Screenshot capture loop stopped") + return + case <-ticker.C: + logger.Info("Attempting screenshot capture") + if err := s.captureScreenshot(); err != nil { + logger.Warn("Screenshot capture failed", "error", err) + } else { + logger.Info("Screenshot captured successfully") + } + } + } +} + +// captureScreenshot captures a screenshot and adds it to the buffer +func (s *ScreenshotServer) captureScreenshot() error { + // Use the screenshot tool to capture + tool := tools.NewScreenshotTool(s.cfg, s.imageSvc, nil) // No rate limiter for auto-capture + + // Execute with default args (full screen) + result, err := tool.Execute(s.captureCtx, map[string]any{}) + if err != nil { + return err + } + + if !result.Success { + return fmt.Errorf("screenshot capture failed: %s", result.Error) + } + + // Extract screenshot data + toolResult, ok := result.Data.(domain.ScreenshotToolResult) + if !ok { + return fmt.Errorf("unexpected result type") + } + + // Get image attachment + if len(result.Images) == 0 { + return fmt.Errorf("no image in result") + } + + imageAttachment := result.Images[0] + + // Create Screenshot object + screenshot := &domain.Screenshot{ + Timestamp: time.Now(), + Data: imageAttachment.Data, + Width: toolResult.Width, + Height: toolResult.Height, + Format: toolResult.Format, + Method: toolResult.Method, + } + + // Add to buffer + return s.buffer.Add(screenshot) +} + +// handleGetLatest handles GET /api/screenshots/latest +func (s *ScreenshotServer) handleGetLatest(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + screenshot, err := s.buffer.GetLatest() + if err != nil { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(screenshot); err != nil { + logger.Warn("Failed to encode screenshot response", "error", err) + } +} + +// handleGetRecent handles GET /api/screenshots?limit=N +func (s *ScreenshotServer) handleGetRecent(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Parse limit parameter + limit := 30 // default + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil { + if parsedLimit > 0 && parsedLimit <= 100 { + limit = parsedLimit + } + } + } + + screenshots := s.buffer.GetRecent(limit) + + response := map[string]interface{}{ + "screenshots": screenshots, + "count": len(screenshots), + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + logger.Warn("Failed to encode screenshots response", "error", err) + } +} + +// handleGetStatus handles GET /api/screenshots/status +func (s *ScreenshotServer) handleGetStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + s.mu.RLock() + defer s.mu.RUnlock() + + status := map[string]interface{}{ + "running": s.running, + "count": s.buffer.Count(), + "interval_sec": s.cfg.ComputerUse.Screenshot.CaptureInterval, + "port": s.port, + "session_id": s.sessionID, + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(status); err != nil { + logger.Warn("Failed to encode status response", "error", err) + } +} + +// GetLatestScreenshot retrieves the latest screenshot from the buffer +// Implements the ScreenshotProvider interface for use by GetLatestScreenshotTool +func (s *ScreenshotServer) GetLatestScreenshot() (*domain.Screenshot, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if s.buffer == nil { + return nil, fmt.Errorf("screenshot buffer not initialized") + } + + return s.buffer.GetLatest() +} diff --git a/internal/services/tools/computer_use_common.go b/internal/services/tools/computer_use_common.go new file mode 100644 index 00000000..aca53ce1 --- /dev/null +++ b/internal/services/tools/computer_use_common.go @@ -0,0 +1,123 @@ +package tools + +import ( + "fmt" + "os" + "sync" + "time" + + config "github.com/inference-gateway/cli/config" +) + +// DisplayServer represents the type of display server +type DisplayServer int + +const ( + DisplayServerX11 DisplayServer = iota + DisplayServerWayland + DisplayServerUnknown +) + +// DetectDisplayServer detects which display server is running +func DetectDisplayServer() DisplayServer { + // Check for Wayland first (more modern) + if os.Getenv("WAYLAND_DISPLAY") != "" { + return DisplayServerWayland + } + + // Check for X11 + if os.Getenv("DISPLAY") != "" { + return DisplayServerX11 + } + + return DisplayServerUnknown +} + +// GetDisplayName returns a string name for the display server +func (ds DisplayServer) String() string { + switch ds { + case DisplayServerX11: + return "x11" + case DisplayServerWayland: + return "wayland" + default: + return "unknown" + } +} + +// RateLimiter implements token bucket rate limiting for computer use actions +type RateLimiter struct { + cfg *config.RateLimitConfig + actionTimes []time.Time + mu sync.Mutex +} + +// NewRateLimiter creates a new rate limiter +func NewRateLimiter(cfg config.RateLimitConfig) *RateLimiter { + return &RateLimiter{ + cfg: &cfg, + actionTimes: make([]time.Time, 0), + } +} + +// CheckAndRecord checks if the action is within rate limits and records it +// Returns an error if the rate limit is exceeded +func (rl *RateLimiter) CheckAndRecord(toolName string) error { + if !rl.cfg.Enabled { + return nil // Rate limiting disabled + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + windowStart := now.Add(-time.Duration(rl.cfg.WindowSeconds) * time.Second) + + // Remove actions outside the time window + validActions := make([]time.Time, 0) + for _, t := range rl.actionTimes { + if t.After(windowStart) { + validActions = append(validActions, t) + } + } + rl.actionTimes = validActions + + // Check if at limit + if len(rl.actionTimes) >= rl.cfg.MaxActionsPerMinute { + return fmt.Errorf("rate limit exceeded: maximum %d actions per %d seconds (current: %d actions in window)", + rl.cfg.MaxActionsPerMinute, rl.cfg.WindowSeconds, len(rl.actionTimes)) + } + + // Record the new action + rl.actionTimes = append(rl.actionTimes, now) + return nil +} + +// GetCurrentCount returns the number of actions in the current window +func (rl *RateLimiter) GetCurrentCount() int { + if !rl.cfg.Enabled { + return 0 + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + windowStart := now.Add(-time.Duration(rl.cfg.WindowSeconds) * time.Second) + + count := 0 + for _, t := range rl.actionTimes { + if t.After(windowStart) { + count++ + } + } + + return count +} + +// Reset clears all recorded actions +func (rl *RateLimiter) Reset() { + rl.mu.Lock() + defer rl.mu.Unlock() + rl.actionTimes = make([]time.Time, 0) +} diff --git a/internal/services/tools/computer_use_wayland.go b/internal/services/tools/computer_use_wayland.go new file mode 100644 index 00000000..c1e625f7 --- /dev/null +++ b/internal/services/tools/computer_use_wayland.go @@ -0,0 +1,269 @@ +package tools + +import ( + "context" + "fmt" + "os/exec" + "strconv" + "strings" + "time" +) + +// WaylandClient provides Wayland screen control operations using command-line tools +type WaylandClient struct { + display string +} + +// NewWaylandClient creates a new Wayland client +func NewWaylandClient(display string) (*WaylandClient, error) { + if err := checkWaylandTools(); err != nil { + return nil, err + } + + return &WaylandClient{ + display: display, + }, nil +} + +// checkWaylandTools checks if required Wayland tools are available +func checkWaylandTools() error { + tools := []string{"grim"} + + for _, tool := range tools { + if _, err := exec.LookPath(tool); err != nil { + return fmt.Errorf("required tool '%s' not found in PATH (install with: sudo apt install %s)", tool, tool) + } + } + + return nil +} + +// Close closes the Wayland client (no-op for command-line tools) +func (c *WaylandClient) Close() { + // Nothing to close for command-line tools +} + +// CaptureScreenBytes captures a screenshot and returns it as PNG bytes +func (c *WaylandClient) CaptureScreenBytes(x, y, width, height int) ([]byte, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var cmd *exec.Cmd + if width == 0 || height == 0 { + cmd = exec.CommandContext(ctx, "grim", "-") + } else { + geometry := fmt.Sprintf("%d,%d %dx%d", x, y, width, height) + cmd = exec.CommandContext(ctx, "grim", "-g", geometry, "-") + } + + output, err := cmd.Output() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + return nil, fmt.Errorf("grim failed: %s", string(exitErr.Stderr)) + } + return nil, fmt.Errorf("failed to capture screenshot: %w", err) + } + + return output, nil +} + +// MoveMouse moves the cursor to the specified absolute coordinates +func (c *WaylandClient) MoveMouse(x, y int) error { + if _, err := exec.LookPath("ydotool"); err != nil { + return fmt.Errorf("ydotool not found (install with: sudo apt install ydotool)") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "ydotool", "mousemove", "--absolute", "--", + strconv.Itoa(x), strconv.Itoa(y)) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ydotool mousemove failed: %s", string(output)) + } + + return nil +} + +// ClickMouse performs a mouse click at the current cursor position +func (c *WaylandClient) ClickMouse(button string, clicks int) error { + if _, err := exec.LookPath("ydotool"); err != nil { + return fmt.Errorf("ydotool not found (install with: sudo apt install ydotool)") + } + + var buttonCode string + switch button { + case "left": + buttonCode = "0xC0" + case "middle": + buttonCode = "0xC1" + case "right": + buttonCode = "0xC2" + default: + return fmt.Errorf("invalid button: %s (must be left, middle, or right)", button) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + for i := 0; i < clicks; i++ { + cmd := exec.CommandContext(ctx, "ydotool", "click", buttonCode) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ydotool click failed: %s", string(output)) + } + + if i < clicks-1 { + time.Sleep(100 * time.Millisecond) + } + } + + return nil +} + +// TypeText types the given text +func (c *WaylandClient) TypeText(text string) error { + if _, err := exec.LookPath("wtype"); err == nil { + return c.typeTextWithWtype(text) + } + + if _, err := exec.LookPath("ydotool"); err == nil { + return c.typeTextWithYdotool(text) + } + + return fmt.Errorf("no text input tool available (install wtype or ydotool)") +} + +// typeTextWithWtype types text using the wtype command +func (c *WaylandClient) typeTextWithWtype(text string) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "wtype", text) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("wtype failed: %s", string(output)) + } + + return nil +} + +// typeTextWithYdotool types text using the ydotool command +func (c *WaylandClient) typeTextWithYdotool(text string) error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "ydotool", "type", text) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ydotool type failed: %s", string(output)) + } + + return nil +} + +// SendKeyCombo sends a key combination (e.g., "ctrl+c") +func (c *WaylandClient) SendKeyCombo(combo string) error { + if _, err := exec.LookPath("wtype"); err == nil { + return c.sendKeyComboWithWtype(combo) + } + + if _, err := exec.LookPath("ydotool"); err == nil { + return c.sendKeyComboWithYdotool(combo) + } + + return fmt.Errorf("no key combo tool available (install wtype or ydotool)") +} + +// sendKeyComboWithWtype sends a key combination using wtype +func (c *WaylandClient) sendKeyComboWithWtype(combo string) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + parts := strings.Split(combo, "+") + if len(parts) < 2 { + return fmt.Errorf("invalid key combo format: %s (expected format: modifier+key)", combo) + } + + modifiers := parts[:len(parts)-1] + key := parts[len(parts)-1] + + args := []string{} + for _, mod := range modifiers { + args = append(args, "-M", mod) + } + args = append(args, "-P", key) + + cmd := exec.CommandContext(ctx, "wtype", args...) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("wtype key combo failed: %s", string(output)) + } + + return nil +} + +// sendKeyComboWithYdotool sends a key combination using ydotool +func (c *WaylandClient) sendKeyComboWithYdotool(combo string) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // ydotool key combo format: key1:key2 + // Convert "ctrl+c" to "29:46" (keycodes) + // This is a simplified version - proper implementation would need keycode mapping + + cmd := exec.CommandContext(ctx, "ydotool", "key", combo) + + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ydotool key combo failed: %s", string(output)) + } + + return nil +} + +// GetScreenDimensions returns the screen width and height +func (c *WaylandClient) GetScreenDimensions() (int, int, error) { + // Wayland doesn't have a simple command to get screen dimensions + // We can use wlr-randr if available, or return default values + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "wlr-randr") + output, err := cmd.Output() + if err != nil { + return 1920, 1080, nil + } + + lines := strings.Split(string(output), "\n") + for _, line := range lines { + if !strings.Contains(line, "current") { + continue + } + + parts := strings.Fields(line) + if len(parts) == 0 { + continue + } + + dims := strings.Split(parts[0], "x") + if len(dims) != 2 { + continue + } + + width, _ := strconv.Atoi(dims[0]) + height, _ := strconv.Atoi(dims[1]) + if width > 0 && height > 0 { + return width, height, nil + } + } + + return 1920, 1080, nil +} diff --git a/internal/services/tools/computer_use_x11.go b/internal/services/tools/computer_use_x11.go new file mode 100644 index 00000000..42fb6f34 --- /dev/null +++ b/internal/services/tools/computer_use_x11.go @@ -0,0 +1,176 @@ +package tools + +import ( + "bytes" + "fmt" + "image" + "image/png" + "os" + + xgb "github.com/BurntSushi/xgb" + xproto "github.com/BurntSushi/xgb/xproto" + xgbutil "github.com/BurntSushi/xgbutil" + xgraphics "github.com/BurntSushi/xgbutil/xgraphics" + + logger "github.com/inference-gateway/cli/internal/logger" +) + +// X11Client wraps X11 connection and provides screen control operations +type X11Client struct { + xu *xgbutil.XUtil + conn *xgb.Conn + screen *xproto.ScreenInfo + display string +} + +// NewX11Client creates a new X11 client connection +func NewX11Client(display string) (*X11Client, error) { + if display == "" { + display = ":0" + } + + oldStderr := os.Stderr + devNull, devErr := os.OpenFile(os.DevNull, os.O_WRONLY, 0) + if devErr == nil { + os.Stderr = devNull + } + + xu, err := xgbutil.NewConnDisplay(display) + + if devErr == nil { + os.Stderr = oldStderr + _ = devNull.Close() + } + + if err != nil { + logger.Error("Failed to connect to X11 display", "display", display, "error", err) + return nil, fmt.Errorf("failed to connect to X11 display %s: %w", display, err) + } + + logger.Debug("Successfully connected to X11 display", "display", display) + + return &X11Client{ + xu: xu, + conn: xu.Conn(), + screen: xproto.Setup(xu.Conn()).DefaultScreen(xu.Conn()), + display: display, + }, nil +} + +// Close closes the X11 connection +func (c *X11Client) Close() { + if c.conn != nil { + c.conn.Close() + } +} + +// GetScreenDimensions returns the screen width and height +func (c *X11Client) GetScreenDimensions() (int, int) { + return int(c.screen.WidthInPixels), int(c.screen.HeightInPixels) +} + +// CaptureScreen captures a screenshot of the entire screen or a region +func (c *X11Client) CaptureScreen(x, y, width, height int) (image.Image, error) { + if width == 0 || height == 0 { + width = int(c.screen.WidthInPixels) + height = int(c.screen.HeightInPixels) + x = 0 + y = 0 + } + + root := c.screen.Root + + ximg, err := xgraphics.NewDrawable(c.xu, xproto.Drawable(root)) + if err != nil { + return nil, fmt.Errorf("failed to create drawable: %w", err) + } + + subImg := ximg.SubImage(image.Rect(x, y, x+width, y+height)) + + return subImg, nil +} + +// CaptureScreenBytes captures a screenshot and returns it as PNG bytes +func (c *X11Client) CaptureScreenBytes(x, y, width, height int) ([]byte, error) { + img, err := c.CaptureScreen(x, y, width, height) + if err != nil { + return nil, err + } + + var buf bytes.Buffer + if err := png.Encode(&buf, img); err != nil { + return nil, fmt.Errorf("failed to encode PNG: %w", err) + } + + return buf.Bytes(), nil +} + +// GetCursorPosition returns the current cursor position +func (c *X11Client) GetCursorPosition() (int, int, error) { + root := c.screen.Root + + pointer, err := xproto.QueryPointer(c.conn, root).Reply() + if err != nil { + return 0, 0, fmt.Errorf("failed to query pointer: %w", err) + } + + return int(pointer.RootX), int(pointer.RootY), nil +} + +// MoveMouse moves the cursor to the specified absolute coordinates +func (c *X11Client) MoveMouse(x, y int) error { + root := c.screen.Root + + err := xproto.WarpPointerChecked( + c.conn, + xproto.WindowNone, + root, + 0, 0, + 0, 0, + int16(x), int16(y), + ).Check() + + if err != nil { + return fmt.Errorf("failed to move mouse: %w", err) + } + + c.conn.Sync() + + return nil +} + +// ClickMouse performs a mouse click at the current cursor position +func (c *X11Client) ClickMouse(button string, clicks int) error { + // Note: X11 mouse clicking requires the XTEST extension which is not + // fully implemented in the pure Go xgb library. + // For production use, consider using xdotool as a fallback or implementing + // XTEST extension support. + + return fmt.Errorf("X11 mouse clicking requires xdotool (install with: sudo apt install xdotool). Use Wayland with ydotool for native support, or we can add xdotool fallback") +} + +// TypeText types the given text by sending key events +func (c *X11Client) TypeText(text string) error { + // This is a simplified implementation + // A full implementation would need to: + // 1. Map characters to keycodes using the keyboard mapping + // 2. Handle modifier keys (Shift, Ctrl, etc.) + // 3. Send KeyPress and KeyRelease events for each character + + // For now, return an error indicating this needs proper keysym mapping + return fmt.Errorf("text typing via X11 requires keysym mapping (not yet implemented)") +} + +// SendKeyCombo sends a key combination (e.g., "ctrl+c") +func (c *X11Client) SendKeyCombo(combo string) error { + // This is a simplified implementation + // A full implementation would need to: + // 1. Parse the combo string to extract modifiers and key + // 2. Map key names to keycodes + // 3. Send modifier key presses + // 4. Send the main key press/release + // 5. Release modifier keys + + // For now, return an error indicating this needs proper implementation + return fmt.Errorf("key combinations via X11 require keysym mapping (not yet implemented)") +} diff --git a/internal/services/tools/get_latest_screenshot.go b/internal/services/tools/get_latest_screenshot.go new file mode 100644 index 00000000..ef97fe07 --- /dev/null +++ b/internal/services/tools/get_latest_screenshot.go @@ -0,0 +1,179 @@ +package tools + +import ( + "context" + "fmt" + "time" + + config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" + sdk "github.com/inference-gateway/sdk" +) + +// GetLatestScreenshotTool retrieves the latest screenshot from the circular buffer +// This tool is used when screenshot streaming is enabled to avoid redundant captures +type GetLatestScreenshotTool struct { + config *config.Config + enabled bool + formatter domain.BaseFormatter + provider domain.ScreenshotProvider + lastCallTime time.Time + minCallInterval time.Duration +} + +// NewGetLatestScreenshotTool creates a new tool that reads from the screenshot buffer +func NewGetLatestScreenshotTool(cfg *config.Config, provider domain.ScreenshotProvider) *GetLatestScreenshotTool { + minInterval := time.Duration(cfg.ComputerUse.Screenshot.CaptureInterval) * time.Second + if minInterval < 2*time.Second { + minInterval = 2 * time.Second + } + + return &GetLatestScreenshotTool{ + config: cfg, + enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.Screenshot.StreamingEnabled, + formatter: domain.NewBaseFormatter("GetLatestScreenshot"), + provider: provider, + minCallInterval: minInterval, + } +} + +// Definition returns the tool definition for the LLM +func (t *GetLatestScreenshotTool) Definition() sdk.ChatCompletionTool { + description := "Retrieves the latest screenshot from the buffer. This is a read-only operation that does NOT require approval. Use this tool to see the current state of the screen. Screenshots are automatically captured every few seconds when streaming is enabled." + return sdk.ChatCompletionTool{ + Type: sdk.Function, + Function: sdk.FunctionObject{ + Name: "GetLatestScreenshot", + Description: &description, + Parameters: &sdk.FunctionParameters{ + "type": "object", + "properties": map[string]any{}, + }, + }, + } +} + +// Execute retrieves the latest screenshot from the buffer +func (t *GetLatestScreenshotTool) Execute(ctx context.Context, args map[string]any) (*domain.ToolExecutionResult, error) { + start := time.Now() + + if !t.lastCallTime.IsZero() { + timeSinceLastCall := time.Since(t.lastCallTime) + if timeSinceLastCall < t.minCallInterval { + waitTime := t.minCallInterval - timeSinceLastCall + return &domain.ToolExecutionResult{ + ToolName: "GetLatestScreenshot", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: fmt.Sprintf("please wait %v before requesting another screenshot (last called %v ago)", waitTime.Round(time.Second), timeSinceLastCall.Round(time.Second)), + }, nil + } + } + + t.lastCallTime = time.Now() + + if t.provider == nil { + return &domain.ToolExecutionResult{ + ToolName: "GetLatestScreenshot", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "screenshot provider not available", + }, nil + } + + screenshot, err := t.provider.GetLatestScreenshot() + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "GetLatestScreenshot", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + mimeType := "image/" + screenshot.Format + imageAttachment := domain.ImageAttachment{ + Data: screenshot.Data, + MimeType: mimeType, + DisplayName: "screenshot-latest", + } + + result := domain.ScreenshotToolResult{ + Display: t.config.ComputerUse.Display, + Region: nil, + Width: screenshot.Width, + Height: screenshot.Height, + Format: screenshot.Format, + Method: screenshot.Method, + } + + return &domain.ToolExecutionResult{ + ToolName: "GetLatestScreenshot", + Arguments: args, + Success: true, + Duration: time.Since(start), + Data: result, + Images: []domain.ImageAttachment{imageAttachment}, + }, nil +} + +// Validate checks if the tool arguments are valid +func (t *GetLatestScreenshotTool) Validate(args map[string]any) error { + // No arguments needed + return nil +} + +// IsEnabled returns whether this tool is enabled +func (t *GetLatestScreenshotTool) IsEnabled() bool { + return t.enabled && t.provider != nil +} + +// FormatResult formats tool execution results for different contexts +func (t *GetLatestScreenshotTool) FormatResult(result *domain.ToolExecutionResult, formatType domain.FormatterType) string { + switch formatType { + case domain.FormatterLLM: + return t.FormatForLLM(result) + case domain.FormatterShort: + return t.FormatPreview(result) + default: + return t.FormatForLLM(result) + } +} + +// FormatPreview returns a short preview of the result for UI display +func (t *GetLatestScreenshotTool) FormatPreview(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return "Failed to get latest screenshot" + } + data, ok := result.Data.(domain.ScreenshotToolResult) + if !ok { + return "Latest screenshot retrieved" + } + return fmt.Sprintf("Latest screenshot: %dx%d (%s)", data.Width, data.Height, data.Method) +} + +// FormatForLLM formats the result for LLM consumption +func (t *GetLatestScreenshotTool) FormatForLLM(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return fmt.Sprintf("Error: %s", result.Error) + } + data, ok := result.Data.(domain.ScreenshotToolResult) + if !ok { + return "Latest screenshot retrieved successfully. Image is attached." + } + return fmt.Sprintf("Latest screenshot retrieved successfully (%dx%d, format: %s, method: %s). This screenshot was automatically captured by the streaming system. Image is attached.", + data.Width, data.Height, data.Format, data.Method) +} + +// ShouldCollapseArg determines if an argument should be collapsed in display +func (t *GetLatestScreenshotTool) ShouldCollapseArg(key string) bool { + return false +} + +// ShouldAlwaysExpand determines if tool results should always be expanded in UI +func (t *GetLatestScreenshotTool) ShouldAlwaysExpand() bool { + return false +} diff --git a/internal/services/tools/keyboard_type.go b/internal/services/tools/keyboard_type.go new file mode 100644 index 00000000..de216552 --- /dev/null +++ b/internal/services/tools/keyboard_type.go @@ -0,0 +1,280 @@ +package tools + +import ( + "context" + "fmt" + "time" + + config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" + sdk "github.com/inference-gateway/sdk" +) + +// KeyboardTypeTool types text or sends key combinations +type KeyboardTypeTool struct { + config *config.Config + enabled bool + formatter domain.BaseFormatter + rateLimiter *RateLimiter +} + +// NewKeyboardTypeTool creates a new keyboard type tool +func NewKeyboardTypeTool(cfg *config.Config, rateLimiter *RateLimiter) *KeyboardTypeTool { + return &KeyboardTypeTool{ + config: cfg, + enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.KeyboardType.Enabled, + formatter: domain.NewBaseFormatter("KeyboardType"), + rateLimiter: rateLimiter, + } +} + +// Definition returns the tool definition for the LLM +func (t *KeyboardTypeTool) Definition() sdk.ChatCompletionTool { + description := "Types text or sends key combinations. Can type regular text or send special key combinations like 'ctrl+c'. Requires user approval unless in auto-accept mode. Note: Exactly one of 'text' or 'key_combo' must be provided." + return sdk.ChatCompletionTool{ + Type: sdk.Function, + Function: sdk.FunctionObject{ + Name: "KeyboardType", + Description: &description, + Parameters: &sdk.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "text": map[string]any{ + "type": "string", + "description": "Text to type. Mutually exclusive with key_combo.", + }, + "key_combo": map[string]any{ + "type": "string", + "description": "Key combination to send (e.g., 'ctrl+c', 'alt+tab', 'shift+enter'). Mutually exclusive with text.", + }, + "display": map[string]any{ + "type": "string", + "description": "Display to use (e.g., ':0'). Defaults to ':0'.", + "default": ":0", + }, + }, + }, + }, + } +} + +// Execute runs the keyboard type tool with given arguments +func (t *KeyboardTypeTool) Execute(ctx context.Context, args map[string]any) (*domain.ToolExecutionResult, error) { + start := time.Now() + + if err := t.rateLimiter.CheckAndRecord("KeyboardType"); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + text, hasText := args["text"].(string) + keyCombo, hasKeyCombo := args["key_combo"].(string) + + if !hasText && !hasKeyCombo { + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "either 'text' or 'key_combo' must be provided", + }, nil + } + + if hasText && hasKeyCombo { + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "only one of 'text' or 'key_combo' can be provided", + }, nil + } + + display := t.config.ComputerUse.Display + if displayArg, ok := args["display"].(string); ok && displayArg != "" { + display = displayArg + } + + displayServer := DetectDisplayServer() + + switch displayServer { + case DisplayServerX11: + client, err := NewX11Client(display) + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + defer client.Close() + + if hasText { + err = client.TypeText(text) + } else { + err = client.SendKeyCombo(keyCombo) + } + + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + case DisplayServerWayland: + client, err := NewWaylandClient(display) + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + defer client.Close() + + if hasText { + err = client.TypeText(text) + } else { + err = client.SendKeyCombo(keyCombo) + } + + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + default: + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "no display server detected (neither X11 nor Wayland)", + }, nil + } + + result := domain.KeyboardTypeToolResult{ + Text: text, + KeyCombo: keyCombo, + Display: display, + Method: displayServer.String(), + } + + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: true, + Duration: time.Since(start), + Data: result, + }, nil +} + +// Validate checks if the tool arguments are valid +func (t *KeyboardTypeTool) Validate(args map[string]any) error { + text, hasText := args["text"].(string) + keyCombo, hasKeyCombo := args["key_combo"].(string) + + if !hasText && !hasKeyCombo { + return fmt.Errorf("either 'text' or 'key_combo' must be provided") + } + + if hasText && hasKeyCombo { + return fmt.Errorf("only one of 'text' or 'key_combo' can be provided") + } + + if hasText { + if len(text) > t.config.ComputerUse.KeyboardType.MaxTextLength { + return fmt.Errorf("text length exceeds maximum of %d characters", t.config.ComputerUse.KeyboardType.MaxTextLength) + } + if len(text) == 0 { + return fmt.Errorf("text cannot be empty") + } + } + + if hasKeyCombo { + if len(keyCombo) == 0 { + return fmt.Errorf("key_combo cannot be empty") + } + } + + return nil +} + +// IsEnabled returns whether this tool is enabled +func (t *KeyboardTypeTool) IsEnabled() bool { + return t.enabled +} + +// FormatResult formats tool execution results for different contexts +func (t *KeyboardTypeTool) FormatResult(result *domain.ToolExecutionResult, formatType domain.FormatterType) string { + switch formatType { + case domain.FormatterLLM: + return t.FormatForLLM(result) + case domain.FormatterShort: + return t.FormatPreview(result) + default: + return t.FormatForLLM(result) + } +} + +// FormatPreview returns a short preview of the result for UI display +func (t *KeyboardTypeTool) FormatPreview(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return "Keyboard input failed" + } + data, ok := result.Data.(domain.KeyboardTypeToolResult) + if !ok { + return "Keyboard input sent" + } + if data.Text != "" { + preview := data.Text + if len(preview) > 30 { + preview = preview[:27] + "..." + } + return fmt.Sprintf("Typed: %s", preview) + } + return fmt.Sprintf("Key combo: %s", data.KeyCombo) +} + +// FormatForLLM formats the result for LLM consumption +func (t *KeyboardTypeTool) FormatForLLM(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return fmt.Sprintf("Error: %s", result.Error) + } + data, ok := result.Data.(domain.KeyboardTypeToolResult) + if !ok { + return "Keyboard input sent successfully" + } + if data.Text != "" { + return fmt.Sprintf("Typed text: '%s' using %s", data.Text, data.Method) + } + return fmt.Sprintf("Sent key combination '%s' using %s", data.KeyCombo, data.Method) +} + +// ShouldCollapseArg determines if an argument should be collapsed in display +func (t *KeyboardTypeTool) ShouldCollapseArg(key string) bool { + return key == "text" +} + +// ShouldAlwaysExpand determines if tool results should always be expanded in UI +func (t *KeyboardTypeTool) ShouldAlwaysExpand() bool { + return false +} diff --git a/internal/services/tools/mouse_click.go b/internal/services/tools/mouse_click.go new file mode 100644 index 00000000..7d4a0bc7 --- /dev/null +++ b/internal/services/tools/mouse_click.go @@ -0,0 +1,325 @@ +package tools + +import ( + "context" + "fmt" + "time" + + config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" + sdk "github.com/inference-gateway/sdk" +) + +// MouseClickTool performs mouse clicks +type MouseClickTool struct { + config *config.Config + enabled bool + formatter domain.BaseFormatter + rateLimiter *RateLimiter +} + +// NewMouseClickTool creates a new mouse click tool +func NewMouseClickTool(cfg *config.Config, rateLimiter *RateLimiter) *MouseClickTool { + return &MouseClickTool{ + config: cfg, + enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.MouseClick.Enabled, + formatter: domain.NewBaseFormatter("MouseClick"), + rateLimiter: rateLimiter, + } +} + +// Definition returns the tool definition for the LLM +func (t *MouseClickTool) Definition() sdk.ChatCompletionTool { + description := "Performs a mouse click. Can click at current position or move to coordinates first. Supports left, right, and middle buttons. Requires user approval unless in auto-accept mode." + return sdk.ChatCompletionTool{ + Type: sdk.Function, + Function: sdk.FunctionObject{ + Name: "MouseClick", + Description: &description, + Parameters: &sdk.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "button": map[string]any{ + "type": "string", + "description": "Mouse button to click", + "enum": []string{"left", "right", "middle"}, + "default": "left", + }, + "clicks": map[string]any{ + "type": "integer", + "description": "Number of clicks (1=single, 2=double, 3=triple)", + "enum": []int{1, 2, 3}, + "default": 1, + }, + "x": map[string]any{ + "type": "integer", + "description": "Optional: X coordinate to move to before clicking", + }, + "y": map[string]any{ + "type": "integer", + "description": "Optional: Y coordinate to move to before clicking", + }, + "display": map[string]any{ + "type": "string", + "description": "Display to use (e.g., ':0'). Defaults to ':0'.", + "default": ":0", + }, + }, + "required": []string{"button"}, + }, + }, + } +} + +// Execute runs the mouse click tool with given arguments +func (t *MouseClickTool) Execute(ctx context.Context, args map[string]any) (*domain.ToolExecutionResult, error) { + start := time.Now() + + if err := t.rateLimiter.CheckAndRecord("MouseClick"); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + button, ok := args["button"].(string) + if !ok { + button = "left" + } + + clicks := 1 + if clicksArg, ok := args["clicks"].(float64); ok { + clicks = int(clicksArg) + } + + display := t.config.ComputerUse.Display + if displayArg, ok := args["display"].(string); ok && displayArg != "" { + display = displayArg + } + + var finalX, finalY int + shouldMove := false + + if xArg, xOk := args["x"].(float64); xOk { + if yArg, yOk := args["y"].(float64); yOk { + finalX = int(xArg) + finalY = int(yArg) + shouldMove = true + } + } + + displayServer := DetectDisplayServer() + + switch displayServer { + case DisplayServerX11: + client, err := NewX11Client(display) + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + defer client.Close() + + if shouldMove { + if err := client.MoveMouse(finalX, finalY); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: fmt.Sprintf("failed to move mouse: %v", err), + }, nil + } + } else { + x, y, _ := client.GetCursorPosition() + finalX, finalY = x, y + } + + if err := client.ClickMouse(button, clicks); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + case DisplayServerWayland: + client, err := NewWaylandClient(display) + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + defer client.Close() + + if shouldMove { + if err := client.MoveMouse(finalX, finalY); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: fmt.Sprintf("failed to move mouse: %v", err), + }, nil + } + } + + if err := client.ClickMouse(button, clicks); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + default: + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "no display server detected (neither X11 nor Wayland)", + }, nil + } + + result := domain.MouseClickToolResult{ + Button: button, + Clicks: clicks, + X: finalX, + Y: finalY, + Display: display, + Method: displayServer.String(), + } + + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: true, + Duration: time.Since(start), + Data: result, + }, nil +} + +// Validate checks if the tool arguments are valid +func (t *MouseClickTool) Validate(args map[string]any) error { + button, ok := args["button"].(string) + if !ok { + return fmt.Errorf("button is required") + } + + if button != "left" && button != "right" && button != "middle" { + return fmt.Errorf("button must be 'left', 'right', or 'middle'") + } + + if clicksArg, ok := args["clicks"].(float64); ok { + clicks := int(clicksArg) + if clicks < 1 || clicks > 3 { + return fmt.Errorf("clicks must be 1, 2, or 3") + } + } + + if xArg, xOk := args["x"].(float64); xOk { + if _, yOk := args["y"].(float64); !yOk { + return fmt.Errorf("both x and y must be provided together") + } + if xArg < 0 { + return fmt.Errorf("x coordinate must be >= 0") + } + } + + if yArg, yOk := args["y"].(float64); yOk { + if _, xOk := args["x"].(float64); !xOk { + return fmt.Errorf("both x and y must be provided together") + } + if yArg < 0 { + return fmt.Errorf("y coordinate must be >= 0") + } + } + + return nil +} + +// IsEnabled returns whether this tool is enabled +func (t *MouseClickTool) IsEnabled() bool { + return t.enabled +} + +// FormatResult formats tool execution results for different contexts +func (t *MouseClickTool) FormatResult(result *domain.ToolExecutionResult, formatType domain.FormatterType) string { + switch formatType { + case domain.FormatterLLM: + return t.FormatForLLM(result) + case domain.FormatterShort: + return t.FormatPreview(result) + default: + return t.FormatForLLM(result) + } +} + +// FormatPreview returns a short preview of the result for UI display +func (t *MouseClickTool) FormatPreview(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return "Mouse click failed" + } + data, ok := result.Data.(domain.MouseClickToolResult) + if !ok { + return "Mouse clicked" + } + var clickType string + switch data.Clicks { + case 2: + clickType = "double-click" + case 3: + clickType = "triple-click" + default: + clickType = "click" + } + return fmt.Sprintf("%s %s at (%d, %d)", data.Button, clickType, data.X, data.Y) +} + +// FormatForLLM formats the result for LLM consumption +func (t *MouseClickTool) FormatForLLM(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return fmt.Sprintf("Error: %s", result.Error) + } + data, ok := result.Data.(domain.MouseClickToolResult) + if !ok { + return "Mouse click performed successfully" + } + var clickDesc string + switch data.Clicks { + case 2: + clickDesc = "double-click" + case 3: + clickDesc = "triple-click" + default: + clickDesc = fmt.Sprintf("%d click(s)", data.Clicks) + } + return fmt.Sprintf("Performed %s %s at position (%d, %d) using %s", + data.Button, clickDesc, data.X, data.Y, data.Method) +} + +// ShouldCollapseArg determines if an argument should be collapsed in display +func (t *MouseClickTool) ShouldCollapseArg(key string) bool { + return false +} + +// ShouldAlwaysExpand determines if tool results should always be expanded in UI +func (t *MouseClickTool) ShouldAlwaysExpand() bool { + return false +} diff --git a/internal/services/tools/mouse_move.go b/internal/services/tools/mouse_move.go new file mode 100644 index 00000000..99b88ad7 --- /dev/null +++ b/internal/services/tools/mouse_move.go @@ -0,0 +1,247 @@ +package tools + +import ( + "context" + "fmt" + "time" + + config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" + sdk "github.com/inference-gateway/sdk" +) + +// MouseMoveTool moves the mouse cursor to specified coordinates +type MouseMoveTool struct { + config *config.Config + enabled bool + formatter domain.BaseFormatter + rateLimiter *RateLimiter +} + +// NewMouseMoveTool creates a new mouse move tool +func NewMouseMoveTool(cfg *config.Config, rateLimiter *RateLimiter) *MouseMoveTool { + return &MouseMoveTool{ + config: cfg, + enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.MouseMove.Enabled, + formatter: domain.NewBaseFormatter("MouseMove"), + rateLimiter: rateLimiter, + } +} + +// Definition returns the tool definition for the LLM +func (t *MouseMoveTool) Definition() sdk.ChatCompletionTool { + description := "Moves the mouse cursor to absolute screen coordinates. Requires user approval unless in auto-accept mode." + return sdk.ChatCompletionTool{ + Type: sdk.Function, + Function: sdk.FunctionObject{ + Name: "MouseMove", + Description: &description, + Parameters: &sdk.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "x": map[string]any{ + "type": "integer", + "description": "X coordinate (absolute position from left edge of screen)", + }, + "y": map[string]any{ + "type": "integer", + "description": "Y coordinate (absolute position from top edge of screen)", + }, + "display": map[string]any{ + "type": "string", + "description": "Display to use (e.g., ':0'). Defaults to ':0'.", + "default": ":0", + }, + }, + "required": []string{"x", "y"}, + }, + }, + } +} + +// Execute runs the mouse move tool with given arguments +func (t *MouseMoveTool) Execute(ctx context.Context, args map[string]any) (*domain.ToolExecutionResult, error) { + start := time.Now() + + if err := t.rateLimiter.CheckAndRecord("MouseMove"); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseMove", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + x, xOk := args["x"].(float64) + y, yOk := args["y"].(float64) + + if !xOk || !yOk { + return &domain.ToolExecutionResult{ + ToolName: "MouseMove", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "x and y coordinates are required", + }, nil + } + + display := t.config.ComputerUse.Display + if displayArg, ok := args["display"].(string); ok && displayArg != "" { + display = displayArg + } + + var fromX, fromY int + displayServer := DetectDisplayServer() + + switch displayServer { + case DisplayServerX11: + client, err := NewX11Client(display) + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseMove", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + defer client.Close() + + fromX, fromY, _ = client.GetCursorPosition() + + if err := client.MoveMouse(int(x), int(y)); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseMove", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + case DisplayServerWayland: + client, err := NewWaylandClient(display) + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseMove", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + defer client.Close() + + fromX, fromY = 0, 0 + + if err := client.MoveMouse(int(x), int(y)); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseMove", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + default: + return &domain.ToolExecutionResult{ + ToolName: "MouseMove", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "no display server detected (neither X11 nor Wayland)", + }, nil + } + + result := domain.MouseMoveToolResult{ + FromX: fromX, + FromY: fromY, + ToX: int(x), + ToY: int(y), + Display: display, + Method: displayServer.String(), + } + + return &domain.ToolExecutionResult{ + ToolName: "MouseMove", + Arguments: args, + Success: true, + Duration: time.Since(start), + Data: result, + }, nil +} + +// Validate checks if the tool arguments are valid +func (t *MouseMoveTool) Validate(args map[string]any) error { + x, xOk := args["x"].(float64) + y, yOk := args["y"].(float64) + + if !xOk { + return fmt.Errorf("x coordinate is required") + } + if !yOk { + return fmt.Errorf("y coordinate is required") + } + if x < 0 { + return fmt.Errorf("x coordinate must be >= 0") + } + if y < 0 { + return fmt.Errorf("y coordinate must be >= 0") + } + + return nil +} + +// IsEnabled returns whether this tool is enabled +func (t *MouseMoveTool) IsEnabled() bool { + return t.enabled +} + +// FormatResult formats tool execution results for different contexts +func (t *MouseMoveTool) FormatResult(result *domain.ToolExecutionResult, formatType domain.FormatterType) string { + switch formatType { + case domain.FormatterLLM: + return t.FormatForLLM(result) + case domain.FormatterShort: + return t.FormatPreview(result) + default: + return t.FormatForLLM(result) + } +} + +// FormatPreview returns a short preview of the result for UI display +func (t *MouseMoveTool) FormatPreview(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return "Mouse move failed" + } + data, ok := result.Data.(domain.MouseMoveToolResult) + if !ok { + return "Mouse moved" + } + return fmt.Sprintf("Moved mouse to (%d, %d)", data.ToX, data.ToY) +} + +// FormatForLLM formats the result for LLM consumption +func (t *MouseMoveTool) FormatForLLM(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return fmt.Sprintf("Error: %s", result.Error) + } + data, ok := result.Data.(domain.MouseMoveToolResult) + if !ok { + return "Mouse moved successfully" + } + return fmt.Sprintf("Mouse moved from (%d, %d) to (%d, %d) using %s", + data.FromX, data.FromY, data.ToX, data.ToY, data.Method) +} + +// ShouldCollapseArg determines if an argument should be collapsed in display +func (t *MouseMoveTool) ShouldCollapseArg(key string) bool { + return false +} + +// ShouldAlwaysExpand determines if tool results should always be expanded in UI +func (t *MouseMoveTool) ShouldAlwaysExpand() bool { + return false +} diff --git a/internal/services/tools/registry.go b/internal/services/tools/registry.go index d293adcb..223e574d 100644 --- a/internal/services/tools/registry.go +++ b/internal/services/tools/registry.go @@ -78,6 +78,14 @@ func (r *Registry) registerTools() { r.tools["A2A_SubmitTask"] = NewA2ASubmitTaskTool(r.config, r.taskTracker) } + if r.config.ComputerUse.Enabled { + rateLimiter := NewRateLimiter(r.config.ComputerUse.RateLimit) + r.tools["Screenshot"] = NewScreenshotTool(r.config, r.imageService, rateLimiter) + r.tools["MouseMove"] = NewMouseMoveTool(r.config, rateLimiter) + r.tools["MouseClick"] = NewMouseClickTool(r.config, rateLimiter) + r.tools["KeyboardType"] = NewKeyboardTypeTool(r.config, rateLimiter) + } + if r.config.MCP.Enabled && r.mcpManager != nil { r.registerMCPTools() } @@ -226,6 +234,25 @@ func (r *Registry) UnregisterMCPServerTools(serverName string) int { return removedCount } +// SetScreenshotServer dynamically registers the GetLatestScreenshot tool +// This should be called after the screenshot server is started +func (r *Registry) SetScreenshotServer(provider domain.ScreenshotProvider) { + if !r.config.ComputerUse.Enabled || !r.config.ComputerUse.Screenshot.StreamingEnabled { + logger.Debug("Screenshot streaming not enabled, skipping GetLatestScreenshot tool registration") + return + } + + if provider == nil { + logger.Warn("Screenshot provider is nil, cannot register GetLatestScreenshot tool") + return + } + + getLatestTool := NewGetLatestScreenshotTool(r.config, provider) + r.tools["GetLatestScreenshot"] = getLatestTool + + logger.Info("Dynamically registered GetLatestScreenshot tool for streaming mode") +} + // SetReadToolUsed marks that the Read tool has been used func (r *Registry) SetReadToolUsed() { r.readToolUsed = true diff --git a/internal/services/tools/screenshot.go b/internal/services/tools/screenshot.go new file mode 100644 index 00000000..69b75e1f --- /dev/null +++ b/internal/services/tools/screenshot.go @@ -0,0 +1,424 @@ +package tools + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "image" + "image/draw" + "image/jpeg" + "image/png" + "time" + + config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" + sdk "github.com/inference-gateway/sdk" + xdraw "golang.org/x/image/draw" +) + +// ScreenshotTool captures screenshots of the display +type ScreenshotTool struct { + config *config.Config + enabled bool + formatter domain.BaseFormatter + imageService domain.ImageService + rateLimiter *RateLimiter +} + +// NewScreenshotTool creates a new screenshot tool +func NewScreenshotTool(cfg *config.Config, imageService domain.ImageService, rateLimiter *RateLimiter) *ScreenshotTool { + return &ScreenshotTool{ + config: cfg, + enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.Screenshot.Enabled, + formatter: domain.NewBaseFormatter("Screenshot"), + imageService: imageService, + rateLimiter: rateLimiter, + } +} + +// Definition returns the tool definition for the LLM +func (t *ScreenshotTool) Definition() sdk.ChatCompletionTool { + description := "Captures a screenshot of the display. This is a read-only operation that does NOT require approval. Can capture the entire screen or a specific region." + return sdk.ChatCompletionTool{ + Type: sdk.Function, + Function: sdk.FunctionObject{ + Name: "Screenshot", + Description: &description, + Parameters: &sdk.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "region": map[string]any{ + "type": "object", + "description": "Optional region to capture. If not specified, captures the entire screen.", + "properties": map[string]any{ + "x": map[string]any{ + "type": "integer", + "description": "X coordinate of the top-left corner", + }, + "y": map[string]any{ + "type": "integer", + "description": "Y coordinate of the top-left corner", + }, + "width": map[string]any{ + "type": "integer", + "description": "Width of the region", + }, + "height": map[string]any{ + "type": "integer", + "description": "Height of the region", + }, + }, + }, + "display": map[string]any{ + "type": "string", + "description": "Display to capture from (e.g., ':0'). Defaults to ':0'.", + "default": ":0", + }, + }, + }, + }, + } +} + +// Execute runs the screenshot tool with given arguments +func (t *ScreenshotTool) Execute(ctx context.Context, args map[string]any) (*domain.ToolExecutionResult, error) { + start := time.Now() + + if t.rateLimiter != nil { + if err := t.rateLimiter.CheckAndRecord("Screenshot"); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "Screenshot", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + } + + display := t.config.ComputerUse.Display + if displayArg, ok := args["display"].(string); ok && displayArg != "" { + display = displayArg + } + + region, x, y, width, height := parseRegionArgs(args) + + displayServer := DetectDisplayServer() + method := displayServer.String() + + var imageBytes []byte + var captureWidth, captureHeight int + var err error + + switch displayServer { + case DisplayServerX11: + imageBytes, captureWidth, captureHeight, err = t.captureX11(display, x, y, width, height) + case DisplayServerWayland: + imageBytes, captureWidth, captureHeight, err = t.captureWayland(display, x, y, width, height) + default: + return &domain.ToolExecutionResult{ + ToolName: "Screenshot", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "no display server detected (neither X11 nor Wayland)", + }, nil + } + + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "Screenshot", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + optimized, err := t.optimizeScreenshot(imageBytes) + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "Screenshot", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: fmt.Sprintf("failed to optimize screenshot: %v", err), + }, nil + } + + base64Data := base64.StdEncoding.EncodeToString(optimized) + + mimeType := "image/" + t.config.ComputerUse.Screenshot.Format + imageAttachment := domain.ImageAttachment{ + Data: base64Data, + MimeType: mimeType, + DisplayName: fmt.Sprintf("screenshot-%s", display), + } + + result := domain.ScreenshotToolResult{ + Display: display, + Region: region, + Width: captureWidth, + Height: captureHeight, + Format: t.config.ComputerUse.Screenshot.Format, + Method: method, + } + + return &domain.ToolExecutionResult{ + ToolName: "Screenshot", + Arguments: args, + Success: true, + Duration: time.Since(start), + Data: result, + Images: []domain.ImageAttachment{imageAttachment}, + }, nil +} + +// captureX11 captures a screenshot using X11 +// parseRegionArgs extracts region parameters from tool arguments +func parseRegionArgs(args map[string]any) (*domain.ScreenRegion, int, int, int, int) { + var region *domain.ScreenRegion + var x, y, width, height int + + regionArg, ok := args["region"].(map[string]any) + if !ok { + return nil, 0, 0, 0, 0 + } + + region = &domain.ScreenRegion{} + if xVal, ok := regionArg["x"].(float64); ok { + region.X = int(xVal) + x = int(xVal) + } + if yVal, ok := regionArg["y"].(float64); ok { + region.Y = int(yVal) + y = int(yVal) + } + if wVal, ok := regionArg["width"].(float64); ok { + region.Width = int(wVal) + width = int(wVal) + } + if hVal, ok := regionArg["height"].(float64); ok { + region.Height = int(hVal) + height = int(hVal) + } + + return region, x, y, width, height +} + +func (t *ScreenshotTool) captureX11(display string, x, y, width, height int) ([]byte, int, int, error) { + client, err := NewX11Client(display) + if err != nil { + return nil, 0, 0, err + } + defer client.Close() + + if width == 0 || height == 0 { + width, height = client.GetScreenDimensions() + x, y = 0, 0 + } + + imageBytes, err := client.CaptureScreenBytes(x, y, width, height) + if err != nil { + return nil, 0, 0, err + } + + return imageBytes, width, height, nil +} + +// captureWayland captures a screenshot using Wayland tools +func (t *ScreenshotTool) captureWayland(display string, x, y, width, height int) ([]byte, int, int, error) { + client, err := NewWaylandClient(display) + if err != nil { + return nil, 0, 0, err + } + defer client.Close() + + if width == 0 || height == 0 { + w, h, err := client.GetScreenDimensions() + if err != nil { + w, h = 1920, 1080 + } + width, height = w, h + x, y = 0, 0 + } + + imageBytes, err := client.CaptureScreenBytes(x, y, width, height) + if err != nil { + return nil, 0, 0, err + } + + return imageBytes, width, height, nil +} + +// optimizeScreenshot optimizes the screenshot image by resizing and compressing +func (t *ScreenshotTool) optimizeScreenshot(imageBytes []byte) ([]byte, error) { + img, format, err := image.Decode(bytes.NewReader(imageBytes)) + if err != nil { + return nil, fmt.Errorf("failed to decode image: %w", err) + } + + img = t.resizeIfNeeded(img) + + return t.encodeImage(img, format) +} + +// resizeIfNeeded resizes the image if it exceeds max dimensions +func (t *ScreenshotTool) resizeIfNeeded(img image.Image) image.Image { + bounds := img.Bounds() + width := bounds.Dx() + height := bounds.Dy() + + maxWidth := t.config.ComputerUse.Screenshot.MaxWidth + maxHeight := t.config.ComputerUse.Screenshot.MaxHeight + + if maxWidth <= 0 && maxHeight <= 0 { + return img + } + + needsResize := false + newWidth := width + newHeight := height + + if maxWidth > 0 && width > maxWidth { + needsResize = true + ratio := float64(maxWidth) / float64(width) + newWidth = maxWidth + newHeight = int(float64(height) * ratio) + } + + if maxHeight > 0 && newHeight > maxHeight { + needsResize = true + ratio := float64(maxHeight) / float64(newHeight) + newHeight = maxHeight + newWidth = int(float64(newWidth) * ratio) + } + + if !needsResize { + return img + } + + dst := image.NewRGBA(image.Rect(0, 0, newWidth, newHeight)) + xdraw.NearestNeighbor.Scale(dst, dst.Bounds(), img, bounds, draw.Src, nil) + return dst +} + +// encodeImage encodes the image to bytes based on configuration +func (t *ScreenshotTool) encodeImage(img image.Image, originalFormat string) ([]byte, error) { + var buf bytes.Buffer + format := t.config.ComputerUse.Screenshot.Format + quality := t.config.ComputerUse.Screenshot.Quality + + if quality <= 0 || quality > 100 { + quality = 85 + } + + switch format { + case "jpeg", "jpg": + err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}) + if err != nil { + return nil, fmt.Errorf("failed to encode jpeg: %w", err) + } + case "png": + encoder := png.Encoder{ + CompressionLevel: png.DefaultCompression, + } + if err := encoder.Encode(&buf, img); err != nil { + return nil, fmt.Errorf("failed to encode png: %w", err) + } + default: + if err := png.Encode(&buf, img); err != nil { + return nil, fmt.Errorf("failed to encode default png: %w", err) + } + } + + return buf.Bytes(), nil +} + +// Validate checks if the tool arguments are valid +func (t *ScreenshotTool) Validate(args map[string]any) error { + regionArg, ok := args["region"].(map[string]any) + if !ok { + return nil + } + + x, xOk := regionArg["x"].(float64) + y, yOk := regionArg["y"].(float64) + width, wOk := regionArg["width"].(float64) + height, hOk := regionArg["height"].(float64) + + if xOk && x < 0 { + return fmt.Errorf("region x must be >= 0") + } + if yOk && y < 0 { + return fmt.Errorf("region y must be >= 0") + } + if wOk && width <= 0 { + return fmt.Errorf("region width must be > 0") + } + if hOk && height <= 0 { + return fmt.Errorf("region height must be > 0") + } + + return nil +} + +// IsEnabled returns whether this tool is enabled +func (t *ScreenshotTool) IsEnabled() bool { + if t.config.ComputerUse.Screenshot.StreamingEnabled { + return false + } + return t.enabled +} + +// FormatResult formats tool execution results for different contexts +func (t *ScreenshotTool) FormatResult(result *domain.ToolExecutionResult, formatType domain.FormatterType) string { + switch formatType { + case domain.FormatterLLM: + return t.FormatForLLM(result) + case domain.FormatterShort: + return t.FormatPreview(result) + default: + return t.FormatForLLM(result) + } +} + +// FormatPreview returns a short preview of the result for UI display +func (t *ScreenshotTool) FormatPreview(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return "Screenshot capture failed" + } + data, ok := result.Data.(domain.ScreenshotToolResult) + if !ok { + return "Screenshot captured" + } + return fmt.Sprintf("Screenshot captured: %dx%d (%s)", data.Width, data.Height, data.Method) +} + +// FormatForLLM formats the result for LLM consumption +func (t *ScreenshotTool) FormatForLLM(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return fmt.Sprintf("Error: %s", result.Error) + } + data, ok := result.Data.(domain.ScreenshotToolResult) + if !ok { + return "Screenshot captured successfully. Image is attached." + } + regionStr := "full screen" + if data.Region != nil { + regionStr = fmt.Sprintf("region x=%d y=%d w=%d h=%d", data.Region.X, data.Region.Y, data.Region.Width, data.Region.Height) + } + return fmt.Sprintf("Screenshot captured successfully (%s, %dx%d, format: %s, method: %s). Image is attached.", + regionStr, data.Width, data.Height, data.Format, data.Method) +} + +// ShouldCollapseArg determines if an argument should be collapsed in display +func (t *ScreenshotTool) ShouldCollapseArg(key string) bool { + return false +} + +// ShouldAlwaysExpand determines if tool results should always be expanded in UI +func (t *ScreenshotTool) ShouldAlwaysExpand() bool { + return false +} diff --git a/internal/ui/autocomplete/autocomplete.go b/internal/ui/autocomplete/autocomplete.go index b7c17fc3..e3092e07 100644 --- a/internal/ui/autocomplete/autocomplete.go +++ b/internal/ui/autocomplete/autocomplete.go @@ -286,8 +286,10 @@ func (a *AutocompleteImpl) generateArgumentTemplate(paramName string, properties switch paramType { case "string": return paramName + "=\"\"" - case "integer", "number": - return "" + case "integer": + return paramName + "=0" + case "number": + return paramName + "=0.0" case "boolean": return paramName + "=false" default: diff --git a/internal/web/pty_manager.go b/internal/web/pty_manager.go index 8e5144f2..6839f86d 100644 --- a/internal/web/pty_manager.go +++ b/internal/web/pty_manager.go @@ -30,9 +30,9 @@ type SessionHandler interface { type Session = SessionHandler // CreateSessionHandler creates either a local PTY session or remote SSH session -func CreateSessionHandler(webCfg *config.WebConfig, serverCfg *config.SSHServerConfig, cfg *config.Config, v *viper.Viper) (SessionHandler, error) { +func CreateSessionHandler(webCfg *config.WebConfig, serverCfg *config.SSHServerConfig, cfg *config.Config, v *viper.Viper, sessionID string, sessionManager *SessionManager) (SessionHandler, error) { if serverCfg != nil { - return createRemoteSSHSession(webCfg, serverCfg, cfg.Gateway.URL) + return createRemoteSSHSession(webCfg, serverCfg, cfg.Gateway.URL, sessionID, sessionManager) } logger.Info("Creating local PTY session") @@ -40,7 +40,7 @@ func CreateSessionHandler(webCfg *config.WebConfig, serverCfg *config.SSHServerC } // createRemoteSSHSession creates a remote SSH session with optional auto-install -func createRemoteSSHSession(webCfg *config.WebConfig, serverCfg *config.SSHServerConfig, gatewayURL string) (SessionHandler, error) { +func createRemoteSSHSession(webCfg *config.WebConfig, serverCfg *config.SSHServerConfig, gatewayURL string, sessionID string, sessionManager *SessionManager) (SessionHandler, error) { logger.Info("Creating remote SSH session", "server", serverCfg.Name) client, err := NewSSHClient(&webCfg.SSH, serverCfg) @@ -59,7 +59,11 @@ func createRemoteSSHSession(webCfg *config.WebConfig, serverCfg *config.SSHServe return nil, err } - session, err := NewSSHSession(client, serverCfg, gatewayURL) + if err := ensureRemoteConfig(client, serverCfg, gatewayURL); err != nil { + logger.Warn("Failed to ensure remote config, continuing anyway", "error", err) + } + + session, err := NewSSHSession(client, serverCfg, gatewayURL, sessionID, sessionManager) if err != nil { if closeErr := client.Close(); closeErr != nil { logger.Warn("Failed to close SSH client after session error", "error", closeErr) @@ -89,6 +93,52 @@ func ensureRemoteBinary(client *SSHClient, webCfg *config.WebConfig, serverCfg * return nil } +// ensureRemoteConfig ensures infer config exists on remote server +// Runs infer init --userspace if ~/.infer/config.yaml doesn't exist +func ensureRemoteConfig(client *SSHClient, serverCfg *config.SSHServerConfig, gatewayURL string) error { + commandPath := serverCfg.CommandPath + if commandPath == "" { + commandPath = "infer" + } + + logger.Info("Checking if infer config exists on remote server", "server", serverCfg.Name) + + session, err := client.NewSession() + if err != nil { + return fmt.Errorf("failed to create SSH session: %w", err) + } + defer func() { _ = session.Close() }() + + checkCmd := "test -f ~/.infer/config.yaml && echo 'exists' || echo 'missing'" + output, err := session.CombinedOutput(checkCmd) + if err != nil { + return fmt.Errorf("failed to check config file: %w", err) + } + + outputStr := string(output) + if len(outputStr) > 0 && outputStr[0] == 'e' { + logger.Info("Infer config already exists on remote server", "server", serverCfg.Name) + return nil + } + + logger.Info("Infer config not found, running init...", "server", serverCfg.Name) + + session2, err := client.NewSession() + if err != nil { + return fmt.Errorf("failed to create SSH session for init: %w", err) + } + defer func() { _ = session2.Close() }() + + initCmd := fmt.Sprintf("%s init --userspace", commandPath) + initOutput, err := session2.CombinedOutput(initCmd) + if err != nil { + return fmt.Errorf("failed to initialize config: %w\nOutput: %s", err, string(initOutput)) + } + + logger.Info("Infer config initialized", "server", serverCfg.Name, "output", string(initOutput)) + return nil +} + // LocalPTYSession represents a single local terminal session type LocalPTYSession struct { cfg *config.Config diff --git a/internal/web/server.go b/internal/web/server.go index 3f1dc61c..f08ed808 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "html/template" + "io" "io/fs" "net/http" "os" @@ -69,8 +70,12 @@ func (s *WebTerminalServer) Start() error { mux.HandleFunc("/", s.handleIndex) mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.FS(staticFS)))) mux.HandleFunc("/api/servers", s.handleServers) + mux.HandleFunc("/api/screenshots/", s.handleScreenshotProxy) mux.HandleFunc("/ws", s.handleWebSocket) + logger.Info("HTTP routes registered", + "routes", []string{"/", "/static/", "/api/servers", "/api/screenshots/", "/ws"}) + addr := fmt.Sprintf("%s:%d", s.cfg.Web.Host, s.cfg.Web.Port) s.server = &http.Server{ Addr: addr, @@ -121,8 +126,6 @@ func (s *WebTerminalServer) handleServers(w http.ResponseWriter, r *http.Request } servers := []ServerInfo{} - - // Add local mode option servers = append(servers, ServerInfo{ ID: "local", Name: "Local", @@ -130,7 +133,6 @@ func (s *WebTerminalServer) handleServers(w http.ResponseWriter, r *http.Request Tags: []string{"local"}, }) - // Add configured remote servers for _, srv := range s.cfg.Web.Servers { servers = append(servers, ServerInfo{ ID: srv.ID, @@ -150,60 +152,163 @@ func (s *WebTerminalServer) handleServers(w http.ResponseWriter, r *http.Request } } -func (s *WebTerminalServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { - conn, err := s.upgrader.Upgrade(w, r, nil) +func (s *WebTerminalServer) handleScreenshotProxy(w http.ResponseWriter, r *http.Request) { + logger.Info("Screenshot proxy handler called", + "method", r.Method, + "path", r.URL.Path, + "query", r.URL.RawQuery) + + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" { + sessionID = r.Header.Get("X-Session-ID") + } + + if sessionID == "" { + logger.Warn("Screenshot proxy: missing session_id", "path", r.URL.Path) + http.Error(w, "Missing session_id", http.StatusBadRequest) + return + } + + port, ok := s.sessionManager.GetScreenshotPort(sessionID) + if !ok { + http.Error(w, "No screenshot port for session", http.StatusNotFound) + return + } + + targetURL := fmt.Sprintf("http://localhost:%d%s", port, r.URL.Path) + if r.URL.RawQuery != "" { + targetURL += "?" + r.URL.RawQuery + } + + logger.Info("Proxying screenshot request", "session_id", sessionID, "target", targetURL) + + proxyReq, err := http.NewRequest(r.Method, targetURL, r.Body) if err != nil { - logger.Error("WebSocket upgrade failed", "error", err) + logger.Error("Failed to create proxy request", "error", err) + http.Error(w, "Proxy error", http.StatusInternalServerError) + return + } + + for name, values := range r.Header { + for _, value := range values { + proxyReq.Header.Add(name, value) + } + } + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(proxyReq) + if err != nil { + logger.Error("Proxy request failed", "error", err) + http.Error(w, "Failed to fetch screenshot", http.StatusBadGateway) return } defer func() { - if err := conn.Close(); err != nil { - logger.Warn("Failed to close WebSocket connection", "error", err) + if err := resp.Body.Close(); err != nil { + logger.Warn("Failed to close response body", "error", err) } }() - sessionID := uuid.New().String() - logger.Info("WebSocket connected", "remote", r.RemoteAddr, "session_id", sessionID) + for name, values := range resp.Header { + for _, value := range values { + w.Header().Add(name, value) + } + } - // Wait for initial connection message with server selection - cols, rows := 80, 24 - serverID := "local" // Default to local mode + w.WriteHeader(resp.StatusCode) + + if _, err := io.Copy(w, resp.Body); err != nil { + logger.Warn("Failed to copy response body", "error", err) + } +} + +// handleInitMessage reads and processes the initial WebSocket message +func (s *WebTerminalServer) handleInitMessage(conn *websocket.Conn, sessionID string) (cols, rows int, serverID string) { + cols, rows = 80, 24 + serverID = "local" if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { logger.Warn("Failed to set read deadline", "session_id", sessionID, "error", err) } + msgType, data, err := conn.ReadMessage() + if err := conn.SetReadDeadline(time.Time{}); err != nil { logger.Warn("Failed to clear read deadline", "session_id", sessionID, "error", err) } - if err == nil && msgType == websocket.TextMessage { - var msg struct { - Type string `json:"type"` - ServerID string `json:"server_id"` - Cols int `json:"cols"` - Rows int `json:"rows"` - } - if json.Unmarshal(data, &msg) == nil && msg.Type == "init" { - cols, rows = msg.Cols, msg.Rows - serverID = msg.ServerID - logger.Info("Session initialized", - "session_id", sessionID, - "server_id", serverID, - "cols", cols, - "rows", rows) - } - } else if err != nil { - logger.Warn("Failed to read init message, using defaults", - "session_id", sessionID, "error", err) + if err != nil { + logger.Warn("Failed to read init message, using defaults", "session_id", sessionID, "error", err) + return } + if msgType != websocket.TextMessage { + return + } + + var msg struct { + Type string `json:"type"` + ServerID string `json:"server_id"` + Cols int `json:"cols"` + Rows int `json:"rows"` + } + + if err := json.Unmarshal(data, &msg); err != nil { + return + } + + if msg.Type != "init" { + return + } + + cols, rows = msg.Cols, msg.Rows + serverID = msg.ServerID + + logger.Info("Session initialized", + "session_id", sessionID, + "server_id", serverID, + "cols", cols, + "rows", rows) + + initResp := map[string]any{ + "type": "init_response", + "session_id": sessionID, + } + + respData, err := json.Marshal(initResp) + if err != nil { + return + } + + if writeErr := conn.WriteMessage(websocket.TextMessage, respData); writeErr != nil { + logger.Warn("Failed to send init response", "error", writeErr) + } + + return +} + +func (s *WebTerminalServer) handleWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Error("WebSocket upgrade failed", "error", err) + return + } + defer func() { + if err := conn.Close(); err != nil { + logger.Warn("Failed to close WebSocket connection", "error", err) + } + }() + + sessionID := uuid.New().String() + logger.Info("WebSocket connected", "remote", r.RemoteAddr, "session_id", sessionID) + + cols, rows, serverID := s.handleInitMessage(conn, sessionID) + serverCfg, ok := s.findServerConfig(serverID, sessionID, conn) if !ok { return } - handler, err := CreateSessionHandler(&s.cfg.Web, serverCfg, s.cfg, s.viper) + handler, err := CreateSessionHandler(&s.cfg.Web, serverCfg, s.cfg, s.viper, sessionID, s.sessionManager) if err != nil { logger.Error("Failed to create session", "error", err, @@ -214,7 +319,11 @@ func (s *WebTerminalServer) handleWebSocket(w http.ResponseWriter, r *http.Reque } return } + + s.sessionManager.RegisterSession(sessionID, handler) + defer func() { + s.sessionManager.RemoveSession(sessionID) if closeErr := handler.Close(); closeErr != nil { logger.Warn("Failed to close session handler", "session_id", sessionID, "error", closeErr) } @@ -231,7 +340,6 @@ func (s *WebTerminalServer) handleWebSocket(w http.ResponseWriter, r *http.Reque logger.Info("Session started", "session_id", sessionID, "server_id", serverID) - // Handle I/O if err := handler.HandleConnection(conn); err != nil { logger.Error("Connection error", "session_id", sessionID, "error", err) } diff --git a/internal/web/session_manager.go b/internal/web/session_manager.go index c652b562..c414cb3f 100644 --- a/internal/web/session_manager.go +++ b/internal/web/session_manager.go @@ -22,9 +22,10 @@ type SessionManager struct { // SessionEntry tracks a session and its activity type SessionEntry struct { - session Session - lastActive time.Time - mu sync.Mutex + session Session + lastActive time.Time + screenshotPort int // Local forwarded port for screenshot streaming + mu sync.Mutex } func NewSessionManager(cfg *config.Config, v *viper.Viper) *SessionManager { @@ -140,6 +141,70 @@ func (sm *SessionManager) ActiveSessionCount() int { return len(sm.sessions) } +// RegisterSession registers an existing session with the manager +func (sm *SessionManager) RegisterSession(sessionID string, session Session) { + sm.mu.Lock() + defer sm.mu.Unlock() + + entry := &SessionEntry{ + session: session, + lastActive: time.Now(), + } + + sm.sessions[sessionID] = entry + logger.Info("Session registered", "id", sessionID, "total", len(sm.sessions)) +} + +// SetScreenshotPort sets the local screenshot port for a session +func (sm *SessionManager) SetScreenshotPort(sessionID string, port int) { + sm.mu.RLock() + entry, exists := sm.sessions[sessionID] + sm.mu.RUnlock() + + if exists { + entry.mu.Lock() + entry.screenshotPort = port + entry.mu.Unlock() + logger.Info("Screenshot port set for session", + "session_id", sessionID, + "port", port) + } else { + logger.Warn("Cannot set screenshot port: session not found", + "session_id", sessionID, + "port", port) + } +} + +// GetScreenshotPort retrieves the local screenshot port for a session +func (sm *SessionManager) GetScreenshotPort(sessionID string) (int, bool) { + sm.mu.RLock() + entry, exists := sm.sessions[sessionID] + sm.mu.RUnlock() + + if !exists { + logger.Warn("Cannot get screenshot port: session not found", + "session_id", sessionID, + "total_sessions", len(sm.sessions)) + return 0, false + } + + entry.mu.Lock() + port := entry.screenshotPort + entry.mu.Unlock() + + if port == 0 { + logger.Warn("Screenshot port not set for session", + "session_id", sessionID) + return 0, false + } + + logger.Info("Retrieved screenshot port for session", + "session_id", sessionID, + "port", port) + + return port, true +} + // Shutdown stops all sessions and the cleanup goroutine func (sm *SessionManager) Shutdown() { close(sm.done) diff --git a/internal/web/ssh_session.go b/internal/web/ssh_session.go index 257042b7..f6b7f356 100644 --- a/internal/web/ssh_session.go +++ b/internal/web/ssh_session.go @@ -1,10 +1,13 @@ package web import ( + "bytes" "context" "encoding/json" "fmt" "io" + "net" + "strconv" "strings" "sync" @@ -17,21 +20,30 @@ import ( // SSHSession wraps an SSH session with PTY for remote terminal access type SSHSession struct { - sshClient *SSHClient - server *config.SSHServerConfig - gatewayURL string - session *ssh.Session - stdin io.WriteCloser - stdout io.Reader - stderr io.Reader - mu sync.Mutex - running bool - ctx context.Context - cancel context.CancelFunc + sshClient *SSHClient + server *config.SSHServerConfig + gatewayURL string + session *ssh.Session + stdin io.WriteCloser + stdout io.Reader + stderr io.Reader + mu sync.Mutex + running bool + ctx context.Context + cancel context.CancelFunc + ws *websocket.Conn + tunnelListener net.Listener + tunnelCtx context.Context + tunnelCancel context.CancelFunc + tunnelWg sync.WaitGroup // Track active forwarding goroutines + screenshotPort int + localScreenshotPort int + sessionID string + sessionManager *SessionManager } // NewSSHSession creates a new SSH session with PTY -func NewSSHSession(client *SSHClient, server *config.SSHServerConfig, gatewayURL string) (*SSHSession, error) { +func NewSSHSession(client *SSHClient, server *config.SSHServerConfig, gatewayURL string, sessionID string, sessionManager *SessionManager) (*SSHSession, error) { if client == nil { return nil, fmt.Errorf("SSH client is required") } @@ -42,11 +54,13 @@ func NewSSHSession(client *SSHClient, server *config.SSHServerConfig, gatewayURL ctx, cancel := context.WithCancel(context.Background()) return &SSHSession{ - sshClient: client, - server: server, - gatewayURL: gatewayURL, - ctx: ctx, - cancel: cancel, + sshClient: client, + server: server, + gatewayURL: gatewayURL, + ctx: ctx, + cancel: cancel, + sessionID: sessionID, + sessionManager: sessionManager, }, nil } @@ -111,7 +125,13 @@ func (s *SSHSession) Start(cols, rows int) error { } cmdArgs := append([]string{"chat"}, s.server.CommandArgs...) - cmd := fmt.Sprintf("INFER_GATEWAY_URL=%s INFER_GATEWAY_MODE=remote %s %s", + + // Source /etc/environment to pick up docker-compose environment variables, then run infer + // Redirect stderr to /dev/null to suppress X11 library warnings + cmd := fmt.Sprintf( + "sh -c 'set -a; test -f /etc/environment && . /etc/environment; set +a; "+ + "INFER_GATEWAY_URL=%s INFER_GATEWAY_MODE=remote "+ + "%s %s 2>/dev/null'", s.gatewayURL, commandPath, strings.Join(cmdArgs, " ")) logger.Info("Starting remote command", @@ -165,6 +185,11 @@ func (s *SSHSession) Resize(cols, rows int) error { // HandleConnection bridges WebSocket and SSH session I/O func (s *SSHSession) HandleConnection(conn *websocket.Conn) error { + // Store WebSocket connection for later use + s.mu.Lock() + s.ws = conn + s.mu.Unlock() + var wg sync.WaitGroup errChan := make(chan error, 2) @@ -279,6 +304,8 @@ func (s *SSHSession) handleSSHOutput(conn *websocket.Conn) error { } if n > 0 { + s.handleScreenshotPortEscape(buf[:n]) + if err := conn.WriteMessage(websocket.BinaryMessage, buf[:n]); err != nil { return fmt.Errorf("failed to write to websocket: %w", err) } @@ -286,6 +313,189 @@ func (s *SSHSession) handleSSHOutput(conn *websocket.Conn) error { } } +// handleScreenshotPortEscape checks for and handles screenshot port escape sequences +func (s *SSHSession) handleScreenshotPortEscape(data []byte) { + port, found := extractPortFromEscape(data) + if !found { + return + } + + logger.Info("Detected screenshot port from remote CLI", + "port", port, + "server", s.server.Name) + + localPort, err := s.SetupPortForwarding(port) + if err != nil { + logger.Error("Failed to set up port forwarding", "error", err) + return + } + + s.notifyWebUI(localPort) +} + +// SetupPortForwarding sets up SSH port forwarding from local to remote port +func (s *SSHSession) SetupPortForwarding(remotePort int) (int, error) { + s.mu.Lock() + defer s.mu.Unlock() + + listener, err := net.Listen("tcp", "localhost:0") + if err != nil { + return 0, fmt.Errorf("failed to listen on local port: %w", err) + } + + localPort := listener.Addr().(*net.TCPAddr).Port + s.tunnelListener = listener + s.screenshotPort = remotePort + s.localScreenshotPort = localPort + + // Create context for tunnel management + s.tunnelCtx, s.tunnelCancel = context.WithCancel(context.Background()) + + logger.Info("Setting up port forwarding", + "local_port", localPort, + "remote_port", remotePort, + "server", s.server.Name, + "session_id", s.sessionID) + + if s.sessionManager != nil { + s.sessionManager.SetScreenshotPort(s.sessionID, localPort) + logger.Info("Registered screenshot port with session manager", + "session_id", s.sessionID, + "local_port", localPort) + } + + go s.forwardConnections(listener, remotePort) + + return localPort, nil +} + +// forwardConnections handles port forwarding for incoming connections +func (s *SSHSession) forwardConnections(listener net.Listener, remotePort int) { + defer func() { + if err := listener.Close(); err != nil { + logger.Warn("Failed to close listener", "error", err) + } + }() + + for { + select { + case <-s.tunnelCtx.Done(): + logger.Info("Port forwarding stopped", "server", s.server.Name) + return + default: + } + + localConn, err := listener.Accept() + if err != nil { + select { + case <-s.tunnelCtx.Done(): + return + default: + logger.Error("Failed to accept local connection", "error", err) + continue + } + } + + s.tunnelWg.Add(1) + go s.handleForwardedConnection(localConn, remotePort) + } +} + +// handleForwardedConnection forwards a single connection through SSH +func (s *SSHSession) handleForwardedConnection(localConn net.Conn, remotePort int) { + defer s.tunnelWg.Done() + defer func() { + if err := localConn.Close(); err != nil { + logger.Debug("Failed to close local connection", "error", err) + } + }() + + remoteAddr := fmt.Sprintf("localhost:%d", remotePort) + remoteConn, err := s.sshClient.client.Dial("tcp", remoteAddr) + if err != nil { + logger.Error("Failed to dial remote address through SSH", + "remote_addr", remoteAddr, + "error", err) + return + } + defer func() { + if err := remoteConn.Close(); err != nil { + logger.Debug("Failed to close remote connection", "error", err) + } + }() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + if _, err := io.Copy(remoteConn, localConn); err != nil { + logger.Debug("Copy from local to remote ended", "error", err) + } + }() + + go func() { + defer wg.Done() + if _, err := io.Copy(localConn, remoteConn); err != nil { + logger.Debug("Copy from remote to local ended", "error", err) + } + }() + + wg.Wait() +} + +// notifyWebUI sends a message to the WebSocket with the local screenshot port +func (s *SSHSession) notifyWebUI(localPort int) { + if s.ws == nil { + logger.Warn("WebSocket connection not available for notification") + return + } + + msg := map[string]interface{}{ + "type": "screenshot_port", + "port": localPort, + } + + data, err := json.Marshal(msg) + if err != nil { + logger.Error("Failed to marshal screenshot port message", "error", err) + return + } + + if err := s.ws.WriteMessage(websocket.TextMessage, data); err != nil { + logger.Error("Failed to send screenshot port to WebSocket", "error", err) + return + } + + logger.Info("Notified WebUI of screenshot port", "local_port", localPort) +} + +// extractPortFromEscape extracts port number from escape sequence +func extractPortFromEscape(data []byte) (int, bool) { + startSeq := []byte("\x1b]5555;screenshot_port=") + endSeq := []byte("\x07") + + startIdx := bytes.Index(data, startSeq) + if startIdx == -1 { + return 0, false + } + + searchStart := startIdx + len(startSeq) + endIdx := bytes.Index(data[searchStart:], endSeq) + if endIdx == -1 { + return 0, false + } + + portStr := string(data[searchStart : searchStart+endIdx]) + port, err := strconv.Atoi(portStr) + if err != nil { + logger.Warn("Failed to parse port from escape sequence", "port_str", portStr, "error", err) + return 0, false + } + + return port, true +} + // Close terminates the SSH session func (s *SSHSession) Close() error { s.mu.Lock() @@ -297,6 +507,22 @@ func (s *SSHSession) Close() error { var errors []error + if s.tunnelCancel != nil { + s.tunnelCancel() + s.tunnelCancel = nil + } + + if s.tunnelListener != nil { + if err := s.tunnelListener.Close(); err != nil { + errors = append(errors, fmt.Errorf("failed to close tunnel listener: %w", err)) + } + s.tunnelListener = nil + } + + logger.Info("Waiting for active port forwarding connections to close", "server", s.server.Name) + s.tunnelWg.Wait() + logger.Info("All port forwarding connections closed", "server", s.server.Name) + if s.session != nil { if err := s.session.Close(); err != nil { errors = append(errors, fmt.Errorf("failed to close session: %w", err)) diff --git a/internal/web/static/app.js b/internal/web/static/app.js index 70e988ab..c11f8ef2 100644 --- a/internal/web/static/app.js +++ b/internal/web/static/app.js @@ -8,6 +8,7 @@ class TerminalManager { this.newTabBtn = document.getElementById('new-tab-btn'); this.serverSelector = document.getElementById('server-selector'); this.welcomeMessage = document.getElementById('welcome-message'); + this.screenshotToggleBtn = document.getElementById('screenshot-toggle-btn'); this.servers = []; this.currentServerID = 'local'; @@ -16,6 +17,7 @@ class TerminalManager { this.serverSelector.addEventListener('change', (e) => { this.currentServerID = e.target.value; }); + this.screenshotToggleBtn.addEventListener('click', () => this.toggleScreenshot()); } async loadServers() { @@ -75,6 +77,39 @@ class TerminalManager { const newTab = this.tabs.get(tabId); if (newTab) { newTab.activate(); + this.updateScreenshotButton(newTab); + } + } + + updateScreenshotButton(tab) { + if (tab && tab.screenshotOverlay) { + // Show button for tabs with screenshot overlay + this.screenshotToggleBtn.classList.remove('hidden'); + // Update active state + if (tab.screenshotOverlay.enabled) { + this.screenshotToggleBtn.classList.add('active'); + } else { + this.screenshotToggleBtn.classList.remove('active'); + } + } else { + // Hide button for tabs without screenshot overlay + this.screenshotToggleBtn.classList.add('hidden'); + this.screenshotToggleBtn.classList.remove('active'); + } + } + + toggleScreenshot() { + if (this.activeTabId === null) return; + + const activeTab = this.tabs.get(this.activeTabId); + if (activeTab && activeTab.screenshotOverlay) { + activeTab.screenshotOverlay.toggle(); + // Update button state + if (activeTab.screenshotOverlay.enabled) { + this.screenshotToggleBtn.classList.add('active'); + } else { + this.screenshotToggleBtn.classList.remove('active'); + } } } @@ -108,9 +143,11 @@ class TerminalTab { this.tabElement = null; this.containerElement = null; this.connected = false; + this.screenshotOverlay = null; this.createUI(); this.createTerminal(); + this.createScreenshotOverlay(); this.connect(); } @@ -188,6 +225,14 @@ class TerminalTab { }); } + createScreenshotOverlay() { + // Only create overlay for remote SSH sessions + if (this.serverID !== 'local' && typeof ScreenshotOverlay !== 'undefined') { + this.screenshotOverlay = new ScreenshotOverlay(this); + console.log(`Tab ${this.id}: Screenshot overlay created for ${this.serverID}`); + } + } + connect() { const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; const wsUrl = `${protocol}//${window.location.host}/ws`; @@ -219,6 +264,23 @@ class TerminalTab { this.term.write(new Uint8Array(buffer)); }); } else { + try { + const msg = JSON.parse(event.data); + if (msg.type === 'init_response') { + this.sessionID = msg.session_id; + console.log(`Tab ${this.id}: Session initialized with ID: ${this.sessionID}`); + return; + } + if (msg.type === 'screenshot_port') { + console.log(`Tab ${this.id}: Received screenshot port: ${msg.port}, session: ${this.sessionID}`); + if (this.screenshotOverlay && this.sessionID) { + this.screenshotOverlay.startPolling(this.sessionID); + } + return; + } + } catch (e) { + // Not JSON, treat as terminal data + } this.term.write(event.data); } }; @@ -261,6 +323,10 @@ class TerminalTab { } destroy() { + if (this.screenshotOverlay) { + this.screenshotOverlay.destroy(); + this.screenshotOverlay = null; + } if (this.socket) { this.socket.close(); } diff --git a/internal/web/static/screenshot-overlay.js b/internal/web/static/screenshot-overlay.js new file mode 100644 index 00000000..c5729095 --- /dev/null +++ b/internal/web/static/screenshot-overlay.js @@ -0,0 +1,227 @@ +/** + * ScreenshotOverlay manages the screenshot streaming overlay UI + */ +class ScreenshotOverlay { + constructor(terminalTab) { + this.terminalTab = terminalTab; + this.enabled = false; + this.screenshotPort = null; + this.pollInterval = null; + this.pollingFrequency = 2000; // 2 seconds + this.overlayElement = null; + this.imageElement = null; + this.timestampElement = null; + this.dimensionsElement = null; + this.statusElement = null; + this.errorElement = null; + + this.createUI(); + } + + /** + * Creates the overlay DOM structure + */ + createUI() { + // Create overlay container + this.overlayElement = document.createElement('div'); + this.overlayElement.className = 'screenshot-overlay hidden'; + this.overlayElement.innerHTML = ` +
+

Live Screenshot

+ +
+
+
Connecting...
+ + Remote screenshot +
+
+
+
+
+ `; + + // Cache DOM elements + this.imageElement = this.overlayElement.querySelector('.screenshot-image'); + this.timestampElement = this.overlayElement.querySelector('.screenshot-timestamp'); + this.dimensionsElement = this.overlayElement.querySelector('.screenshot-dimensions'); + this.statusElement = this.overlayElement.querySelector('.screenshot-status'); + this.errorElement = this.overlayElement.querySelector('.screenshot-error'); + + // Bind close button + const closeBtn = this.overlayElement.querySelector('.screenshot-close'); + closeBtn.addEventListener('click', () => this.hide()); + + // Append to terminal container + if (this.terminalTab && this.terminalTab.containerElement) { + this.terminalTab.containerElement.appendChild(this.overlayElement); + } else { + document.body.appendChild(this.overlayElement); + } + } + + /** + * Starts screenshot polling with the given session ID + * @param {string} sessionID - Session ID for this terminal tab + */ + startPolling(sessionID) { + if (this.pollInterval) { + clearInterval(this.pollInterval); + } + + this.sessionID = sessionID; + console.log(`Screenshot overlay: starting polling for session ${sessionID}`); + + // Show status + this.updateStatus('Loading screenshots...'); + this.hideError(); + + // Fetch immediately + this.fetchLatest(); + + // Start polling + this.pollInterval = setInterval(() => { + this.fetchLatest(); + }, this.pollingFrequency); + } + + /** + * Stops screenshot polling + */ + stopPolling() { + if (this.pollInterval) { + clearInterval(this.pollInterval); + this.pollInterval = null; + } + this.updateStatus('Disconnected'); + } + + /** + * Fetches the latest screenshot from the API + */ + async fetchLatest() { + if (!this.sessionID) { + this.showError('Session ID not configured'); + return; + } + + try { + const url = `/api/screenshots/latest?session_id=${encodeURIComponent(this.sessionID)}`; + const response = await fetch(url); + + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const screenshot = await response.json(); + this.updateDisplay(screenshot); + this.hideError(); + this.hideStatus(); + + } catch (error) { + console.error('Failed to fetch screenshot:', error); + this.showError(`Failed to fetch screenshot: ${error.message}`); + this.updateStatus('Connection error'); + } + } + + /** + * Updates the overlay display with screenshot data + * @param {Object} screenshot - Screenshot data with id, timestamp, data, width, height, format, method + */ + updateDisplay(screenshot) { + if (!screenshot || !screenshot.data) { + this.showError('Invalid screenshot data'); + return; + } + + // Update image + this.imageElement.src = `data:image/${screenshot.format};base64,${screenshot.data}`; + this.imageElement.style.display = 'block'; + + // Update timestamp + const timestamp = new Date(screenshot.timestamp); + this.timestampElement.textContent = timestamp.toLocaleTimeString(); + + // Update dimensions + this.dimensionsElement.textContent = `${screenshot.width}×${screenshot.height}`; + + // Update status + this.updateStatus(`Live (${screenshot.method})`); + } + + /** + * Shows the overlay + */ + show() { + this.enabled = true; + this.overlayElement.classList.remove('hidden'); + + // Resume polling if port is configured + if (this.screenshotPort && !this.pollInterval) { + this.startPolling(this.screenshotPort); + } + } + + /** + * Hides the overlay + */ + hide() { + this.enabled = false; + this.overlayElement.classList.add('hidden'); + this.stopPolling(); + } + + /** + * Toggles the overlay visibility + */ + toggle() { + if (this.enabled) { + this.hide(); + } else { + this.show(); + } + } + + /** + * Updates the status message + * @param {string} message - Status message + */ + updateStatus(message) { + this.statusElement.textContent = message; + this.statusElement.classList.remove('hidden'); + } + + /** + * Hides the status message + */ + hideStatus() { + this.statusElement.classList.add('hidden'); + } + + /** + * Shows an error message + * @param {string} message - Error message + */ + showError(message) { + this.errorElement.textContent = message; + this.errorElement.classList.remove('hidden'); + } + + /** + * Hides the error message + */ + hideError() { + this.errorElement.classList.add('hidden'); + } + + /** + * Cleans up resources + */ + destroy() { + this.stopPolling(); + if (this.overlayElement && this.overlayElement.parentNode) { + this.overlayElement.parentNode.removeChild(this.overlayElement); + } + } +} diff --git a/internal/web/templates/index.html b/internal/web/templates/index.html index 0a4bec39..aebb0421 100644 --- a/internal/web/templates/index.html +++ b/internal/web/templates/index.html @@ -115,6 +115,34 @@ #new-tab-btn:hover { background: #414868; } + #screenshot-toggle-btn { + display: flex; + align-items: center; + justify-content: center; + padding: 4px 12px; + margin-left: auto; + margin-right: 8px; + background: #24283b; + border: 1px solid #414868; + border-radius: 4px; + color: #a9b1d6; + font-size: 13px; + cursor: pointer; + transition: all 0.15s; + gap: 6px; + } + #screenshot-toggle-btn:hover { + background: #414868; + border-color: #7aa2f7; + } + #screenshot-toggle-btn.active { + background: #7aa2f7; + color: #1a1b26; + border-color: #7aa2f7; + } + #screenshot-toggle-btn.hidden { + display: none; + } #terminal-area { flex: 1; position: relative; @@ -190,6 +218,106 @@ font-family: 'Menlo', 'Monaco', 'Courier New', monospace; font-size: 13px; } + /* Screenshot overlay styles */ + .screenshot-overlay { + position: absolute; + top: 20px; + right: 20px; + width: 400px; + max-height: 80vh; + background: rgba(22, 22, 30, 0.98); + border: 1px solid #414868; + border-radius: 8px; + box-shadow: 0 8px 32px rgba(0, 0, 0, 0.5); + z-index: 1000; + display: flex; + flex-direction: column; + overflow: hidden; + backdrop-filter: blur(10px); + } + .screenshot-overlay.hidden { + display: none; + } + .screenshot-header { + display: flex; + align-items: center; + justify-content: space-between; + padding: 12px 16px; + background: #1a1b26; + border-bottom: 1px solid #414868; + } + .screenshot-header h3 { + font-size: 14px; + font-weight: 600; + color: #7aa2f7; + margin: 0; + } + .screenshot-close { + background: transparent; + border: none; + color: #a9b1d6; + font-size: 24px; + line-height: 1; + cursor: pointer; + padding: 0; + width: 24px; + height: 24px; + display: flex; + align-items: center; + justify-content: center; + border-radius: 4px; + transition: all 0.15s; + } + .screenshot-close:hover { + background: #f7768e; + color: #1a1b26; + } + .screenshot-content { + padding: 16px; + overflow-y: auto; + flex: 1; + } + .screenshot-status { + font-size: 12px; + color: #9ece6a; + margin-bottom: 12px; + padding: 6px 10px; + background: rgba(158, 206, 106, 0.1); + border-radius: 4px; + text-align: center; + } + .screenshot-error { + font-size: 12px; + color: #f7768e; + margin-bottom: 12px; + padding: 6px 10px; + background: rgba(247, 118, 142, 0.1); + border-radius: 4px; + text-align: center; + } + .screenshot-error.hidden { + display: none; + } + .screenshot-image { + width: 100%; + height: auto; + border-radius: 4px; + border: 1px solid #414868; + display: block; + margin-bottom: 12px; + } + .screenshot-info { + display: flex; + justify-content: space-between; + font-size: 11px; + color: #565f89; + padding-top: 8px; + border-top: 1px solid #414868; + } + .screenshot-timestamp, + .screenshot-dimensions { + font-family: 'Menlo', 'Monaco', 'Courier New', monospace; + } @@ -201,6 +329,9 @@ +
@@ -222,6 +353,7 @@

Getting Started

+ From 841bb09444c59987cbcb4390c2cecf6bc6e62452 Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Sat, 3 Jan 2026 12:16:29 +0200 Subject: [PATCH 02/14] refactor(web): Rename screenshot overlay to preview MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename screenshot streaming UI components to use "Preview" terminology: - Rename screenshot-overlay.js to preview-overlay.js - Remove emoji from button (📷 Screenshots → Preview) - Update overlay title from "Live Screenshot" to "Live Preview" - Update user-facing messages to use "Preview" instead of "Screenshot" This improves clarity and consistency in the web UI while keeping internal implementation details (CSS classes, API endpoints) unchanged. Signed-off-by: Eden Reich --- .../{screenshot-overlay.js => preview-overlay.js} | 10 +++++----- internal/web/templates/index.html | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) rename internal/web/static/{screenshot-overlay.js => preview-overlay.js} (95%) diff --git a/internal/web/static/screenshot-overlay.js b/internal/web/static/preview-overlay.js similarity index 95% rename from internal/web/static/screenshot-overlay.js rename to internal/web/static/preview-overlay.js index c5729095..4c4c41d2 100644 --- a/internal/web/static/screenshot-overlay.js +++ b/internal/web/static/preview-overlay.js @@ -27,13 +27,13 @@ class ScreenshotOverlay { this.overlayElement.className = 'screenshot-overlay hidden'; this.overlayElement.innerHTML = `
-

Live Screenshot

+

Live Preview

Connecting...
- Remote screenshot + Remote preview
@@ -73,7 +73,7 @@ class ScreenshotOverlay { console.log(`Screenshot overlay: starting polling for session ${sessionID}`); // Show status - this.updateStatus('Loading screenshots...'); + this.updateStatus('Loading preview...'); this.hideError(); // Fetch immediately @@ -120,7 +120,7 @@ class ScreenshotOverlay { } catch (error) { console.error('Failed to fetch screenshot:', error); - this.showError(`Failed to fetch screenshot: ${error.message}`); + this.showError(`Failed to fetch preview: ${error.message}`); this.updateStatus('Connection error'); } } @@ -131,7 +131,7 @@ class ScreenshotOverlay { */ updateDisplay(screenshot) { if (!screenshot || !screenshot.data) { - this.showError('Invalid screenshot data'); + this.showError('Invalid preview data'); return; } diff --git a/internal/web/templates/index.html b/internal/web/templates/index.html index aebb0421..4a2c1344 100644 --- a/internal/web/templates/index.html +++ b/internal/web/templates/index.html @@ -329,8 +329,8 @@
-
@@ -353,7 +353,7 @@

Getting Started

- + From 5bd552050c8bcd350f4f12b7bd98e21e7e51215f Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Sat, 3 Jan 2026 16:07:45 +0200 Subject: [PATCH 03/14] feat: Add typing delay and improve X11 keyboard support --- config/config.go | 2 + examples/computer-use/Dockerfile.ubuntu-gui | 80 +++++- examples/computer-use/docker-compose.yml | 2 - internal/services/agent.go | 17 +- .../services/tools/computer_use_wayland.go | 48 ++-- internal/services/tools/computer_use_x11.go | 240 ++++++++++++++++-- internal/services/tools/keyboard_type.go | 4 +- internal/services/tools/keyboard_type_test.go | 233 +++++++++++++++++ internal/web/templates/index.html | 4 +- 9 files changed, 578 insertions(+), 52 deletions(-) create mode 100644 internal/services/tools/keyboard_type_test.go diff --git a/config/config.go b/config/config.go index 88575e87..0ed19d0f 100644 --- a/config/config.go +++ b/config/config.go @@ -289,6 +289,7 @@ type MouseClickToolConfig struct { type KeyboardTypeToolConfig struct { Enabled bool `yaml:"enabled" mapstructure:"enabled"` MaxTextLength int `yaml:"max_text_length" mapstructure:"max_text_length"` + TypingDelayMs int `yaml:"typing_delay_ms" mapstructure:"typing_delay_ms"` RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` } @@ -1055,6 +1056,7 @@ Write the AGENTS.md file to the project root when you have gathered enough infor KeyboardType: KeyboardTypeToolConfig{ Enabled: true, MaxTextLength: 1000, + TypingDelayMs: 200, RequireApproval: &[]bool{true}[0], }, RateLimit: RateLimitConfig{ diff --git a/examples/computer-use/Dockerfile.ubuntu-gui b/examples/computer-use/Dockerfile.ubuntu-gui index 3c73c139..5b7f5d6e 100644 --- a/examples/computer-use/Dockerfile.ubuntu-gui +++ b/examples/computer-use/Dockerfile.ubuntu-gui @@ -5,7 +5,7 @@ ENV DEBIAN_FRONTEND=noninteractive ENV TZ=UTC RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone -RUN apt-get update && apt-get install -y \ +RUN apt-get update && apt-get install -y --no-install-recommends \ xfce4 \ xfce4-terminal \ xfce4-goodies \ @@ -14,7 +14,6 @@ RUN apt-get update && apt-get install -y \ x11-apps \ xdotool \ xterm \ - firefox \ mousepad \ thunar \ git \ @@ -28,6 +27,18 @@ RUN apt-get update && apt-get install -y \ sudo \ pm-utils \ openssh-server \ + xdg-utils \ + gnupg \ + ca-certificates \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +RUN curl -fsSL https://ftp-master.debian.org/keys/archive-key-12.asc | gpg --dearmor -o /usr/share/keyrings/debian-archive-keyring.gpg \ + && echo "deb [signed-by=/usr/share/keyrings/debian-archive-keyring.gpg] http://deb.debian.org/debian bookworm main" > /etc/apt/sources.list.d/debian.list \ + && echo "deb [signed-by=/usr/share/keyrings/debian-archive-keyring.gpg] http://deb.debian.org/debian bookworm-updates main" >> /etc/apt/sources.list.d/debian.list \ + && apt-get update \ + && apt-get install -y --no-install-recommends -t bookworm chromium chromium-common chromium-sandbox \ + && rm /etc/apt/sources.list.d/debian.list \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* @@ -36,6 +47,26 @@ RUN useradd -m -s /bin/bash ubuntu 2>/dev/null || true && \ usermod -aG sudo ubuntu && \ echo "ubuntu ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers +RUN printf '#!/bin/bash\nexec /usr/bin/chromium --no-sandbox --disable-dev-shm-usage --disable-gpu --disable-software-rasterizer "$@"\n' > /usr/local/bin/chromium && \ + chmod +x /usr/local/bin/chromium + +RUN update-alternatives --install /usr/bin/x-www-browser x-www-browser /usr/local/bin/chromium 100 && \ + update-alternatives --install /usr/bin/gnome-www-browser gnome-www-browser /usr/local/bin/chromium 100 && \ + update-alternatives --set x-www-browser /usr/local/bin/chromium && \ + update-alternatives --set gnome-www-browser /usr/local/bin/chromium + +RUN mkdir -p /usr/share/applications && \ + printf '[Desktop Entry]\nVersion=1.0\nName=Chromium Web Browser\nComment=Access the Internet\nExec=/usr/local/bin/chromium %%U\nTerminal=false\nType=Application\nIcon=chromium-browser\nCategories=Network;WebBrowser;\nMimeType=text/html;text/xml;application/xhtml+xml;x-scheme-handler/http;x-scheme-handler/https;\nStartupNotify=true\n' > /usr/share/applications/chromium.desktop && \ + chmod 644 /usr/share/applications/chromium.desktop + +RUN mkdir -p /usr/share/xfce4/helpers && \ + printf '[Desktop Entry]\nVersion=1.0\nType=X-XFCE-Helper\nX-XFCE-Category=WebBrowser\nX-XFCE-Commands=/usr/local/bin/chromium\nX-XFCE-CommandsWithParameter=/usr/local/bin/chromium %%s\nIcon=chromium-browser\nName=Chromium\nX-XFCE-Binaries=chromium\n' > /usr/share/xfce4/helpers/chromium.desktop && \ + chmod 644 /usr/share/xfce4/helpers/chromium.desktop + +RUN mkdir -p /etc/xdg && \ + printf '[Default Applications]\ntext/html=chromium.desktop\ntext/xml=chromium.desktop\napplication/xhtml+xml=chromium.desktop\nx-scheme-handler/http=chromium.desktop\nx-scheme-handler/https=chromium.desktop\nx-scheme-handler/about=chromium.desktop\nx-scheme-handler/unknown=chromium.desktop\n\n[Added Associations]\ntext/html=chromium.desktop\ntext/xml=chromium.desktop\napplication/xhtml+xml=chromium.desktop\nx-scheme-handler/http=chromium.desktop\nx-scheme-handler/https=chromium.desktop\n' > /etc/xdg/mimeapps.list && \ + printf '[Default Applications]\ntext/html=chromium.desktop\ntext/xml=chromium.desktop\napplication/xhtml+xml=chromium.desktop\nx-scheme-handler/http=chromium.desktop\nx-scheme-handler/https=chromium.desktop\nx-scheme-handler/about=chromium.desktop\nx-scheme-handler/unknown=chromium.desktop\n\n[Added Associations]\ntext/html=chromium.desktop\ntext/xml=chromium.desktop\napplication/xhtml+xml=chromium.desktop\nx-scheme-handler/http=chromium.desktop\nx-scheme-handler/https=chromium.desktop\n' > /etc/xdg/xfce-mimeapps.list + RUN mkdir -p /var/run/sshd && \ mkdir -p /home/ubuntu/.ssh && \ chmod 700 /home/ubuntu/.ssh && \ @@ -76,12 +107,57 @@ sleep 2 export DISPLAY=:1 +echo "Configuring default browser for ubuntu user..." +mkdir -p /home/ubuntu/.config/xfce4 \ + /home/ubuntu/.local/share/applications \ + /home/ubuntu/.local/share/xfce4/helpers + +# Create user-level mimeapps.list +cat > /home/ubuntu/.config/mimeapps.list << 'MIMEAPPS' +[Default Applications] +text/html=chromium.desktop +text/xml=chromium.desktop +application/xhtml+xml=chromium.desktop +x-scheme-handler/http=chromium.desktop +x-scheme-handler/https=chromium.desktop +x-scheme-handler/about=chromium.desktop +x-scheme-handler/unknown=chromium.desktop + +[Added Associations] +text/html=chromium.desktop +text/xml=chromium.desktop +application/xhtml+xml=chromium.desktop +x-scheme-handler/http=chromium.desktop +x-scheme-handler/https=chromium.desktop +MIMEAPPS + +# Create user-level XFCE helper file +cat > /home/ubuntu/.local/share/xfce4/helpers/chromium.desktop << 'HELPER' +[Desktop Entry] +Version=1.0 +Type=X-XFCE-Helper +X-XFCE-Category=WebBrowser +X-XFCE-Commands=/usr/local/bin/chromium +X-XFCE-CommandsWithParameter=/usr/local/bin/chromium %s +Icon=chromium-browser +Name=Chromium +X-XFCE-Binaries=chromium +HELPER + +# Create XFCE helpers.rc pointing to our chromium helper +cat > /home/ubuntu/.config/xfce4/helpers.rc << 'HELPERS' +WebBrowser=chromium +HELPERS + +chown -R ubuntu:ubuntu /home/ubuntu/.config /home/ubuntu/.local + echo "Starting XFCE desktop environment with D-Bus session..." runuser -u ubuntu -- env DISPLAY=:1 dbus-launch --exit-with-session startxfce4 & XFCE_PID=$! echo "✓ Headless X11 server running on display :1 (${SCREEN_RESOLUTION}x${SCREEN_DEPTH})" echo "✓ XFCE desktop started" +echo "✓ Chromium browser configured" echo "✓ Ready for computer use tools" while kill -0 $XVFB_PID 2>/dev/null && kill -0 $XFCE_PID 2>/dev/null; do diff --git a/examples/computer-use/docker-compose.yml b/examples/computer-use/docker-compose.yml index 6e1eb24e..e723a186 100644 --- a/examples/computer-use/docker-compose.yml +++ b/examples/computer-use/docker-compose.yml @@ -105,8 +105,6 @@ services: ANTHROPIC_API_KEY: ${ANTHROPIC_API_KEY:-} OPENAI_API_KEY: ${OPENAI_API_KEY:-} GOOGLE_API_KEY: ${GOOGLE_API_KEY:-} - ports: - - "8080:8080" networks: - computer-use-network restart: unless-stopped diff --git a/internal/services/agent.go b/internal/services/agent.go index 866ca938..16d4f9a4 100644 --- a/internal/services/agent.go +++ b/internal/services/agent.go @@ -444,7 +444,7 @@ func (s *AgentServiceImpl) RunWithStream(ctx context.Context, req *domain.AgentR time.Sleep(constants.AgentIterationDelay) } - requestCtx, requestCancel := context.WithCancel(ctx) + requestCtx, requestCancel := context.WithTimeout(ctx, time.Duration(s.timeoutSeconds)*time.Second) s.requestsMux.Lock() s.activeRequests[req.RequestID] = requestCancel @@ -514,6 +514,21 @@ func (s *AgentServiceImpl) RunWithStream(ctx context.Context, req *domain.AgentR var streamUsage *sdk.CompletionUsage ////// STREAM ITERATION START for event := range events { + select { + case <-requestCtx.Done(): + if requestCtx.Err() == context.DeadlineExceeded { + logger.Error("stream timeout", "error", requestCtx.Err()) + eventPublisher.chatEvents <- domain.ChatErrorEvent{ + RequestID: req.RequestID, + Timestamp: time.Now(), + Error: fmt.Errorf("stream timed out after %d seconds", s.timeoutSeconds), + } + return + } + return + default: + } + if event.Event == nil { logger.Error("event is nil") continue diff --git a/internal/services/tools/computer_use_wayland.go b/internal/services/tools/computer_use_wayland.go index c1e625f7..19208aa5 100644 --- a/internal/services/tools/computer_use_wayland.go +++ b/internal/services/tools/computer_use_wayland.go @@ -124,44 +124,56 @@ func (c *WaylandClient) ClickMouse(button string, clicks int) error { return nil } -// TypeText types the given text -func (c *WaylandClient) TypeText(text string) error { +// TypeText types the given text with a configurable delay between keystrokes (in milliseconds) +func (c *WaylandClient) TypeText(text string, delayMs int) error { if _, err := exec.LookPath("wtype"); err == nil { - return c.typeTextWithWtype(text) + return c.typeTextWithWtype(text, delayMs) } if _, err := exec.LookPath("ydotool"); err == nil { - return c.typeTextWithYdotool(text) + return c.typeTextWithYdotool(text, delayMs) } return fmt.Errorf("no text input tool available (install wtype or ydotool)") } // typeTextWithWtype types text using the wtype command -func (c *WaylandClient) typeTextWithWtype(text string) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() +func (c *WaylandClient) typeTextWithWtype(text string, delayMs int) error { + for _, char := range text { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - cmd := exec.CommandContext(ctx, "wtype", text) + cmd := exec.CommandContext(ctx, "wtype", string(char)) + output, err := cmd.CombinedOutput() + cancel() - output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("wtype failed: %s", string(output)) + if err != nil { + return fmt.Errorf("wtype failed: %s", string(output)) + } + + if delayMs > 0 { + time.Sleep(time.Duration(delayMs) * time.Millisecond) + } } return nil } // typeTextWithYdotool types text using the ydotool command -func (c *WaylandClient) typeTextWithYdotool(text string) error { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() +func (c *WaylandClient) typeTextWithYdotool(text string, delayMs int) error { + for _, char := range text { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - cmd := exec.CommandContext(ctx, "ydotool", "type", text) + cmd := exec.CommandContext(ctx, "ydotool", "type", string(char)) + output, err := cmd.CombinedOutput() + cancel() - output, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ydotool type failed: %s", string(output)) + if err != nil { + return fmt.Errorf("ydotool failed: %s", string(output)) + } + + if delayMs > 0 { + time.Sleep(time.Duration(delayMs) * time.Millisecond) + } } return nil diff --git a/internal/services/tools/computer_use_x11.go b/internal/services/tools/computer_use_x11.go index 42fb6f34..ad4064c3 100644 --- a/internal/services/tools/computer_use_x11.go +++ b/internal/services/tools/computer_use_x11.go @@ -6,10 +6,14 @@ import ( "image" "image/png" "os" + "strings" + "time" xgb "github.com/BurntSushi/xgb" xproto "github.com/BurntSushi/xgb/xproto" + xtest "github.com/BurntSushi/xgb/xtest" xgbutil "github.com/BurntSushi/xgbutil" + keybind "github.com/BurntSushi/xgbutil/keybind" xgraphics "github.com/BurntSushi/xgbutil/xgraphics" logger "github.com/inference-gateway/cli/internal/logger" @@ -23,6 +27,24 @@ type X11Client struct { display string } +// Character mapping tables for X11 key names +var ( + shiftChars = map[rune]string{ + '!': "exclam", '@': "at", '#': "numbersign", '$': "dollar", + '%': "percent", '^': "asciicircum", '&': "ampersand", '*': "asterisk", + '(': "parenleft", ')': "parenright", '_': "underscore", '+': "plus", + '{': "braceleft", '}': "braceright", '|': "bar", ':': "colon", + '"': "quotedbl", '<': "less", '>': "greater", '?': "question", + '~': "asciitilde", + } + + punctuationChars = map[rune]string{ + '.': "period", ',': "comma", ';': "semicolon", '\'': "apostrophe", + '/': "slash", '\\': "backslash", '-': "minus", '=': "equal", + '[': "bracketleft", ']': "bracketright", '`': "grave", + } +) + // NewX11Client creates a new X11 client connection func NewX11Client(display string) (*X11Client, error) { if display == "" { @@ -47,6 +69,13 @@ func NewX11Client(display string) (*X11Client, error) { return nil, fmt.Errorf("failed to connect to X11 display %s: %w", display, err) } + if err := xtest.Init(xu.Conn()); err != nil { + logger.Error("Failed to initialize XTEST extension", "error", err) + return nil, fmt.Errorf("failed to initialize XTEST extension: %w", err) + } + + keybind.Initialize(xu) + logger.Debug("Successfully connected to X11 display", "display", display) return &X11Client{ @@ -141,36 +170,197 @@ func (c *X11Client) MoveMouse(x, y int) error { // ClickMouse performs a mouse click at the current cursor position func (c *X11Client) ClickMouse(button string, clicks int) error { - // Note: X11 mouse clicking requires the XTEST extension which is not - // fully implemented in the pure Go xgb library. - // For production use, consider using xdotool as a fallback or implementing - // XTEST extension support. + root := c.screen.Root + + var buttonCode byte + switch button { + case "left": + buttonCode = 1 + case "middle": + buttonCode = 2 + case "right": + buttonCode = 3 + default: + return fmt.Errorf("invalid button: %s (must be 'left', 'middle', or 'right')", button) + } + + for i := 0; i < clicks; i++ { + cookie := xtest.FakeInputChecked(c.conn, xproto.ButtonPress, buttonCode, 0, root, 0, 0, 0) + if err := cookie.Check(); err != nil { + return fmt.Errorf("failed to send button press: %w", err) + } + time.Sleep(50 * time.Millisecond) + + cookie = xtest.FakeInputChecked(c.conn, xproto.ButtonRelease, buttonCode, 0, root, 0, 0, 0) + if err := cookie.Check(); err != nil { + return fmt.Errorf("failed to send button release: %w", err) + } + + if i < clicks-1 { + time.Sleep(100 * time.Millisecond) + } + } + + c.conn.Sync() + return nil +} + +// charToKeyInfo maps a character to its X11 key string and shift requirement +type charToKeyInfo struct { + keyStr string + needsShift bool +} + +// mapCharToKey converts a character to its X11 key name and shift requirement +func mapCharToKey(char rune) charToKeyInfo { + if char >= 'A' && char <= 'Z' { + return charToKeyInfo{ + keyStr: strings.ToLower(string(char)), + needsShift: true, + } + } + + if shiftChar, ok := shiftChars[char]; ok { + return charToKeyInfo{ + keyStr: shiftChar, + needsShift: true, + } + } + + if punctChar, ok := punctuationChars[char]; ok { + return charToKeyInfo{ + keyStr: punctChar, + needsShift: false, + } + } + + switch char { + case '\n': + return charToKeyInfo{keyStr: "Return", needsShift: false} + case '\t': + return charToKeyInfo{keyStr: "Tab", needsShift: false} + case ' ': + return charToKeyInfo{keyStr: "space", needsShift: false} + default: + return charToKeyInfo{keyStr: string(char), needsShift: false} + } +} + +// TypeText types the given text with a configurable delay between keystrokes (in milliseconds) +func (c *X11Client) TypeText(text string, delayMs int) error { + root := c.screen.Root + baseDelay := time.Duration(delayMs) * time.Millisecond + + for _, char := range text { + keyInfo := mapCharToKey(char) - return fmt.Errorf("X11 mouse clicking requires xdotool (install with: sudo apt install xdotool). Use Wayland with ydotool for native support, or we can add xdotool fallback") + keycodes := keybind.StrToKeycodes(c.xu, keyInfo.keyStr) + if len(keycodes) == 0 { + logger.Warn("No keycode found for character", "char", string(char), "keyStr", keyInfo.keyStr) + continue + } + + keycode := keycodes[0] + + if err := c.typeKeyWithShift(root, keycode, keyInfo.needsShift, baseDelay); err != nil { + return err + } + } + + c.conn.Sync() + return nil } -// TypeText types the given text by sending key events -func (c *X11Client) TypeText(text string) error { - // This is a simplified implementation - // A full implementation would need to: - // 1. Map characters to keycodes using the keyboard mapping - // 2. Handle modifier keys (Shift, Ctrl, etc.) - // 3. Send KeyPress and KeyRelease events for each character +// typeKeyWithShift types a single key, optionally with shift modifier +func (c *X11Client) typeKeyWithShift(root xproto.Window, keycode xproto.Keycode, needsShift bool, delay time.Duration) error { + if needsShift { + shiftKeycodes := keybind.StrToKeycodes(c.xu, "Shift_L") + if len(shiftKeycodes) > 0 { + _ = xtest.FakeInput(c.conn, xproto.KeyPress, byte(shiftKeycodes[0]), 0, root, 0, 0, 0) + time.Sleep(delay) + } + } + + _ = xtest.FakeInput(c.conn, xproto.KeyPress, byte(keycode), 0, root, 0, 0, 0) + time.Sleep(delay) - // For now, return an error indicating this needs proper keysym mapping - return fmt.Errorf("text typing via X11 requires keysym mapping (not yet implemented)") + _ = xtest.FakeInput(c.conn, xproto.KeyRelease, byte(keycode), 0, root, 0, 0, 0) + time.Sleep(delay) + + if needsShift { + shiftKeycodes := keybind.StrToKeycodes(c.xu, "Shift_L") + if len(shiftKeycodes) > 0 { + _ = xtest.FakeInput(c.conn, xproto.KeyRelease, byte(shiftKeycodes[0]), 0, root, 0, 0, 0) + time.Sleep(delay) + } + } + + return nil } -// SendKeyCombo sends a key combination (e.g., "ctrl+c") +// SendKeyCombo sends a key combination (e.g., "ctrl+c", "super+l") func (c *X11Client) SendKeyCombo(combo string) error { - // This is a simplified implementation - // A full implementation would need to: - // 1. Parse the combo string to extract modifiers and key - // 2. Map key names to keycodes - // 3. Send modifier key presses - // 4. Send the main key press/release - // 5. Release modifier keys - - // For now, return an error indicating this needs proper implementation - return fmt.Errorf("key combinations via X11 require keysym mapping (not yet implemented)") + root := c.screen.Root + + combo = strings.ReplaceAll(combo, "-", "+") + parts := strings.Split(combo, "+") + + if len(parts) == 0 { + return fmt.Errorf("invalid key combination: %s", combo) + } + + modifiers := parts[:len(parts)-1] + mainKey := parts[len(parts)-1] + + modifierMap := map[string]string{ + "ctrl": "Control_L", + "control": "Control_L", + "alt": "Alt_L", + "shift": "Shift_L", + "super": "Super_L", + "meta": "Meta_L", + "win": "Super_L", + "cmd": "Super_L", + } + + var modKeycodes []xproto.Keycode + for _, mod := range modifiers { + modName := strings.ToLower(strings.TrimSpace(mod)) + xModName, ok := modifierMap[modName] + if !ok { + xModName = mod + } + + keycodes := keybind.StrToKeycodes(c.xu, xModName) + if len(keycodes) == 0 { + return fmt.Errorf("no keycode found for modifier: %s", mod) + } + modKeycodes = append(modKeycodes, keycodes[0]) + } + + mainKey = strings.TrimSpace(mainKey) + mainKeycodes := keybind.StrToKeycodes(c.xu, mainKey) + if len(mainKeycodes) == 0 { + return fmt.Errorf("no keycode found for key: %s", mainKey) + } + mainKeycode := mainKeycodes[0] + + for _, keycode := range modKeycodes { + _ = xtest.FakeInput(c.conn, xproto.KeyPress, byte(keycode), 0, root, 0, 0, 0) + time.Sleep(10 * time.Millisecond) + } + + _ = xtest.FakeInput(c.conn, xproto.KeyPress, byte(mainKeycode), 0, root, 0, 0, 0) + time.Sleep(50 * time.Millisecond) + + _ = xtest.FakeInput(c.conn, xproto.KeyRelease, byte(mainKeycode), 0, root, 0, 0, 0) + time.Sleep(10 * time.Millisecond) + + for i := len(modKeycodes) - 1; i >= 0; i-- { + _ = xtest.FakeInput(c.conn, xproto.KeyRelease, byte(modKeycodes[i]), 0, root, 0, 0, 0) + time.Sleep(10 * time.Millisecond) + } + + c.conn.Sync() + return nil } diff --git a/internal/services/tools/keyboard_type.go b/internal/services/tools/keyboard_type.go index de216552..93784c18 100644 --- a/internal/services/tools/keyboard_type.go +++ b/internal/services/tools/keyboard_type.go @@ -117,7 +117,7 @@ func (t *KeyboardTypeTool) Execute(ctx context.Context, args map[string]any) (*d defer client.Close() if hasText { - err = client.TypeText(text) + err = client.TypeText(text, t.config.ComputerUse.KeyboardType.TypingDelayMs) } else { err = client.SendKeyCombo(keyCombo) } @@ -146,7 +146,7 @@ func (t *KeyboardTypeTool) Execute(ctx context.Context, args map[string]any) (*d defer client.Close() if hasText { - err = client.TypeText(text) + err = client.TypeText(text, t.config.ComputerUse.KeyboardType.TypingDelayMs) } else { err = client.SendKeyCombo(keyCombo) } diff --git a/internal/services/tools/keyboard_type_test.go b/internal/services/tools/keyboard_type_test.go new file mode 100644 index 00000000..3e3f068b --- /dev/null +++ b/internal/services/tools/keyboard_type_test.go @@ -0,0 +1,233 @@ +package tools + +import ( + "testing" + "time" + + config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" +) + +func TestKeyboardTypeTool_TypingDelay(t *testing.T) { + tests := []struct { + name string + text string + delayMs int + expectedMinMs int + skipExecution bool + }{ + { + name: "fast typing with short delay", + text: "hi", + delayMs: 50, + expectedMinMs: 100, + skipExecution: true, + }, + { + name: "slow typing with long delay", + text: "hello", + delayMs: 200, + expectedMinMs: 1000, + skipExecution: true, + }, + { + name: "zero delay should still work", + text: "test", + delayMs: 0, + expectedMinMs: 0, + skipExecution: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &config.Config{ + ComputerUse: config.ComputerUseConfig{ + Enabled: true, + KeyboardType: config.KeyboardTypeToolConfig{ + Enabled: true, + MaxTextLength: 1000, + TypingDelayMs: tt.delayMs, + }, + RateLimit: config.RateLimitConfig{ + Enabled: true, + MaxActionsPerMinute: 60, + WindowSeconds: 60, + }, + }, + } + + tool := NewKeyboardTypeTool(cfg, NewRateLimiter(cfg.ComputerUse.RateLimit)) + + if tool.config.ComputerUse.KeyboardType.TypingDelayMs != tt.delayMs { + t.Errorf("Expected delay %d ms, got %d ms", tt.delayMs, tool.config.ComputerUse.KeyboardType.TypingDelayMs) + } + }) + } +} + +func TestKeyboardTypeTool_ConfigDefault(t *testing.T) { + cfg := config.DefaultConfig() + + expectedDelay := 200 + actualDelay := cfg.ComputerUse.KeyboardType.TypingDelayMs + + if actualDelay != expectedDelay { + t.Errorf("Expected default typing delay %d ms, got %d ms", expectedDelay, actualDelay) + } +} + +func TestKeyboardTypeTool_Validation(t *testing.T) { + cfg := &config.Config{ + ComputerUse: config.ComputerUseConfig{ + Enabled: true, + KeyboardType: config.KeyboardTypeToolConfig{ + Enabled: true, + MaxTextLength: 100, + TypingDelayMs: 200, + }, + RateLimit: config.RateLimitConfig{ + Enabled: true, + MaxActionsPerMinute: 60, + WindowSeconds: 60, + }, + }, + } + + tool := NewKeyboardTypeTool(cfg, NewRateLimiter(cfg.ComputerUse.RateLimit)) + + tests := []struct { + name string + args map[string]any + wantErr bool + }{ + { + name: "valid text input", + args: map[string]any{ + "text": "hello world", + }, + wantErr: false, + }, + { + name: "text exceeds max length", + args: map[string]any{ + "text": string(make([]byte, 101)), + }, + wantErr: true, + }, + { + name: "empty text", + args: map[string]any{ + "text": "", + }, + wantErr: true, + }, + { + name: "valid key combo", + args: map[string]any{ + "key_combo": "ctrl+c", + }, + wantErr: false, + }, + { + name: "neither text nor key_combo", + args: map[string]any{}, + wantErr: true, + }, + { + name: "both text and key_combo", + args: map[string]any{ + "text": "hello", + "key_combo": "ctrl+c", + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tool.Validate(tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// Mock timing test to verify delay is applied (doesn't require X11) +func TestX11Client_TypeTextTiming(t *testing.T) { + tests := []struct { + name string + text string + delayMs int + minExpectedMs int + }{ + { + name: "short text with 100ms delay", + text: "abc", + delayMs: 100, + minExpectedMs: 300, + }, + { + name: "longer text with 50ms delay", + text: "hello", + delayMs: 50, + minExpectedMs: 250, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + charCount := len([]rune(tt.text)) + expectedMin := time.Duration(charCount*tt.delayMs) * time.Millisecond + + if expectedMin < time.Duration(tt.minExpectedMs)*time.Millisecond { + t.Errorf("Expected minimum %v ms, calculated %v ms", tt.minExpectedMs, expectedMin.Milliseconds()) + } + }) + } +} + +func TestKeyboardTypeTool_FormatResult(t *testing.T) { + cfg := &config.Config{ + ComputerUse: config.ComputerUseConfig{ + Enabled: true, + KeyboardType: config.KeyboardTypeToolConfig{ + Enabled: true, + MaxTextLength: 1000, + TypingDelayMs: 200, + }, + RateLimit: config.RateLimitConfig{ + Enabled: true, + MaxActionsPerMinute: 60, + WindowSeconds: 60, + }, + }, + } + + tool := NewKeyboardTypeTool(cfg, NewRateLimiter(cfg.ComputerUse.RateLimit)) + + result := &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: map[string]any{"text": "www.google.com"}, + Success: true, + Duration: time.Second, + Data: domain.KeyboardTypeToolResult{ + Text: "www.google.com", + Display: ":0", + Method: "x11", + }, + } + + formatted := tool.FormatForLLM(result) + expected := "Typed text: 'www.google.com' using x11" + if formatted != expected { + t.Errorf("Expected formatted result %q, got %q", expected, formatted) + } + + preview := tool.FormatPreview(result) + expectedPreview := "Typed: www.google.com" + if preview != expectedPreview { + t.Errorf("Expected preview %q, got %q", expectedPreview, preview) + } +} diff --git a/internal/web/templates/index.html b/internal/web/templates/index.html index 4a2c1344..1e49c851 100644 --- a/internal/web/templates/index.html +++ b/internal/web/templates/index.html @@ -223,8 +223,8 @@ position: absolute; top: 20px; right: 20px; - width: 400px; - max-height: 80vh; + width: 600px; + max-height: 85vh; background: rgba(22, 22, 30, 0.98); border: 1px solid #414868; border-radius: 8px; From 9c326499634e00de3609c62ee4eb9f0224dbb0d2 Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Sat, 3 Jan 2026 19:01:25 +0200 Subject: [PATCH 04/14] refactor: Refactor computer use tools with display protocol abstraction Signed-off-by: Eden Reich --- .infer/config.yaml | 4 +- cmd/agents.go | 4 +- cmd/root.go | 1 + config/config.go | 6 +- cspell.yaml | 27 ++ examples/computer-use/.infer/config.yaml | 2 - internal/display/interface.go | 93 ++++ internal/display/macos/controller_darwin.go | 113 +++++ internal/display/macos/controller_stub.go | 82 ++++ internal/display/registry.go | 72 +++ .../wayland/client.go} | 9 +- internal/display/wayland/controller.go | 135 ++++++ .../x11/client.go} | 4 +- internal/display/x11/controller.go | 112 +++++ internal/domain/interfaces.go | 10 + internal/infra/storage/jsonl.go | 10 +- internal/services/agent.go | 2 +- internal/services/screenshot_server.go | 104 ++--- internal/services/tools/keyboard_type.go | 120 ++--- internal/services/tools/keyboard_type_test.go | 7 +- internal/services/tools/mouse_click.go | 130 +++--- internal/services/tools/mouse_move.go | 114 ++--- internal/services/tools/registry.go | 19 +- internal/services/tools/screenshot.go | 424 ------------------ .../ratelimiter.go} | 62 +-- internal/web/pty_manager.go | 2 +- internal/web/server.go | 2 +- internal/web/ssh_session.go | 8 +- internal/web/static/preview-overlay.js | 13 +- internal/web/templates/index.html | 40 ++ 30 files changed, 947 insertions(+), 784 deletions(-) create mode 100644 internal/display/interface.go create mode 100644 internal/display/macos/controller_darwin.go create mode 100644 internal/display/macos/controller_stub.go create mode 100644 internal/display/registry.go rename internal/{services/tools/computer_use_wayland.go => display/wayland/client.go} (95%) create mode 100644 internal/display/wayland/controller.go rename internal/{services/tools/computer_use_x11.go => display/x11/client.go} (98%) create mode 100644 internal/display/x11/controller.go delete mode 100644 internal/services/tools/screenshot.go rename internal/{services/tools/computer_use_common.go => utils/ratelimiter.go} (55%) diff --git a/.infer/config.yaml b/.infer/config.yaml index 090a283b..f09404a2 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -651,7 +651,8 @@ computer_use: streaming_enabled: false capture_interval: 3 buffer_size: 30 - temp_dir: /tmp/infer-screenshots + temp_dir: "" + log_captures: false mouse_move: enabled: true require_approval: true @@ -661,6 +662,7 @@ computer_use: keyboard_type: enabled: true max_text_length: 1000 + typing_delay_ms: 200 require_approval: true rate_limit: enabled: true diff --git a/cmd/agents.go b/cmd/agents.go index ecf2e0d3..d8c79098 100644 --- a/cmd/agents.go +++ b/cmd/agents.go @@ -242,7 +242,7 @@ type ExternalAgent struct { } // getConfig loads the configuration from viper -func getConfig(cmd *cobra.Command) (*config.Config, error) { +func getConfig(_ *cobra.Command) (*config.Config, error) { cfg, err := getConfigFromViper() if err != nil { return nil, fmt.Errorf("failed to load config: %w", err) @@ -425,7 +425,7 @@ func listAgents(cmd *cobra.Command, args []string) error { format, _ := cmd.Flags().GetString("format") if format == "json" { - combinedOutput := map[string]interface{}{ + combinedOutput := map[string]any{ "local": localAgents, "external": externalAgents, "total": totalAgents, diff --git a/cmd/root.go b/cmd/root.go index 11322a0a..9fbb706b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -103,6 +103,7 @@ func initConfig() { // nolint:funlen v.SetDefault("computer_use.screenshot.capture_interval", defaults.ComputerUse.Screenshot.CaptureInterval) v.SetDefault("computer_use.screenshot.buffer_size", defaults.ComputerUse.Screenshot.BufferSize) v.SetDefault("computer_use.screenshot.temp_dir", defaults.ComputerUse.Screenshot.TempDir) + v.SetDefault("computer_use.screenshot.log_captures", defaults.ComputerUse.Screenshot.LogCaptures) v.SetDefault("computer_use.mouse_move.enabled", defaults.ComputerUse.MouseMove.Enabled) v.SetDefault("computer_use.mouse_click.enabled", defaults.ComputerUse.MouseClick.Enabled) v.SetDefault("computer_use.keyboard_type.enabled", defaults.ComputerUse.KeyboardType.Enabled) diff --git a/config/config.go b/config/config.go index 0ed19d0f..bca0ed3c 100644 --- a/config/config.go +++ b/config/config.go @@ -271,6 +271,7 @@ type ScreenshotToolConfig struct { CaptureInterval int `yaml:"capture_interval" mapstructure:"capture_interval"` // seconds BufferSize int `yaml:"buffer_size" mapstructure:"buffer_size"` // number of screenshots TempDir string `yaml:"temp_dir" mapstructure:"temp_dir"` // path for disk storage + LogCaptures bool `yaml:"log_captures" mapstructure:"log_captures"` // log every capture (debug) } // MouseMoveToolConfig contains mouse move-specific tool settings @@ -1031,7 +1032,7 @@ Write the AGENTS.md file to the project root when you have gathered enough infor Servers: []SSHServerConfig{}, }, ComputerUse: ComputerUseConfig{ - Enabled: false, // Security: disabled by default + Enabled: false, Display: ":0", Screenshot: ScreenshotToolConfig{ Enabled: true, @@ -1043,7 +1044,8 @@ Write the AGENTS.md file to the project root when you have gathered enough infor StreamingEnabled: false, CaptureInterval: 3, BufferSize: 30, - TempDir: "/tmp/infer-screenshots", + TempDir: "", + LogCaptures: false, }, MouseMove: MouseMoveToolConfig{ Enabled: true, diff --git a/cspell.yaml b/cspell.yaml index 9c858d93..061b10f8 100644 --- a/cspell.yaml +++ b/cspell.yaml @@ -5,10 +5,16 @@ words: - alecthomas - alrvd - apapsch + - asciicircum + - asciitilde - aymanbagabas - aymerick - bahlo - bluemonday + - braceleft + - braceright + - bracketleft + - bracketright - bubbletea - buger - cancelreader @@ -21,6 +27,7 @@ words: - colorprofile - coninput - creack + - cyclop - davecgh - deepseek - DEEPSEEK @@ -30,9 +37,11 @@ words: - duckduckgo - easyjson - erikgeiser + - exclam - fsnotify - funlen - gjson + - gocognit - gocyclo - goldmark - gopkg @@ -43,9 +52,11 @@ words: - inconshreveable - invopop - isatty + - ISPEED - jsonmerge - jsonparser - jsonschema + - keybind - keygen - kimi - ledongthuc @@ -67,21 +78,28 @@ words: - myshortcuts - nolint - NOPASSWD + - numbersign - oapi - ollama - OLLAMA - oneline + - OSPEED + - parenleft + - parenright - pflag - pgdn + - pgrep - pgup - playai - pmezard + - quotedbl - qwen - resty - retryable - rivo - runewidth - sabhiram + - sagents - sagikazarmark - SHTTP - sjson @@ -93,6 +111,15 @@ words: - termenv - terminfo - tidwall + - tmpl - uniseg + - upgrader + - userspace + - Winsize + - xdraw + - xgbutil + - xgraphics + - ximg + - xtest - xxhash - yuin diff --git a/examples/computer-use/.infer/config.yaml b/examples/computer-use/.infer/config.yaml index d668c8b8..8374e34a 100644 --- a/examples/computer-use/.infer/config.yaml +++ b/examples/computer-use/.infer/config.yaml @@ -177,11 +177,9 @@ computer_use: format: "jpeg" quality: 80 require_approval: false - # Screenshot streaming for web UI streaming_enabled: true capture_interval: 3 buffer_size: 30 - temp_dir: "/tmp/infer-screenshots" mouse_move: enabled: true require_approval: true diff --git a/internal/display/interface.go b/internal/display/interface.go new file mode 100644 index 00000000..575350ad --- /dev/null +++ b/internal/display/interface.go @@ -0,0 +1,93 @@ +package display + +import ( + "context" + "image" +) + +// DisplayController abstracts display server-specific operations (X11, Wayland, macOS Quartz) +type DisplayController interface { + // Screen operations + CaptureScreenBytes(ctx context.Context, region *Region) ([]byte, error) + CaptureScreen(ctx context.Context, region *Region) (image.Image, error) + GetScreenDimensions(ctx context.Context) (width, height int, err error) + + // Mouse operations + GetCursorPosition(ctx context.Context) (x, y int, err error) + MoveMouse(ctx context.Context, x, y int) error + ClickMouse(ctx context.Context, button MouseButton, clicks int) error + + // Keyboard operations + TypeText(ctx context.Context, text string, delayMs int) error + SendKeyCombo(ctx context.Context, combo string) error + + // Lifecycle + Close() error +} + +// Region represents a rectangular area on the screen +type Region struct { + X int + Y int + Width int + Height int +} + +// MouseButton represents a mouse button +type MouseButton int + +const ( + MouseButtonLeft MouseButton = iota + MouseButtonMiddle + MouseButtonRight +) + +// String returns the string representation of a mouse button +func (b MouseButton) String() string { + switch b { + case MouseButtonLeft: + return "left" + case MouseButtonMiddle: + return "middle" + case MouseButtonRight: + return "right" + default: + return "unknown" + } +} + +// ParseMouseButton parses a string into a MouseButton +func ParseMouseButton(s string) MouseButton { + switch s { + case "left": + return MouseButtonLeft + case "middle": + return MouseButtonMiddle + case "right": + return MouseButtonRight + default: + return MouseButtonLeft + } +} + +// Provider creates DisplayController instances for a specific display server/protocol +type Provider interface { + // GetController creates a new DisplayController for the specified display + GetController(display string) (DisplayController, error) + + // GetDisplayInfo returns information about the display server/protocol + GetDisplayInfo() DisplayInfo + + // IsAvailable returns true if this display server is available on the current system + IsAvailable() bool +} + +// DisplayInfo contains metadata about a display server or protocol +type DisplayInfo struct { + Name string // "x11", "wayland", "macos" + SupportsRegions bool + SupportsMouse bool + SupportsKeyboard bool + MaxTextLength int + RequiresElevation bool +} diff --git a/internal/display/macos/controller_darwin.go b/internal/display/macos/controller_darwin.go new file mode 100644 index 00000000..1391c0fd --- /dev/null +++ b/internal/display/macos/controller_darwin.go @@ -0,0 +1,113 @@ +//go:build darwin + +package macos + +import ( + "context" + "fmt" + "image" + "runtime" + + display "github.com/inference-gateway/cli/internal/display" +) + +// Controller implements display.DisplayController for macOS +// This is a placeholder for future CGO implementation using: +// - CGDisplayCreateImage for screenshots +// - CGEventPost for mouse/keyboard control +// - Accessibility API for permissions +type Controller struct{} + +var _ display.DisplayController = (*Controller)(nil) + +func (c *Controller) CaptureScreenBytes(ctx context.Context, region *display.Region) ([]byte, error) { + // TODO: Implement using CGDisplayCreateImage + CGO + // Sample code structure: + // /* + // #cgo LDFLAGS: -framework CoreGraphics -framework CoreFoundation + // #include + // CGImageRef CGDisplayCreateImage(CGDirectDisplayID displayID); + // */ + // import "C" + return nil, fmt.Errorf("macOS screenshot not yet implemented (requires CGO)") +} + +func (c *Controller) CaptureScreen(ctx context.Context, region *display.Region) (image.Image, error) { + return nil, fmt.Errorf("macOS screenshot not yet implemented (requires CGO)") +} + +func (c *Controller) GetScreenDimensions(ctx context.Context) (width, height int, err error) { + // TODO: Implement using CGDisplayBounds + return 0, 0, fmt.Errorf("macOS screen dimensions not yet implemented (requires CGO)") +} + +func (c *Controller) GetCursorPosition(ctx context.Context) (x, y int, err error) { + // TODO: Implement using CGEventGetLocation + return 0, 0, fmt.Errorf("macOS cursor position not yet implemented (requires CGO)") +} + +func (c *Controller) MoveMouse(ctx context.Context, x, y int) error { + // TODO: Implement using CGEventCreateMouseEvent + CGEventPost + return fmt.Errorf("macOS mouse move not yet implemented (requires CGO)") +} + +func (c *Controller) ClickMouse(ctx context.Context, button display.MouseButton, clicks int) error { + // TODO: Implement using CGEventCreateMouseEvent + CGEventPost + return fmt.Errorf("macOS mouse click not yet implemented (requires CGO)") +} + +func (c *Controller) TypeText(ctx context.Context, text string, delayMs int) error { + // TODO: Implement using CGEventCreateKeyboardEvent + CGEventPost + return fmt.Errorf("macOS keyboard type not yet implemented (requires CGO)") +} + +func (c *Controller) SendKeyCombo(ctx context.Context, combo string) error { + // TODO: Implement using CGEventCreateKeyboardEvent with modifiers + return fmt.Errorf("macOS key combo not yet implemented (requires CGO)") +} + +func (c *Controller) Close() error { + return nil +} + +// Provider implements the display.Provider interface for macOS +type Provider struct{} + +var _ display.Provider = (*Provider)(nil) + +func NewProvider() *Provider { + return &Provider{} +} + +func (p *Provider) GetController(display string) (display.DisplayController, error) { + // TODO: Check Accessibility permissions + // Sample code: + // AXIsProcessTrustedWithOptions() + return nil, fmt.Errorf("macOS provider not yet implemented (requires CGO)") +} + +func (p *Provider) GetDisplayInfo() display.DisplayInfo { + return display.DisplayInfo{ + Name: "macos", + SupportsRegions: true, + SupportsMouse: true, + SupportsKeyboard: true, + MaxTextLength: 0, + RequiresElevation: true, + } +} + +func (p *Provider) IsAvailable() bool { + // Only available on macOS (darwin) + // TODO: Also check Accessibility permissions when implemented + return runtime.GOOS == "darwin" +} + +// Register the macOS provider in the global registry (darwin only) +func init() { + // TODO: Uncomment when implementation is ready + // display.Register(NewProvider()) + + // For now, don't register to avoid false positives + // The stub implementation will prevent compilation errors +} diff --git a/internal/display/macos/controller_stub.go b/internal/display/macos/controller_stub.go new file mode 100644 index 00000000..0c262bb7 --- /dev/null +++ b/internal/display/macos/controller_stub.go @@ -0,0 +1,82 @@ +//go:build !darwin + +package macos + +import ( + "context" + "fmt" + "image" + + display "github.com/inference-gateway/cli/internal/display" +) + +// Controller is a stub implementation for non-macOS platforms +type Controller struct{} + +var _ display.DisplayController = (*Controller)(nil) + +func (c *Controller) CaptureScreenBytes(ctx context.Context, region *display.Region) ([]byte, error) { + return nil, fmt.Errorf("macOS platform not available on this system") +} + +func (c *Controller) CaptureScreen(ctx context.Context, region *display.Region) (image.Image, error) { + return nil, fmt.Errorf("macOS platform not available on this system") +} + +func (c *Controller) GetScreenDimensions(ctx context.Context) (width, height int, err error) { + return 0, 0, fmt.Errorf("macOS platform not available on this system") +} + +func (c *Controller) GetCursorPosition(ctx context.Context) (x, y int, err error) { + return 0, 0, fmt.Errorf("macOS platform not available on this system") +} + +func (c *Controller) MoveMouse(ctx context.Context, x, y int) error { + return fmt.Errorf("macOS platform not available on this system") +} + +func (c *Controller) ClickMouse(ctx context.Context, button display.MouseButton, clicks int) error { + return fmt.Errorf("macOS platform not available on this system") +} + +func (c *Controller) TypeText(ctx context.Context, text string, delayMs int) error { + return fmt.Errorf("macOS platform not available on this system") +} + +func (c *Controller) SendKeyCombo(ctx context.Context, combo string) error { + return fmt.Errorf("macOS platform not available on this system") +} + +func (c *Controller) Close() error { + return nil +} + +// Provider is a stub implementation for non-macOS platforms +type Provider struct{} + +var _ display.Provider = (*Provider)(nil) + +func NewProvider() *Provider { + return &Provider{} +} + +func (p *Provider) GetController(display string) (display.DisplayController, error) { + return nil, fmt.Errorf("macOS platform not available on this system") +} + +func (p *Provider) GetDisplayInfo() display.DisplayInfo { + return display.DisplayInfo{ + Name: "macos", + SupportsRegions: false, + SupportsMouse: false, + SupportsKeyboard: false, + MaxTextLength: 0, + RequiresElevation: false, + } +} + +func (p *Provider) IsAvailable() bool { + return false // Always false on non-macOS systems +} + +// No init() - don't register on non-macOS systems diff --git a/internal/display/registry.go b/internal/display/registry.go new file mode 100644 index 00000000..7173b007 --- /dev/null +++ b/internal/display/registry.go @@ -0,0 +1,72 @@ +package display + +import ( + "fmt" + "sync" +) + +// Registry manages display server providers and handles display detection +type Registry struct { + providers []Provider + mu sync.RWMutex +} + +var ( + globalRegistry = &Registry{ + providers: make([]Provider, 0), + } +) + +// Register adds a display server provider to the global registry +// This is typically called from init() functions in display-specific packages +func Register(provider Provider) { + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + globalRegistry.providers = append(globalRegistry.providers, provider) +} + +// DetectDisplay returns the first available display server provider +// Priority is determined by registration order (first registered has highest priority) +func DetectDisplay() (Provider, error) { + globalRegistry.mu.RLock() + defer globalRegistry.mu.RUnlock() + + for _, p := range globalRegistry.providers { + if p.IsAvailable() { + return p, nil + } + } + + return nil, fmt.Errorf("no compatible display server detected (tried %d providers)", len(globalRegistry.providers)) +} + +// GetAllProviders returns all registered providers +func GetAllProviders() []Provider { + globalRegistry.mu.RLock() + defer globalRegistry.mu.RUnlock() + + providers := make([]Provider, len(globalRegistry.providers)) + copy(providers, globalRegistry.providers) + return providers +} + +// GetProvider returns a specific provider by display server name, or nil if not found +func GetProvider(displayName string) Provider { + globalRegistry.mu.RLock() + defer globalRegistry.mu.RUnlock() + + for _, p := range globalRegistry.providers { + if p.GetDisplayInfo().Name == displayName { + return p + } + } + + return nil +} + +// ClearProviders removes all registered providers (primarily for testing) +func ClearProviders() { + globalRegistry.mu.Lock() + defer globalRegistry.mu.Unlock() + globalRegistry.providers = make([]Provider, 0) +} diff --git a/internal/services/tools/computer_use_wayland.go b/internal/display/wayland/client.go similarity index 95% rename from internal/services/tools/computer_use_wayland.go rename to internal/display/wayland/client.go index 19208aa5..502b4ab1 100644 --- a/internal/services/tools/computer_use_wayland.go +++ b/internal/display/wayland/client.go @@ -1,4 +1,4 @@ -package tools +package wayland import ( "context" @@ -226,10 +226,6 @@ func (c *WaylandClient) sendKeyComboWithYdotool(combo string) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // ydotool key combo format: key1:key2 - // Convert "ctrl+c" to "29:46" (keycodes) - // This is a simplified version - proper implementation would need keycode mapping - cmd := exec.CommandContext(ctx, "ydotool", "key", combo) output, err := cmd.CombinedOutput() @@ -242,9 +238,6 @@ func (c *WaylandClient) sendKeyComboWithYdotool(combo string) error { // GetScreenDimensions returns the screen width and height func (c *WaylandClient) GetScreenDimensions() (int, int, error) { - // Wayland doesn't have a simple command to get screen dimensions - // We can use wlr-randr if available, or return default values - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/internal/display/wayland/controller.go b/internal/display/wayland/controller.go new file mode 100644 index 00000000..fe61224c --- /dev/null +++ b/internal/display/wayland/controller.go @@ -0,0 +1,135 @@ +package wayland + +import ( + "bytes" + "context" + "fmt" + "image" + "image/png" + "os" + + display "github.com/inference-gateway/cli/internal/display" +) + +// Controller wraps the existing WaylandClient to implement the display.DisplayController interface +type Controller struct { + client *WaylandClient +} + +var _ display.DisplayController = (*Controller)(nil) + +// CaptureScreenBytes captures a screenshot and returns PNG bytes +func (c *Controller) CaptureScreenBytes(ctx context.Context, region *display.Region) ([]byte, error) { + if region == nil { + return c.client.CaptureScreenBytes(0, 0, 0, 0) + } + return c.client.CaptureScreenBytes(region.X, region.Y, region.Width, region.Height) +} + +// CaptureScreen captures a screenshot and returns an image.Image +func (c *Controller) CaptureScreen(ctx context.Context, region *display.Region) (image.Image, error) { + // WaylandClient only returns bytes, so we need to decode them + var imgBytes []byte + var err error + + if region == nil { + imgBytes, err = c.client.CaptureScreenBytes(0, 0, 0, 0) + } else { + imgBytes, err = c.client.CaptureScreenBytes(region.X, region.Y, region.Width, region.Height) + } + + if err != nil { + return nil, err + } + + // Decode PNG bytes to image.Image + img, err := png.Decode(bytes.NewReader(imgBytes)) + if err != nil { + return nil, fmt.Errorf("failed to decode screenshot: %w", err) + } + + return img, nil +} + +// GetScreenDimensions returns the screen width and height +func (c *Controller) GetScreenDimensions(ctx context.Context) (width, height int, err error) { + return c.client.GetScreenDimensions() +} + +// GetCursorPosition returns the current cursor position +// Note: Wayland doesn't provide a standard way to get cursor position +func (c *Controller) GetCursorPosition(ctx context.Context) (x, y int, err error) { + // Wayland doesn't expose cursor position for security reasons + // Return an error indicating this is not supported + return 0, 0, fmt.Errorf("getting cursor position is not supported on Wayland") +} + +// MoveMouse moves the cursor to the specified coordinates +func (c *Controller) MoveMouse(ctx context.Context, x, y int) error { + return c.client.MoveMouse(x, y) +} + +// ClickMouse clicks the specified mouse button +func (c *Controller) ClickMouse(ctx context.Context, button display.MouseButton, clicks int) error { + return c.client.ClickMouse(button.String(), clicks) +} + +// TypeText types the given text with the specified delay between keystrokes +func (c *Controller) TypeText(ctx context.Context, text string, delayMs int) error { + return c.client.TypeText(text, delayMs) +} + +// SendKeyCombo sends a key combination (e.g., "ctrl+c") +func (c *Controller) SendKeyCombo(ctx context.Context, combo string) error { + return c.client.SendKeyCombo(combo) +} + +// Close closes the Wayland client +func (c *Controller) Close() error { + c.client.Close() + return nil +} + +// Provider implements the display.Provider interface for Wayland +type Provider struct{} + +var _ display.Provider = (*Provider)(nil) + +// NewProvider creates a new Wayland provider +func NewProvider() *Provider { + return &Provider{} +} + +// GetController creates a new DisplayController for the specified display +func (p *Provider) GetController(display string) (display.DisplayController, error) { + client, err := NewWaylandClient(display) + if err != nil { + return nil, err + } + return &Controller{client: client}, nil +} + +// GetDisplayInfo returns information about the Wayland platform +func (p *Provider) GetDisplayInfo() display.DisplayInfo { + return display.DisplayInfo{ + Name: "wayland", + SupportsRegions: true, + SupportsMouse: true, + SupportsKeyboard: true, + MaxTextLength: 0, + RequiresElevation: false, + } +} + +// IsAvailable returns true if Wayland is available on the current system +func (p *Provider) IsAvailable() bool { + // Wayland is available if the WAYLAND_DISPLAY environment variable is set + // Wayland takes priority over X11 + return os.Getenv("WAYLAND_DISPLAY") != "" +} + +// Register the Wayland provider in the global registry +// Note: init() runs before X11's init() due to alphabetical ordering of package names +func init() { + display.Register(NewProvider()) +} diff --git a/internal/services/tools/computer_use_x11.go b/internal/display/x11/client.go similarity index 98% rename from internal/services/tools/computer_use_x11.go rename to internal/display/x11/client.go index ad4064c3..e98b5462 100644 --- a/internal/services/tools/computer_use_x11.go +++ b/internal/display/x11/client.go @@ -1,4 +1,4 @@ -package tools +package x11 import ( "bytes" @@ -256,7 +256,7 @@ func (c *X11Client) TypeText(text string, delayMs int) error { keycodes := keybind.StrToKeycodes(c.xu, keyInfo.keyStr) if len(keycodes) == 0 { - logger.Warn("No keycode found for character", "char", string(char), "keyStr", keyInfo.keyStr) + logger.Debug("No keycode found for character", "char", string(char), "keyStr", keyInfo.keyStr) continue } diff --git a/internal/display/x11/controller.go b/internal/display/x11/controller.go new file mode 100644 index 00000000..187a29a3 --- /dev/null +++ b/internal/display/x11/controller.go @@ -0,0 +1,112 @@ +package x11 + +import ( + "context" + "image" + "os" + + display "github.com/inference-gateway/cli/internal/display" +) + +// Controller wraps the existing X11Client to implement the display.DisplayController interface +type Controller struct { + client *X11Client +} + +var _ display.DisplayController = (*Controller)(nil) + +// CaptureScreenBytes captures a screenshot and returns PNG bytes +func (c *Controller) CaptureScreenBytes(ctx context.Context, region *display.Region) ([]byte, error) { + if region == nil { + return c.client.CaptureScreenBytes(0, 0, 0, 0) + } + return c.client.CaptureScreenBytes(region.X, region.Y, region.Width, region.Height) +} + +// CaptureScreen captures a screenshot and returns an image.Image +func (c *Controller) CaptureScreen(ctx context.Context, region *display.Region) (image.Image, error) { + if region == nil { + return c.client.CaptureScreen(0, 0, 0, 0) + } + return c.client.CaptureScreen(region.X, region.Y, region.Width, region.Height) +} + +// GetScreenDimensions returns the screen width and height +func (c *Controller) GetScreenDimensions(ctx context.Context) (width, height int, err error) { + w, h := c.client.GetScreenDimensions() + return w, h, nil +} + +// GetCursorPosition returns the current cursor position +func (c *Controller) GetCursorPosition(ctx context.Context) (x, y int, err error) { + return c.client.GetCursorPosition() +} + +// MoveMouse moves the cursor to the specified coordinates +func (c *Controller) MoveMouse(ctx context.Context, x, y int) error { + return c.client.MoveMouse(x, y) +} + +// ClickMouse clicks the specified mouse button +func (c *Controller) ClickMouse(ctx context.Context, button display.MouseButton, clicks int) error { + return c.client.ClickMouse(button.String(), clicks) +} + +// TypeText types the given text with the specified delay between keystrokes +func (c *Controller) TypeText(ctx context.Context, text string, delayMs int) error { + return c.client.TypeText(text, delayMs) +} + +// SendKeyCombo sends a key combination (e.g., "ctrl+c", "super+l") +func (c *Controller) SendKeyCombo(ctx context.Context, combo string) error { + return c.client.SendKeyCombo(combo) +} + +// Close closes the X11 connection +func (c *Controller) Close() error { + c.client.Close() + return nil +} + +// Provider implements the display.Provider interface for X11 +type Provider struct{} + +var _ display.Provider = (*Provider)(nil) + +// NewProvider creates a new X11 provider +func NewProvider() *Provider { + return &Provider{} +} + +// GetController creates a new DisplayController for the specified display +func (p *Provider) GetController(display string) (display.DisplayController, error) { + client, err := NewX11Client(display) + if err != nil { + return nil, err + } + return &Controller{client: client}, nil +} + +// GetDisplayInfo returns information about the X11 platform +func (p *Provider) GetDisplayInfo() display.DisplayInfo { + return display.DisplayInfo{ + Name: "x11", + SupportsRegions: true, + SupportsMouse: true, + SupportsKeyboard: true, + MaxTextLength: 0, + RequiresElevation: false, + } +} + +// IsAvailable returns true if X11 is available on the current system +func (p *Provider) IsAvailable() bool { + // X11 is available if the DISPLAY environment variable is set + // and WAYLAND_DISPLAY is not set (Wayland takes priority) + return os.Getenv("DISPLAY") != "" && os.Getenv("WAYLAND_DISPLAY") == "" +} + +// Register the X11 provider in the global registry +func init() { + display.Register(NewProvider()) +} diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 353243b4..8039d592 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -54,6 +54,16 @@ type ScreenshotProvider interface { GetLatestScreenshot() (*Screenshot, error) } +// RateLimiter defines the interface for rate limiting computer use actions +type RateLimiter interface { + // CheckAndRecord checks if the action is within rate limits and records it + CheckAndRecord(toolName string) error + // GetCurrentCount returns the number of actions in the current window + GetCurrentCount() int + // Reset clears all recorded actions + Reset() +} + // ScreenshotToolResult represents the result of a screenshot capture type ScreenshotToolResult struct { Display string `json:"display"` diff --git a/internal/infra/storage/jsonl.go b/internal/infra/storage/jsonl.go index 1af3f918..02951fac 100644 --- a/internal/infra/storage/jsonl.go +++ b/internal/infra/storage/jsonl.go @@ -52,13 +52,13 @@ func (s *JsonlStorage) conversationFilePath(conversationID string) string { // saveConversationUnlocked saves a conversation without acquiring the lock // Caller must hold the lock before calling this method -func (s *JsonlStorage) saveConversationUnlocked(ctx context.Context, conversationID string, entries []domain.ConversationEntry, metadata ConversationMetadata) error { - metadataJSON, err := json.Marshal(map[string]interface{}{"metadata": metadata}) +func (s *JsonlStorage) saveConversationUnlocked(_ context.Context, conversationID string, entries []domain.ConversationEntry, metadata ConversationMetadata) error { + metadataJSON, err := json.Marshal(map[string]any{"metadata": metadata}) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } - entriesJSON, err := json.Marshal(map[string]interface{}{"entries": entries}) + entriesJSON, err := json.Marshal(map[string]any{"entries": entries}) if err != nil { return fmt.Errorf("failed to marshal entries: %w", err) } @@ -308,12 +308,12 @@ func (s *JsonlStorage) UpdateConversationMetadata(ctx context.Context, conversat return fmt.Errorf("failed to unmarshal entries: %w", err) } - metadataJSON, err := json.Marshal(map[string]interface{}{"metadata": metadata}) + metadataJSON, err := json.Marshal(map[string]any{"metadata": metadata}) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } - entriesJSON, err := json.Marshal(map[string]interface{}{"entries": entriesWrapper.Entries}) + entriesJSON, err := json.Marshal(map[string]any{"entries": entriesWrapper.Entries}) if err != nil { return fmt.Errorf("failed to marshal entries: %w", err) } diff --git a/internal/services/agent.go b/internal/services/agent.go index 16d4f9a4..b21ce844 100644 --- a/internal/services/agent.go +++ b/internal/services/agent.go @@ -767,7 +767,7 @@ func (s *AgentServiceImpl) storeIterationMetrics( } } -func (s *AgentServiceImpl) optimizeConversation(ctx context.Context, req *domain.AgentRequest, conversation []sdk.Message, eventPublisher *eventPublisher) []sdk.Message { +func (s *AgentServiceImpl) optimizeConversation(_ context.Context, req *domain.AgentRequest, conversation []sdk.Message, eventPublisher *eventPublisher) []sdk.Message { if s.optimizer == nil { return conversation } diff --git a/internal/services/screenshot_server.go b/internal/services/screenshot_server.go index cb484809..c47ab365 100644 --- a/internal/services/screenshot_server.go +++ b/internal/services/screenshot_server.go @@ -6,14 +6,18 @@ import ( "fmt" "net" "net/http" + "path/filepath" "strconv" "sync" "time" config "github.com/inference-gateway/cli/config" + display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" - tools "github.com/inference-gateway/cli/internal/services/tools" + + _ "github.com/inference-gateway/cli/internal/display/wayland" + _ "github.com/inference-gateway/cli/internal/display/x11" ) // ScreenshotServer provides an HTTP API for screenshot streaming @@ -49,37 +53,34 @@ func (s *ScreenshotServer) Start() error { return fmt.Errorf("screenshot server already running") } - logger.Info("Starting screenshot server", "session_id", s.sessionID) - - // Create circular buffer bufferSize := s.cfg.ComputerUse.Screenshot.BufferSize if bufferSize <= 0 { - bufferSize = 30 // default + bufferSize = 30 } tempDir := s.cfg.ComputerUse.Screenshot.TempDir if tempDir == "" { - tempDir = "/tmp/infer-screenshots" + tempDir = filepath.Join(s.cfg.GetConfigDir(), "tmp", "screenshots") } - logger.Info("Creating screenshot buffer", "buffer_size", bufferSize, "temp_dir", tempDir) + absTempDir, err := filepath.Abs(tempDir) + if err != nil { + return fmt.Errorf("failed to resolve temp directory path: %w", err) + } - buffer, err := NewCircularScreenshotBuffer(bufferSize, tempDir, s.sessionID) + buffer, err := NewCircularScreenshotBuffer(bufferSize, absTempDir, s.sessionID) if err != nil { return fmt.Errorf("failed to create screenshot buffer: %w", err) } s.buffer = buffer - // Listen on random port listener, err := net.Listen("tcp", "localhost:0") if err != nil { return fmt.Errorf("failed to listen: %w", err) } s.port = listener.Addr().(*net.TCPAddr).Port - logger.Info("Screenshot server listening", "port", s.port) - // Create HTTP server mux := http.NewServeMux() mux.HandleFunc("/api/screenshots/latest", s.handleGetLatest) mux.HandleFunc("/api/screenshots", s.handleGetRecent) @@ -89,20 +90,27 @@ func (s *ScreenshotServer) Start() error { Handler: mux, } - // Start HTTP server in goroutine go func() { - logger.Info("Screenshot HTTP server started", "port", s.port) if err := s.server.Serve(listener); err != nil && err != http.ErrServerClosed { logger.Error("Screenshot server error", "error", err) } }() - // Start capture loop s.captureCtx, s.captureStop = context.WithCancel(context.Background()) go s.startCaptureLoop() s.running = true - logger.Info("Screenshot server fully initialized", "port", s.port, "capture_interval", s.cfg.ComputerUse.Screenshot.CaptureInterval) + + interval := s.cfg.ComputerUse.Screenshot.CaptureInterval + if interval <= 0 { + interval = 3 + } + logger.Info("Screenshot server started", + "session_id", s.sessionID, + "port", s.port, + "buffer_size", bufferSize, + "temp_dir", absTempDir, + "capture_interval", interval) return nil } @@ -116,12 +124,10 @@ func (s *ScreenshotServer) Stop() error { return nil } - // Stop capture loop if s.captureStop != nil { s.captureStop() } - // Shutdown HTTP server ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -131,7 +137,6 @@ func (s *ScreenshotServer) Stop() error { } } - // Cleanup buffer if s.buffer != nil { if err := s.buffer.Cleanup(); err != nil { logger.Warn("Failed to cleanup buffer", "error", err) @@ -154,25 +159,21 @@ func (s *ScreenshotServer) Port() int { func (s *ScreenshotServer) startCaptureLoop() { interval := s.cfg.ComputerUse.Screenshot.CaptureInterval if interval <= 0 { - interval = 3 // default: 3 seconds + interval = 3 } - logger.Info("Screenshot capture loop started", "interval_seconds", interval) - ticker := time.NewTicker(time.Duration(interval) * time.Second) defer ticker.Stop() for { select { case <-s.captureCtx.Done(): - logger.Info("Screenshot capture loop stopped") return case <-ticker.C: - logger.Info("Attempting screenshot capture") if err := s.captureScreenshot(); err != nil { logger.Warn("Screenshot capture failed", "error", err) - } else { - logger.Info("Screenshot captured successfully") + } else if s.cfg.ComputerUse.Screenshot.LogCaptures { + logger.Debug("Screenshot captured") } } } @@ -180,43 +181,45 @@ func (s *ScreenshotServer) startCaptureLoop() { // captureScreenshot captures a screenshot and adds it to the buffer func (s *ScreenshotServer) captureScreenshot() error { - // Use the screenshot tool to capture - tool := tools.NewScreenshotTool(s.cfg, s.imageSvc, nil) // No rate limiter for auto-capture - - // Execute with default args (full screen) - result, err := tool.Execute(s.captureCtx, map[string]any{}) + displayProvider, err := display.DetectDisplay() if err != nil { - return err + return fmt.Errorf("no compatible display platform detected: %w", err) } - if !result.Success { - return fmt.Errorf("screenshot capture failed: %s", result.Error) + controller, err := displayProvider.GetController(s.cfg.ComputerUse.Display) + if err != nil { + return fmt.Errorf("failed to get platform controller: %w", err) } + defer func() { + if closeErr := controller.Close(); closeErr != nil { + logger.Warn("Failed to close controller", "error", closeErr) + } + }() - // Extract screenshot data - toolResult, ok := result.Data.(domain.ScreenshotToolResult) - if !ok { - return fmt.Errorf("unexpected result type") + width, height, err := controller.GetScreenDimensions(s.captureCtx) + if err != nil { + return fmt.Errorf("failed to get screen dimensions: %w", err) } - // Get image attachment - if len(result.Images) == 0 { - return fmt.Errorf("no image in result") + imageBytes, err := controller.CaptureScreenBytes(s.captureCtx, nil) + if err != nil { + return fmt.Errorf("failed to capture screenshot: %w", err) } - imageAttachment := result.Images[0] + imageAttachment, err := s.imageSvc.ReadImageFromBinary(imageBytes, "screenshot.png") + if err != nil { + return fmt.Errorf("failed to process image: %w", err) + } - // Create Screenshot object screenshot := &domain.Screenshot{ Timestamp: time.Now(), Data: imageAttachment.Data, - Width: toolResult.Width, - Height: toolResult.Height, - Format: toolResult.Format, - Method: toolResult.Method, + Width: width, + Height: height, + Format: "png", + Method: displayProvider.GetDisplayInfo().Name, } - // Add to buffer return s.buffer.Add(screenshot) } @@ -246,8 +249,7 @@ func (s *ScreenshotServer) handleGetRecent(w http.ResponseWriter, r *http.Reques return } - // Parse limit parameter - limit := 30 // default + limit := 30 if limitStr := r.URL.Query().Get("limit"); limitStr != "" { if parsedLimit, err := strconv.Atoi(limitStr); err == nil { if parsedLimit > 0 && parsedLimit <= 100 { @@ -258,7 +260,7 @@ func (s *ScreenshotServer) handleGetRecent(w http.ResponseWriter, r *http.Reques screenshots := s.buffer.GetRecent(limit) - response := map[string]interface{}{ + response := map[string]any{ "screenshots": screenshots, "count": len(screenshots), } @@ -279,7 +281,7 @@ func (s *ScreenshotServer) handleGetStatus(w http.ResponseWriter, r *http.Reques s.mu.RLock() defer s.mu.RUnlock() - status := map[string]interface{}{ + status := map[string]any{ "running": s.running, "count": s.buffer.Count(), "interval_sec": s.cfg.ComputerUse.Screenshot.CaptureInterval, diff --git a/internal/services/tools/keyboard_type.go b/internal/services/tools/keyboard_type.go index 93784c18..5524e28f 100644 --- a/internal/services/tools/keyboard_type.go +++ b/internal/services/tools/keyboard_type.go @@ -6,25 +6,29 @@ import ( "time" config "github.com/inference-gateway/cli/config" + display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" + logger "github.com/inference-gateway/cli/internal/logger" sdk "github.com/inference-gateway/sdk" ) // KeyboardTypeTool types text or sends key combinations type KeyboardTypeTool struct { - config *config.Config - enabled bool - formatter domain.BaseFormatter - rateLimiter *RateLimiter + config *config.Config + enabled bool + formatter domain.BaseFormatter + rateLimiter domain.RateLimiter + displayProvider display.Provider } // NewKeyboardTypeTool creates a new keyboard type tool -func NewKeyboardTypeTool(cfg *config.Config, rateLimiter *RateLimiter) *KeyboardTypeTool { +func NewKeyboardTypeTool(cfg *config.Config, rateLimiter domain.RateLimiter, displayProvider display.Provider) *KeyboardTypeTool { return &KeyboardTypeTool{ - config: cfg, - enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.KeyboardType.Enabled, - formatter: domain.NewBaseFormatter("KeyboardType"), - rateLimiter: rateLimiter, + config: cfg, + enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.KeyboardType.Enabled, + formatter: domain.NewBaseFormatter("KeyboardType"), + rateLimiter: rateLimiter, + displayProvider: displayProvider, } } @@ -95,87 +99,59 @@ func (t *KeyboardTypeTool) Execute(ctx context.Context, args map[string]any) (*d }, nil } - display := t.config.ComputerUse.Display + displayName := t.config.ComputerUse.Display if displayArg, ok := args["display"].(string); ok && displayArg != "" { - display = displayArg + displayName = displayArg } - displayServer := DetectDisplayServer() - - switch displayServer { - case DisplayServerX11: - client, err := NewX11Client(display) - if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "KeyboardType", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } - defer client.Close() - - if hasText { - err = client.TypeText(text, t.config.ComputerUse.KeyboardType.TypingDelayMs) - } else { - err = client.SendKeyCombo(keyCombo) - } - - if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "KeyboardType", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } - - case DisplayServerWayland: - client, err := NewWaylandClient(display) - if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "KeyboardType", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } - defer client.Close() + if t.displayProvider == nil { + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "no compatible display platform detected", + }, nil + } - if hasText { - err = client.TypeText(text, t.config.ComputerUse.KeyboardType.TypingDelayMs) - } else { - err = client.SendKeyCombo(keyCombo) + controller, err := t.displayProvider.GetController(displayName) + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "KeyboardType", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: fmt.Sprintf("failed to get platform controller: %v", err), + }, nil + } + defer func() { + if closeErr := controller.Close(); closeErr != nil { + logger.Warn("Failed to close controller", "error", closeErr) } + }() - if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "KeyboardType", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } + var execErr error + if hasText { + execErr = controller.TypeText(ctx, text, t.config.ComputerUse.KeyboardType.TypingDelayMs) + } else { + execErr = controller.SendKeyCombo(ctx, keyCombo) + } - default: + if execErr != nil { return &domain.ToolExecutionResult{ ToolName: "KeyboardType", Arguments: args, Success: false, Duration: time.Since(start), - Error: "no display server detected (neither X11 nor Wayland)", + Error: fmt.Sprintf("keyboard action failed: %v", execErr), }, nil } result := domain.KeyboardTypeToolResult{ Text: text, KeyCombo: keyCombo, - Display: display, - Method: displayServer.String(), + Display: displayName, + Method: t.displayProvider.GetDisplayInfo().Name, } return &domain.ToolExecutionResult{ diff --git a/internal/services/tools/keyboard_type_test.go b/internal/services/tools/keyboard_type_test.go index 3e3f068b..0b7f0df1 100644 --- a/internal/services/tools/keyboard_type_test.go +++ b/internal/services/tools/keyboard_type_test.go @@ -6,6 +6,7 @@ import ( config "github.com/inference-gateway/cli/config" domain "github.com/inference-gateway/cli/internal/domain" + utils "github.com/inference-gateway/cli/internal/utils" ) func TestKeyboardTypeTool_TypingDelay(t *testing.T) { @@ -57,7 +58,7 @@ func TestKeyboardTypeTool_TypingDelay(t *testing.T) { }, } - tool := NewKeyboardTypeTool(cfg, NewRateLimiter(cfg.ComputerUse.RateLimit)) + tool := NewKeyboardTypeTool(cfg, utils.NewRateLimiter(cfg.ComputerUse.RateLimit), nil) if tool.config.ComputerUse.KeyboardType.TypingDelayMs != tt.delayMs { t.Errorf("Expected delay %d ms, got %d ms", tt.delayMs, tool.config.ComputerUse.KeyboardType.TypingDelayMs) @@ -94,7 +95,7 @@ func TestKeyboardTypeTool_Validation(t *testing.T) { }, } - tool := NewKeyboardTypeTool(cfg, NewRateLimiter(cfg.ComputerUse.RateLimit)) + tool := NewKeyboardTypeTool(cfg, utils.NewRateLimiter(cfg.ComputerUse.RateLimit), nil) tests := []struct { name string @@ -205,7 +206,7 @@ func TestKeyboardTypeTool_FormatResult(t *testing.T) { }, } - tool := NewKeyboardTypeTool(cfg, NewRateLimiter(cfg.ComputerUse.RateLimit)) + tool := NewKeyboardTypeTool(cfg, utils.NewRateLimiter(cfg.ComputerUse.RateLimit), nil) result := &domain.ToolExecutionResult{ ToolName: "KeyboardType", diff --git a/internal/services/tools/mouse_click.go b/internal/services/tools/mouse_click.go index 7d4a0bc7..5ccf5a05 100644 --- a/internal/services/tools/mouse_click.go +++ b/internal/services/tools/mouse_click.go @@ -6,25 +6,29 @@ import ( "time" config "github.com/inference-gateway/cli/config" + display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" + logger "github.com/inference-gateway/cli/internal/logger" sdk "github.com/inference-gateway/sdk" ) // MouseClickTool performs mouse clicks type MouseClickTool struct { - config *config.Config - enabled bool - formatter domain.BaseFormatter - rateLimiter *RateLimiter + config *config.Config + enabled bool + formatter domain.BaseFormatter + rateLimiter domain.RateLimiter + displayProvider display.Provider } // NewMouseClickTool creates a new mouse click tool -func NewMouseClickTool(cfg *config.Config, rateLimiter *RateLimiter) *MouseClickTool { +func NewMouseClickTool(cfg *config.Config, rateLimiter domain.RateLimiter, displayProvider display.Provider) *MouseClickTool { return &MouseClickTool{ - config: cfg, - enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.MouseClick.Enabled, - formatter: domain.NewBaseFormatter("MouseClick"), - rateLimiter: rateLimiter, + config: cfg, + enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.MouseClick.Enabled, + formatter: domain.NewBaseFormatter("MouseClick"), + rateLimiter: rateLimiter, + displayProvider: displayProvider, } } @@ -95,9 +99,9 @@ func (t *MouseClickTool) Execute(ctx context.Context, args map[string]any) (*dom clicks = int(clicksArg) } - display := t.config.ComputerUse.Display + displayName := t.config.ComputerUse.Display if displayArg, ok := args["display"].(string); ok && displayArg != "" { - display = displayArg + displayName = displayArg } var finalX, finalY int @@ -111,89 +115,55 @@ func (t *MouseClickTool) Execute(ctx context.Context, args map[string]any) (*dom } } - displayServer := DetectDisplayServer() - - switch displayServer { - case DisplayServerX11: - client, err := NewX11Client(display) - if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseClick", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } - defer client.Close() - - if shouldMove { - if err := client.MoveMouse(finalX, finalY); err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseClick", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: fmt.Sprintf("failed to move mouse: %v", err), - }, nil - } - } else { - x, y, _ := client.GetCursorPosition() - finalX, finalY = x, y - } - - if err := client.ClickMouse(button, clicks); err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseClick", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } - - case DisplayServerWayland: - client, err := NewWaylandClient(display) - if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseClick", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } - defer client.Close() + if t.displayProvider == nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "no compatible display platform detected", + }, nil + } - if shouldMove { - if err := client.MoveMouse(finalX, finalY); err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseClick", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: fmt.Sprintf("failed to move mouse: %v", err), - }, nil - } + controller, err := t.displayProvider.GetController(displayName) + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: fmt.Sprintf("failed to get platform controller: %v", err), + }, nil + } + defer func() { + if closeErr := controller.Close(); closeErr != nil { + logger.Warn("Failed to close controller", "error", closeErr) } + }() - if err := client.ClickMouse(button, clicks); err != nil { + if shouldMove { + if err := controller.MoveMouse(ctx, finalX, finalY); err != nil { return &domain.ToolExecutionResult{ ToolName: "MouseClick", Arguments: args, Success: false, Duration: time.Since(start), - Error: err.Error(), + Error: fmt.Sprintf("failed to move mouse: %v", err), }, nil } + } else { + x, y, _ := controller.GetCursorPosition(ctx) + finalX, finalY = x, y + } - default: + mouseButton := display.ParseMouseButton(button) + if err := controller.ClickMouse(ctx, mouseButton, clicks); err != nil { return &domain.ToolExecutionResult{ ToolName: "MouseClick", Arguments: args, Success: false, Duration: time.Since(start), - Error: "no display server detected (neither X11 nor Wayland)", + Error: fmt.Sprintf("failed to click mouse: %v", err), }, nil } @@ -202,8 +172,8 @@ func (t *MouseClickTool) Execute(ctx context.Context, args map[string]any) (*dom Clicks: clicks, X: finalX, Y: finalY, - Display: display, - Method: displayServer.String(), + Display: displayName, + Method: t.displayProvider.GetDisplayInfo().Name, } return &domain.ToolExecutionResult{ diff --git a/internal/services/tools/mouse_move.go b/internal/services/tools/mouse_move.go index 99b88ad7..2210b077 100644 --- a/internal/services/tools/mouse_move.go +++ b/internal/services/tools/mouse_move.go @@ -6,25 +6,29 @@ import ( "time" config "github.com/inference-gateway/cli/config" + display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" + logger "github.com/inference-gateway/cli/internal/logger" sdk "github.com/inference-gateway/sdk" ) // MouseMoveTool moves the mouse cursor to specified coordinates type MouseMoveTool struct { - config *config.Config - enabled bool - formatter domain.BaseFormatter - rateLimiter *RateLimiter + config *config.Config + enabled bool + formatter domain.BaseFormatter + rateLimiter domain.RateLimiter + displayProvider display.Provider } // NewMouseMoveTool creates a new mouse move tool -func NewMouseMoveTool(cfg *config.Config, rateLimiter *RateLimiter) *MouseMoveTool { +func NewMouseMoveTool(cfg *config.Config, rateLimiter domain.RateLimiter, displayProvider display.Provider) *MouseMoveTool { return &MouseMoveTool{ - config: cfg, - enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.MouseMove.Enabled, - formatter: domain.NewBaseFormatter("MouseMove"), - rateLimiter: rateLimiter, + config: cfg, + enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.MouseMove.Enabled, + formatter: domain.NewBaseFormatter("MouseMove"), + rateLimiter: rateLimiter, + displayProvider: displayProvider, } } @@ -86,72 +90,46 @@ func (t *MouseMoveTool) Execute(ctx context.Context, args map[string]any) (*doma }, nil } - display := t.config.ComputerUse.Display + displayName := t.config.ComputerUse.Display if displayArg, ok := args["display"].(string); ok && displayArg != "" { - display = displayArg - } - - var fromX, fromY int - displayServer := DetectDisplayServer() - - switch displayServer { - case DisplayServerX11: - client, err := NewX11Client(display) - if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseMove", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } - defer client.Close() - - fromX, fromY, _ = client.GetCursorPosition() - - if err := client.MoveMouse(int(x), int(y)); err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseMove", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } + displayName = displayArg + } - case DisplayServerWayland: - client, err := NewWaylandClient(display) - if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseMove", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } - defer client.Close() - - fromX, fromY = 0, 0 - - if err := client.MoveMouse(int(x), int(y)); err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseMove", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil + if t.displayProvider == nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseMove", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "no compatible display platform detected", + }, nil + } + + controller, err := t.displayProvider.GetController(displayName) + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseMove", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: fmt.Sprintf("failed to get platform controller: %v", err), + }, nil + } + defer func() { + if closeErr := controller.Close(); closeErr != nil { + logger.Warn("Failed to close controller", "error", closeErr) } + }() - default: + fromX, fromY, _ := controller.GetCursorPosition(ctx) + + if err := controller.MoveMouse(ctx, int(x), int(y)); err != nil { return &domain.ToolExecutionResult{ ToolName: "MouseMove", Arguments: args, Success: false, Duration: time.Since(start), - Error: "no display server detected (neither X11 nor Wayland)", + Error: fmt.Sprintf("failed to move mouse: %v", err), }, nil } @@ -160,8 +138,8 @@ func (t *MouseMoveTool) Execute(ctx context.Context, args map[string]any) (*doma FromY: fromY, ToX: int(x), ToY: int(y), - Display: display, - Method: displayServer.String(), + Display: displayName, + Method: t.displayProvider.GetDisplayInfo().Name, } return &domain.ToolExecutionResult{ diff --git a/internal/services/tools/registry.go b/internal/services/tools/registry.go index 223e574d..e2d15352 100644 --- a/internal/services/tools/registry.go +++ b/internal/services/tools/registry.go @@ -7,10 +7,15 @@ import ( "time" config "github.com/inference-gateway/cli/config" + display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" utils "github.com/inference-gateway/cli/internal/utils" sdk "github.com/inference-gateway/sdk" + + _ "github.com/inference-gateway/cli/internal/display/macos" + _ "github.com/inference-gateway/cli/internal/display/wayland" + _ "github.com/inference-gateway/cli/internal/display/x11" ) // Registry manages all available tools @@ -79,11 +84,15 @@ func (r *Registry) registerTools() { } if r.config.ComputerUse.Enabled { - rateLimiter := NewRateLimiter(r.config.ComputerUse.RateLimit) - r.tools["Screenshot"] = NewScreenshotTool(r.config, r.imageService, rateLimiter) - r.tools["MouseMove"] = NewMouseMoveTool(r.config, rateLimiter) - r.tools["MouseClick"] = NewMouseClickTool(r.config, rateLimiter) - r.tools["KeyboardType"] = NewKeyboardTypeTool(r.config, rateLimiter) + displayProvider, err := display.DetectDisplay() + if err != nil { + logger.Warn("No compatible display platform detected, computer use tools will be disabled", "error", err) + } else { + rateLimiter := utils.NewRateLimiter(r.config.ComputerUse.RateLimit) + r.tools["MouseMove"] = NewMouseMoveTool(r.config, rateLimiter, displayProvider) + r.tools["MouseClick"] = NewMouseClickTool(r.config, rateLimiter, displayProvider) + r.tools["KeyboardType"] = NewKeyboardTypeTool(r.config, rateLimiter, displayProvider) + } } if r.config.MCP.Enabled && r.mcpManager != nil { diff --git a/internal/services/tools/screenshot.go b/internal/services/tools/screenshot.go deleted file mode 100644 index 69b75e1f..00000000 --- a/internal/services/tools/screenshot.go +++ /dev/null @@ -1,424 +0,0 @@ -package tools - -import ( - "bytes" - "context" - "encoding/base64" - "fmt" - "image" - "image/draw" - "image/jpeg" - "image/png" - "time" - - config "github.com/inference-gateway/cli/config" - domain "github.com/inference-gateway/cli/internal/domain" - sdk "github.com/inference-gateway/sdk" - xdraw "golang.org/x/image/draw" -) - -// ScreenshotTool captures screenshots of the display -type ScreenshotTool struct { - config *config.Config - enabled bool - formatter domain.BaseFormatter - imageService domain.ImageService - rateLimiter *RateLimiter -} - -// NewScreenshotTool creates a new screenshot tool -func NewScreenshotTool(cfg *config.Config, imageService domain.ImageService, rateLimiter *RateLimiter) *ScreenshotTool { - return &ScreenshotTool{ - config: cfg, - enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.Screenshot.Enabled, - formatter: domain.NewBaseFormatter("Screenshot"), - imageService: imageService, - rateLimiter: rateLimiter, - } -} - -// Definition returns the tool definition for the LLM -func (t *ScreenshotTool) Definition() sdk.ChatCompletionTool { - description := "Captures a screenshot of the display. This is a read-only operation that does NOT require approval. Can capture the entire screen or a specific region." - return sdk.ChatCompletionTool{ - Type: sdk.Function, - Function: sdk.FunctionObject{ - Name: "Screenshot", - Description: &description, - Parameters: &sdk.FunctionParameters{ - "type": "object", - "properties": map[string]any{ - "region": map[string]any{ - "type": "object", - "description": "Optional region to capture. If not specified, captures the entire screen.", - "properties": map[string]any{ - "x": map[string]any{ - "type": "integer", - "description": "X coordinate of the top-left corner", - }, - "y": map[string]any{ - "type": "integer", - "description": "Y coordinate of the top-left corner", - }, - "width": map[string]any{ - "type": "integer", - "description": "Width of the region", - }, - "height": map[string]any{ - "type": "integer", - "description": "Height of the region", - }, - }, - }, - "display": map[string]any{ - "type": "string", - "description": "Display to capture from (e.g., ':0'). Defaults to ':0'.", - "default": ":0", - }, - }, - }, - }, - } -} - -// Execute runs the screenshot tool with given arguments -func (t *ScreenshotTool) Execute(ctx context.Context, args map[string]any) (*domain.ToolExecutionResult, error) { - start := time.Now() - - if t.rateLimiter != nil { - if err := t.rateLimiter.CheckAndRecord("Screenshot"); err != nil { - return &domain.ToolExecutionResult{ - ToolName: "Screenshot", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } - } - - display := t.config.ComputerUse.Display - if displayArg, ok := args["display"].(string); ok && displayArg != "" { - display = displayArg - } - - region, x, y, width, height := parseRegionArgs(args) - - displayServer := DetectDisplayServer() - method := displayServer.String() - - var imageBytes []byte - var captureWidth, captureHeight int - var err error - - switch displayServer { - case DisplayServerX11: - imageBytes, captureWidth, captureHeight, err = t.captureX11(display, x, y, width, height) - case DisplayServerWayland: - imageBytes, captureWidth, captureHeight, err = t.captureWayland(display, x, y, width, height) - default: - return &domain.ToolExecutionResult{ - ToolName: "Screenshot", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: "no display server detected (neither X11 nor Wayland)", - }, nil - } - - if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "Screenshot", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil - } - - optimized, err := t.optimizeScreenshot(imageBytes) - if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "Screenshot", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: fmt.Sprintf("failed to optimize screenshot: %v", err), - }, nil - } - - base64Data := base64.StdEncoding.EncodeToString(optimized) - - mimeType := "image/" + t.config.ComputerUse.Screenshot.Format - imageAttachment := domain.ImageAttachment{ - Data: base64Data, - MimeType: mimeType, - DisplayName: fmt.Sprintf("screenshot-%s", display), - } - - result := domain.ScreenshotToolResult{ - Display: display, - Region: region, - Width: captureWidth, - Height: captureHeight, - Format: t.config.ComputerUse.Screenshot.Format, - Method: method, - } - - return &domain.ToolExecutionResult{ - ToolName: "Screenshot", - Arguments: args, - Success: true, - Duration: time.Since(start), - Data: result, - Images: []domain.ImageAttachment{imageAttachment}, - }, nil -} - -// captureX11 captures a screenshot using X11 -// parseRegionArgs extracts region parameters from tool arguments -func parseRegionArgs(args map[string]any) (*domain.ScreenRegion, int, int, int, int) { - var region *domain.ScreenRegion - var x, y, width, height int - - regionArg, ok := args["region"].(map[string]any) - if !ok { - return nil, 0, 0, 0, 0 - } - - region = &domain.ScreenRegion{} - if xVal, ok := regionArg["x"].(float64); ok { - region.X = int(xVal) - x = int(xVal) - } - if yVal, ok := regionArg["y"].(float64); ok { - region.Y = int(yVal) - y = int(yVal) - } - if wVal, ok := regionArg["width"].(float64); ok { - region.Width = int(wVal) - width = int(wVal) - } - if hVal, ok := regionArg["height"].(float64); ok { - region.Height = int(hVal) - height = int(hVal) - } - - return region, x, y, width, height -} - -func (t *ScreenshotTool) captureX11(display string, x, y, width, height int) ([]byte, int, int, error) { - client, err := NewX11Client(display) - if err != nil { - return nil, 0, 0, err - } - defer client.Close() - - if width == 0 || height == 0 { - width, height = client.GetScreenDimensions() - x, y = 0, 0 - } - - imageBytes, err := client.CaptureScreenBytes(x, y, width, height) - if err != nil { - return nil, 0, 0, err - } - - return imageBytes, width, height, nil -} - -// captureWayland captures a screenshot using Wayland tools -func (t *ScreenshotTool) captureWayland(display string, x, y, width, height int) ([]byte, int, int, error) { - client, err := NewWaylandClient(display) - if err != nil { - return nil, 0, 0, err - } - defer client.Close() - - if width == 0 || height == 0 { - w, h, err := client.GetScreenDimensions() - if err != nil { - w, h = 1920, 1080 - } - width, height = w, h - x, y = 0, 0 - } - - imageBytes, err := client.CaptureScreenBytes(x, y, width, height) - if err != nil { - return nil, 0, 0, err - } - - return imageBytes, width, height, nil -} - -// optimizeScreenshot optimizes the screenshot image by resizing and compressing -func (t *ScreenshotTool) optimizeScreenshot(imageBytes []byte) ([]byte, error) { - img, format, err := image.Decode(bytes.NewReader(imageBytes)) - if err != nil { - return nil, fmt.Errorf("failed to decode image: %w", err) - } - - img = t.resizeIfNeeded(img) - - return t.encodeImage(img, format) -} - -// resizeIfNeeded resizes the image if it exceeds max dimensions -func (t *ScreenshotTool) resizeIfNeeded(img image.Image) image.Image { - bounds := img.Bounds() - width := bounds.Dx() - height := bounds.Dy() - - maxWidth := t.config.ComputerUse.Screenshot.MaxWidth - maxHeight := t.config.ComputerUse.Screenshot.MaxHeight - - if maxWidth <= 0 && maxHeight <= 0 { - return img - } - - needsResize := false - newWidth := width - newHeight := height - - if maxWidth > 0 && width > maxWidth { - needsResize = true - ratio := float64(maxWidth) / float64(width) - newWidth = maxWidth - newHeight = int(float64(height) * ratio) - } - - if maxHeight > 0 && newHeight > maxHeight { - needsResize = true - ratio := float64(maxHeight) / float64(newHeight) - newHeight = maxHeight - newWidth = int(float64(newWidth) * ratio) - } - - if !needsResize { - return img - } - - dst := image.NewRGBA(image.Rect(0, 0, newWidth, newHeight)) - xdraw.NearestNeighbor.Scale(dst, dst.Bounds(), img, bounds, draw.Src, nil) - return dst -} - -// encodeImage encodes the image to bytes based on configuration -func (t *ScreenshotTool) encodeImage(img image.Image, originalFormat string) ([]byte, error) { - var buf bytes.Buffer - format := t.config.ComputerUse.Screenshot.Format - quality := t.config.ComputerUse.Screenshot.Quality - - if quality <= 0 || quality > 100 { - quality = 85 - } - - switch format { - case "jpeg", "jpg": - err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}) - if err != nil { - return nil, fmt.Errorf("failed to encode jpeg: %w", err) - } - case "png": - encoder := png.Encoder{ - CompressionLevel: png.DefaultCompression, - } - if err := encoder.Encode(&buf, img); err != nil { - return nil, fmt.Errorf("failed to encode png: %w", err) - } - default: - if err := png.Encode(&buf, img); err != nil { - return nil, fmt.Errorf("failed to encode default png: %w", err) - } - } - - return buf.Bytes(), nil -} - -// Validate checks if the tool arguments are valid -func (t *ScreenshotTool) Validate(args map[string]any) error { - regionArg, ok := args["region"].(map[string]any) - if !ok { - return nil - } - - x, xOk := regionArg["x"].(float64) - y, yOk := regionArg["y"].(float64) - width, wOk := regionArg["width"].(float64) - height, hOk := regionArg["height"].(float64) - - if xOk && x < 0 { - return fmt.Errorf("region x must be >= 0") - } - if yOk && y < 0 { - return fmt.Errorf("region y must be >= 0") - } - if wOk && width <= 0 { - return fmt.Errorf("region width must be > 0") - } - if hOk && height <= 0 { - return fmt.Errorf("region height must be > 0") - } - - return nil -} - -// IsEnabled returns whether this tool is enabled -func (t *ScreenshotTool) IsEnabled() bool { - if t.config.ComputerUse.Screenshot.StreamingEnabled { - return false - } - return t.enabled -} - -// FormatResult formats tool execution results for different contexts -func (t *ScreenshotTool) FormatResult(result *domain.ToolExecutionResult, formatType domain.FormatterType) string { - switch formatType { - case domain.FormatterLLM: - return t.FormatForLLM(result) - case domain.FormatterShort: - return t.FormatPreview(result) - default: - return t.FormatForLLM(result) - } -} - -// FormatPreview returns a short preview of the result for UI display -func (t *ScreenshotTool) FormatPreview(result *domain.ToolExecutionResult) string { - if result == nil || !result.Success { - return "Screenshot capture failed" - } - data, ok := result.Data.(domain.ScreenshotToolResult) - if !ok { - return "Screenshot captured" - } - return fmt.Sprintf("Screenshot captured: %dx%d (%s)", data.Width, data.Height, data.Method) -} - -// FormatForLLM formats the result for LLM consumption -func (t *ScreenshotTool) FormatForLLM(result *domain.ToolExecutionResult) string { - if result == nil || !result.Success { - return fmt.Sprintf("Error: %s", result.Error) - } - data, ok := result.Data.(domain.ScreenshotToolResult) - if !ok { - return "Screenshot captured successfully. Image is attached." - } - regionStr := "full screen" - if data.Region != nil { - regionStr = fmt.Sprintf("region x=%d y=%d w=%d h=%d", data.Region.X, data.Region.Y, data.Region.Width, data.Region.Height) - } - return fmt.Sprintf("Screenshot captured successfully (%s, %dx%d, format: %s, method: %s). Image is attached.", - regionStr, data.Width, data.Height, data.Format, data.Method) -} - -// ShouldCollapseArg determines if an argument should be collapsed in display -func (t *ScreenshotTool) ShouldCollapseArg(key string) bool { - return false -} - -// ShouldAlwaysExpand determines if tool results should always be expanded in UI -func (t *ScreenshotTool) ShouldAlwaysExpand() bool { - return false -} diff --git a/internal/services/tools/computer_use_common.go b/internal/utils/ratelimiter.go similarity index 55% rename from internal/services/tools/computer_use_common.go rename to internal/utils/ratelimiter.go index aca53ce1..dc8ab857 100644 --- a/internal/services/tools/computer_use_common.go +++ b/internal/utils/ratelimiter.go @@ -1,60 +1,27 @@ -package tools +package utils import ( "fmt" - "os" "sync" "time" config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" ) -// DisplayServer represents the type of display server -type DisplayServer int - -const ( - DisplayServerX11 DisplayServer = iota - DisplayServerWayland - DisplayServerUnknown -) - -// DetectDisplayServer detects which display server is running -func DetectDisplayServer() DisplayServer { - // Check for Wayland first (more modern) - if os.Getenv("WAYLAND_DISPLAY") != "" { - return DisplayServerWayland - } - - // Check for X11 - if os.Getenv("DISPLAY") != "" { - return DisplayServerX11 - } - - return DisplayServerUnknown -} - -// GetDisplayName returns a string name for the display server -func (ds DisplayServer) String() string { - switch ds { - case DisplayServerX11: - return "x11" - case DisplayServerWayland: - return "wayland" - default: - return "unknown" - } -} - -// RateLimiter implements token bucket rate limiting for computer use actions -type RateLimiter struct { +// TokenBucketRateLimiter implements token bucket rate limiting for computer use actions +type TokenBucketRateLimiter struct { cfg *config.RateLimitConfig actionTimes []time.Time mu sync.Mutex } +// Ensure TokenBucketRateLimiter implements domain.RateLimiter +var _ domain.RateLimiter = (*TokenBucketRateLimiter)(nil) + // NewRateLimiter creates a new rate limiter -func NewRateLimiter(cfg config.RateLimitConfig) *RateLimiter { - return &RateLimiter{ +func NewRateLimiter(cfg config.RateLimitConfig) domain.RateLimiter { + return &TokenBucketRateLimiter{ cfg: &cfg, actionTimes: make([]time.Time, 0), } @@ -62,9 +29,9 @@ func NewRateLimiter(cfg config.RateLimitConfig) *RateLimiter { // CheckAndRecord checks if the action is within rate limits and records it // Returns an error if the rate limit is exceeded -func (rl *RateLimiter) CheckAndRecord(toolName string) error { +func (rl *TokenBucketRateLimiter) CheckAndRecord(toolName string) error { if !rl.cfg.Enabled { - return nil // Rate limiting disabled + return nil } rl.mu.Lock() @@ -73,7 +40,6 @@ func (rl *RateLimiter) CheckAndRecord(toolName string) error { now := time.Now() windowStart := now.Add(-time.Duration(rl.cfg.WindowSeconds) * time.Second) - // Remove actions outside the time window validActions := make([]time.Time, 0) for _, t := range rl.actionTimes { if t.After(windowStart) { @@ -82,19 +48,17 @@ func (rl *RateLimiter) CheckAndRecord(toolName string) error { } rl.actionTimes = validActions - // Check if at limit if len(rl.actionTimes) >= rl.cfg.MaxActionsPerMinute { return fmt.Errorf("rate limit exceeded: maximum %d actions per %d seconds (current: %d actions in window)", rl.cfg.MaxActionsPerMinute, rl.cfg.WindowSeconds, len(rl.actionTimes)) } - // Record the new action rl.actionTimes = append(rl.actionTimes, now) return nil } // GetCurrentCount returns the number of actions in the current window -func (rl *RateLimiter) GetCurrentCount() int { +func (rl *TokenBucketRateLimiter) GetCurrentCount() int { if !rl.cfg.Enabled { return 0 } @@ -116,7 +80,7 @@ func (rl *RateLimiter) GetCurrentCount() int { } // Reset clears all recorded actions -func (rl *RateLimiter) Reset() { +func (rl *TokenBucketRateLimiter) Reset() { rl.mu.Lock() defer rl.mu.Unlock() rl.actionTimes = make([]time.Time, 0) diff --git a/internal/web/pty_manager.go b/internal/web/pty_manager.go index 6839f86d..80a3b956 100644 --- a/internal/web/pty_manager.go +++ b/internal/web/pty_manager.go @@ -95,7 +95,7 @@ func ensureRemoteBinary(client *SSHClient, webCfg *config.WebConfig, serverCfg * // ensureRemoteConfig ensures infer config exists on remote server // Runs infer init --userspace if ~/.infer/config.yaml doesn't exist -func ensureRemoteConfig(client *SSHClient, serverCfg *config.SSHServerConfig, gatewayURL string) error { +func ensureRemoteConfig(client *SSHClient, serverCfg *config.SSHServerConfig, _ string) error { commandPath := serverCfg.CommandPath if commandPath == "" { commandPath = "infer" diff --git a/internal/web/server.go b/internal/web/server.go index f08ed808..3ebc63fb 100644 --- a/internal/web/server.go +++ b/internal/web/server.go @@ -143,7 +143,7 @@ func (s *WebTerminalServer) handleServers(w http.ResponseWriter, r *http.Request } w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]interface{}{ + if err := json.NewEncoder(w).Encode(map[string]any{ "servers": servers, "ssh_enabled": s.cfg.Web.SSH.Enabled, }); err != nil { diff --git a/internal/web/ssh_session.go b/internal/web/ssh_session.go index f6b7f356..2cc74fa3 100644 --- a/internal/web/ssh_session.go +++ b/internal/web/ssh_session.go @@ -80,9 +80,9 @@ func (s *SSHSession) Start(cols, rows int) error { s.session = session modes := ssh.TerminalModes{ - ssh.ECHO: 1, // Enable echoing - ssh.TTY_OP_ISPEED: 14400, // Input speed = 14.4kbaud - ssh.TTY_OP_OSPEED: 14400, // Output speed = 14.4kbaud + ssh.ECHO: 1, + ssh.TTY_OP_ISPEED: 14400, + ssh.TTY_OP_OSPEED: 14400, } if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil { @@ -451,7 +451,7 @@ func (s *SSHSession) notifyWebUI(localPort int) { return } - msg := map[string]interface{}{ + msg := map[string]any{ "type": "screenshot_port", "port": localPort, } diff --git a/internal/web/static/preview-overlay.js b/internal/web/static/preview-overlay.js index 4c4c41d2..f986af60 100644 --- a/internal/web/static/preview-overlay.js +++ b/internal/web/static/preview-overlay.js @@ -9,6 +9,7 @@ class ScreenshotOverlay { this.pollInterval = null; this.pollingFrequency = 2000; // 2 seconds this.overlayElement = null; + this.skeletonElement = null; this.imageElement = null; this.timestampElement = null; this.dimensionsElement = null; @@ -33,7 +34,8 @@ class ScreenshotOverlay {
Connecting...
- Remote preview +
+
@@ -41,7 +43,7 @@ class ScreenshotOverlay {
`; - // Cache DOM elements + this.skeletonElement = this.overlayElement.querySelector('.screenshot-skeleton'); this.imageElement = this.overlayElement.querySelector('.screenshot-image'); this.timestampElement = this.overlayElement.querySelector('.screenshot-timestamp'); this.dimensionsElement = this.overlayElement.querySelector('.screenshot-dimensions'); @@ -72,6 +74,9 @@ class ScreenshotOverlay { this.sessionID = sessionID; console.log(`Screenshot overlay: starting polling for session ${sessionID}`); + this.skeletonElement.classList.remove('hidden'); + this.imageElement.classList.add('hidden'); + // Show status this.updateStatus('Loading preview...'); this.hideError(); @@ -135,9 +140,11 @@ class ScreenshotOverlay { return; } + this.skeletonElement.classList.add('hidden'); + this.imageElement.classList.remove('hidden'); + // Update image this.imageElement.src = `data:image/${screenshot.format};base64,${screenshot.data}`; - this.imageElement.style.display = 'block'; // Update timestamp const timestamp = new Date(screenshot.timestamp); diff --git a/internal/web/templates/index.html b/internal/web/templates/index.html index 1e49c851..61efb28f 100644 --- a/internal/web/templates/index.html +++ b/internal/web/templates/index.html @@ -306,6 +306,46 @@ display: block; margin-bottom: 12px; } + .screenshot-image.hidden { + display: none; + } + .screenshot-skeleton { + width: 100%; + height: 400px; + border-radius: 4px; + border: 1px solid #414868; + margin-bottom: 12px; + background: linear-gradient( + 90deg, + #1a1b26 0%, + #24283b 50%, + #1a1b26 100% + ); + background-size: 200% 100%; + animation: loading 1.5s ease-in-out infinite; + position: relative; + overflow: hidden; + } + .screenshot-skeleton::after { + content: 'Loading preview...'; + position: absolute; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); + color: #565f89; + font-size: 14px; + } + .screenshot-skeleton.hidden { + display: none; + } + @keyframes loading { + 0% { + background-position: 200% 0; + } + 100% { + background-position: -200% 0; + } + } .screenshot-info { display: flex; justify-content: space-between; From ab05f030542534b6e76bda2ecb829948cf66db6b Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Sat, 3 Jan 2026 19:20:54 +0200 Subject: [PATCH 05/14] docs: Add computer use tools documentation --- .infer/config.yaml | 61 +++++++++++++++++++++++++++++++++++++++++ config/config.go | 63 ++++++++++++++++++++++++++++++++++++++++++- config/config_test.go | 35 ++++++++++++++++++++++++ cspell.yaml | 3 +++ 4 files changed, 161 insertions(+), 1 deletion(-) diff --git a/.infer/config.yaml b/.infer/config.yaml index f09404a2..789ae401 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -169,6 +169,41 @@ agent: - The system supports up to 5 concurrent tool executions by default - This reduces back-and-forth communication and significantly improves performance + COMPUTER USE TOOLS: + You have TWO ways to interact with the system: + 1. Direct terminal tools (PRIMARY): Bash, Read, Write, Edit, Grep, etc. + 2. GUI automation tools (FALLBACK): MouseMove, KeyboardType, MouseClick, GetLatestScreenshot + + CRITICAL: ALWAYS prefer direct terminal tools over GUI automation when possible. + + When to use DIRECT tools (preferred): + - Reading files: Use Read tool, NOT KeyboardType to open an editor + - Writing files: Use Write/Edit tools, NOT GUI text editor + - Running commands: Use Bash tool, NOT KeyboardType in a terminal window + - Searching code: Use Grep tool, NOT opening files via GUI + - File operations: Use Bash/Read/Write, NOT GUI file manager + + When to use GUI tools (only when necessary): + - Interacting with graphical applications that have no CLI equivalent + - Testing UI behavior or visual elements + - Automating tasks that MUST be done through a GUI + - Taking screenshots to inspect visual state + + Why prefer direct tools: + - 10-100x faster execution (no GUI rendering delays) + - More reliable (no window focus issues, no timing problems) + - Precise output (structured data, not visual interpretation) + - Parallel execution support (batch multiple operations) + - Lower resource usage (no display server overhead) + + Example - WRONG approach: + MouseMove(x=100, y=200) + MouseClick(button="left") + KeyboardType(text="cat file.txt") + + Example - CORRECT approach: + Read(file_path="/path/to/file.txt") + WORKFLOW: When asked to implement features or fix issues: 1. Plan with TodoWrite @@ -255,6 +290,32 @@ agent: FOCUS: System operations, service management, monitoring, diagnostics, and infrastructure tasks. CONTEXT: This is a shared system environment, not a project workspace. Users may be managing servers, containers, services, or general infrastructure. + + COMPUTER USE TOOLS: + You have TWO ways to interact with the system: + 1. Direct terminal tools (PRIMARY): Bash, Read, Write, Edit, Grep, etc. + 2. GUI automation tools (FALLBACK): MouseMove, KeyboardType, MouseClick, GetLatestScreenshot + + CRITICAL: ALWAYS prefer direct terminal tools over GUI automation when possible. + + When to use DIRECT tools (preferred): + - Reading files: Use Read tool, NOT KeyboardType to open an editor + - Writing files: Use Write/Edit tools, NOT GUI text editor + - Running commands: Use Bash tool, NOT KeyboardType in a terminal window + - Searching code: Use Grep tool, NOT opening files via GUI + - System operations: Use Bash for systemctl, journalctl, docker, etc. + + When to use GUI tools (only when necessary): + - Interacting with graphical applications that have no CLI equivalent + - Testing UI behavior or visual elements + - Remote desktop administration tasks that MUST be done through a GUI + + Why prefer direct tools: + - 10-100x faster execution (no GUI rendering delays) + - More reliable (no window focus issues, no timing problems) + - Works over SSH without X11 forwarding + - Precise output (structured data, not visual interpretation) + - Lower resource usage (critical for remote systems) system_reminders: enabled: true interval: 4 diff --git a/config/config.go b/config/config.go index bca0ed3c..1656eb6c 100644 --- a/config/config.go +++ b/config/config.go @@ -829,6 +829,41 @@ PARALLEL TOOL EXECUTION: - The system supports up to 5 concurrent tool executions by default - This reduces back-and-forth communication and significantly improves performance +COMPUTER USE TOOLS: +You have TWO ways to interact with the system: +1. Direct terminal tools (PRIMARY): Bash, Read, Write, Edit, Grep, etc. +2. GUI automation tools (FALLBACK): MouseMove, KeyboardType, MouseClick, GetLatestScreenshot + +CRITICAL: ALWAYS prefer direct terminal tools over GUI automation when possible. + +When to use DIRECT tools (preferred): +- Reading files: Use Read tool, NOT KeyboardType to open an editor +- Writing files: Use Write/Edit tools, NOT GUI text editor +- Running commands: Use Bash tool, NOT KeyboardType in a terminal window +- Searching code: Use Grep tool, NOT opening files via GUI +- File operations: Use Bash/Read/Write, NOT GUI file manager + +When to use GUI tools (only when necessary): +- Interacting with graphical applications that have no CLI equivalent +- Testing UI behavior or visual elements +- Automating tasks that MUST be done through a GUI +- Taking screenshots to inspect visual state + +Why prefer direct tools: +- 10-100x faster execution (no GUI rendering delays) +- More reliable (no window focus issues, no timing problems) +- Precise output (structured data, not visual interpretation) +- Parallel execution support (batch multiple operations) +- Lower resource usage (no display server overhead) + +Example - WRONG approach: +MouseMove(x=100, y=200) +MouseClick(button="left") +KeyboardType(text="cat file.txt") + +Example - CORRECT approach: +Read(file_path="/path/to/file.txt") + WORKFLOW: When asked to implement features or fix issues: 1. Plan with TodoWrite @@ -865,7 +900,33 @@ EXAMPLE: FOCUS: System operations, service management, monitoring, diagnostics, and infrastructure tasks. -CONTEXT: This is a shared system environment, not a project workspace. Users may be managing servers, containers, services, or general infrastructure.`, +CONTEXT: This is a shared system environment, not a project workspace. Users may be managing servers, containers, services, or general infrastructure. + +COMPUTER USE TOOLS: +You have TWO ways to interact with the system: +1. Direct terminal tools (PRIMARY): Bash, Read, Write, Edit, Grep, etc. +2. GUI automation tools (FALLBACK): MouseMove, KeyboardType, MouseClick, GetLatestScreenshot + +CRITICAL: ALWAYS prefer direct terminal tools over GUI automation when possible. + +When to use DIRECT tools (preferred): +- Reading files: Use Read tool, NOT KeyboardType to open an editor +- Writing files: Use Write/Edit tools, NOT GUI text editor +- Running commands: Use Bash tool, NOT KeyboardType in a terminal window +- Searching code: Use Grep tool, NOT opening files via GUI +- System operations: Use Bash for systemctl, journalctl, docker, etc. + +When to use GUI tools (only when necessary): +- Interacting with graphical applications that have no CLI equivalent +- Testing UI behavior or visual elements +- Remote desktop administration tasks that MUST be done through a GUI + +Why prefer direct tools: +- 10-100x faster execution (no GUI rendering delays) +- More reliable (no window focus issues, no timing problems) +- Works over SSH without X11 forwarding +- Precise output (structured data, not visual interpretation) +- Lower resource usage (critical for remote systems)`, SystemReminders: SystemRemindersConfig{ Enabled: true, Interval: 4, diff --git a/config/config_test.go b/config/config_test.go index 2b833cd2..cc378621 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -127,6 +127,41 @@ PARALLEL TOOL EXECUTION: - The system supports up to 5 concurrent tool executions by default - This reduces back-and-forth communication and significantly improves performance +COMPUTER USE TOOLS: +You have TWO ways to interact with the system: +1. Direct terminal tools (PRIMARY): Bash, Read, Write, Edit, Grep, etc. +2. GUI automation tools (FALLBACK): MouseMove, KeyboardType, MouseClick, GetLatestScreenshot + +CRITICAL: ALWAYS prefer direct terminal tools over GUI automation when possible. + +When to use DIRECT tools (preferred): +- Reading files: Use Read tool, NOT KeyboardType to open an editor +- Writing files: Use Write/Edit tools, NOT GUI text editor +- Running commands: Use Bash tool, NOT KeyboardType in a terminal window +- Searching code: Use Grep tool, NOT opening files via GUI +- File operations: Use Bash/Read/Write, NOT GUI file manager + +When to use GUI tools (only when necessary): +- Interacting with graphical applications that have no CLI equivalent +- Testing UI behavior or visual elements +- Automating tasks that MUST be done through a GUI +- Taking screenshots to inspect visual state + +Why prefer direct tools: +- 10-100x faster execution (no GUI rendering delays) +- More reliable (no window focus issues, no timing problems) +- Precise output (structured data, not visual interpretation) +- Parallel execution support (batch multiple operations) +- Lower resource usage (no display server overhead) + +Example - WRONG approach: +MouseMove(x=100, y=200) +MouseClick(button="left") +KeyboardType(text="cat file.txt") + +Example - CORRECT approach: +Read(file_path="/path/to/file.txt") + WORKFLOW: When asked to implement features or fix issues: 1. Plan with TodoWrite diff --git a/cspell.yaml b/cspell.yaml index 061b10f8..74b1850f 100644 --- a/cspell.yaml +++ b/cspell.yaml @@ -53,6 +53,7 @@ words: - invopop - isatty - ISPEED + - journalctl - jsonmerge - jsonparser - jsonschema @@ -101,6 +102,7 @@ words: - sabhiram - sagents - sagikazarmark + - SCMPR - SHTTP - sjson - sname @@ -108,6 +110,7 @@ words: - stretchr - subosito - sysinfo + - systemctl - termenv - terminfo - tidwall From 0153abcc623b385565ca9534632170ad9e1395ee Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Sun, 4 Jan 2026 02:24:11 +0200 Subject: [PATCH 06/14] feat: Add macOS support and focus management tools --- .infer/config.yaml | 15 +- cmd/chat.go | 13 +- cmd/export.go | 3 +- cmd/root.go | 5 +- config/config.go | 69 +++- go.mod | 24 +- go.sum | 59 ++- internal/app/chat.go | 16 +- internal/container/container.go | 4 +- internal/display/interface.go | 21 +- internal/display/macos/client_darwin.go | 343 ++++++++++++++++++ internal/display/macos/controller_darwin.go | 132 +++++-- internal/display/macos/controller_stub.go | 6 +- internal/display/wayland/client.go | 12 + internal/display/wayland/controller.go | 17 +- internal/display/x11/client.go | 50 ++- internal/display/x11/controller.go | 19 +- internal/domain/config_service.go | 3 + internal/services/agent.go | 140 ++++++- internal/services/agent_utils.go | 28 +- internal/services/config_service.go | 45 ++- internal/services/screenshot_server.go | 3 +- internal/services/tools.go | 9 - internal/services/tools/activate_app.go | 175 +++++++++ internal/services/tools/get_focused_app.go | 188 ++++++++++ .../services/tools/get_latest_screenshot.go | 11 +- internal/services/tools/keyboard_type.go | 19 +- internal/services/tools/mouse_click.go | 23 +- internal/services/tools/mouse_move.go | 23 +- internal/services/tools/mouse_scroll.go | 217 +++++++++++ internal/services/tools/registry.go | 84 +++-- internal/services/tools/registry_test.go | 78 +++- internal/web/pty_manager.go | 5 +- tests/mocks/domain/fake_config_service.go | 63 ++++ 34 files changed, 1712 insertions(+), 210 deletions(-) create mode 100644 internal/display/macos/client_darwin.go create mode 100644 internal/services/tools/activate_app.go create mode 100644 internal/services/tools/get_focused_app.go create mode 100644 internal/services/tools/mouse_scroll.go diff --git a/.infer/config.yaml b/.infer/config.yaml index 789ae401..ac406a65 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -701,7 +701,7 @@ web: servers: [] computer_use: enabled: false - display: :0 + restore_focus_on_approval: true screenshot: enabled: true max_width: 1920 @@ -711,20 +711,29 @@ computer_use: require_approval: false streaming_enabled: false capture_interval: 3 - buffer_size: 30 + buffer_size: 5 temp_dir: "" log_captures: false mouse_move: enabled: true - require_approval: true + require_approval: false mouse_click: enabled: true require_approval: true + mouse_scroll: + enabled: true + require_approval: false keyboard_type: enabled: true max_text_length: 1000 typing_delay_ms: 200 require_approval: true + get_focused_app: + enabled: true + require_approval: false + activate_app: + enabled: true + require_approval: false rate_limit: enabled: true max_actions_per_minute: 60 diff --git a/cmd/chat.go b/cmd/chat.go index 24cbeafb..f21fb6c8 100644 --- a/cmd/chat.go +++ b/cmd/chat.go @@ -40,8 +40,16 @@ and have a conversational interface with the inference gateway.`, return fmt.Errorf("failed to load config: %w", err) } + if os.Getenv("INFER_WEB_MODE") == "true" { + cfg.Web.Enabled = true + V.Set("web.enabled", true) + } + webMode, _ := cmd.Flags().GetBool("web") if webMode { + cfg.Web.Enabled = true + V.Set("web.enabled", true) + if cmd.Flags().Changed("port") { cfg.Web.Port, _ = cmd.Flags().GetInt("port") } @@ -49,7 +57,6 @@ and have a conversational interface with the inference gateway.`, cfg.Web.Host, _ = cmd.Flags().GetString("host") } - // SSH remote mode flags if cmd.Flags().Changed("ssh-host") { cfg.Web.SSH.Enabled = true sshHost, _ := cmd.Flags().GetString("ssh-host") @@ -58,7 +65,6 @@ and have a conversational interface with the inference gateway.`, sshCommand, _ := cmd.Flags().GetString("ssh-command") noInstall, _ := cmd.Flags().GetBool("ssh-no-install") - // Create a single server config from CLI flags cfg.Web.Servers = []config.SSHServerConfig{ { Name: "CLI Remote Server", @@ -143,6 +149,7 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { conversationRepo := services.GetConversationRepository() modelService := services.GetModelService() config := services.GetConfig() + configService := services.GetConfigService() toolService := services.GetToolService() fileService := services.GetFileService() imageService := services.GetImageService() @@ -183,7 +190,7 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { conversationRepo, conversationOptimizer, modelService, - config, + configService, toolService, fileService, imageService, diff --git a/cmd/export.go b/cmd/export.go index 9a1a4f9a..0f066c01 100644 --- a/cmd/export.go +++ b/cmd/export.go @@ -46,7 +46,8 @@ func runExport(sessionID string) error { return fmt.Errorf("failed to initialize storage: %w", err) } - toolRegistry := tools.NewRegistry(cfg, nil, nil, nil) + configService := services.NewConfigService(V, cfg) + toolRegistry := tools.NewRegistry(configService, nil, nil, nil) toolFormatterService := services.NewToolFormatterService(toolRegistry) pricingService := services.NewPricingService(&cfg.Pricing) persistentRepo := services.NewPersistentConversationRepository(toolFormatterService, pricingService, storageBackend) diff --git a/cmd/root.go b/cmd/root.go index 9fbb706b..16a6e9cb 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -93,7 +93,7 @@ func initConfig() { // nolint:funlen v.SetDefault("web.servers", defaults.Web.Servers) v.SetDefault("computer_use", defaults.ComputerUse) v.SetDefault("computer_use.enabled", defaults.ComputerUse.Enabled) - v.SetDefault("computer_use.display", defaults.ComputerUse.Display) + v.SetDefault("computer_use.restore_focus_on_approval", defaults.ComputerUse.RestoreFocusOnApproval) v.SetDefault("computer_use.screenshot.enabled", defaults.ComputerUse.Screenshot.Enabled) v.SetDefault("computer_use.screenshot.max_width", defaults.ComputerUse.Screenshot.MaxWidth) v.SetDefault("computer_use.screenshot.max_height", defaults.ComputerUse.Screenshot.MaxHeight) @@ -106,8 +106,11 @@ func initConfig() { // nolint:funlen v.SetDefault("computer_use.screenshot.log_captures", defaults.ComputerUse.Screenshot.LogCaptures) v.SetDefault("computer_use.mouse_move.enabled", defaults.ComputerUse.MouseMove.Enabled) v.SetDefault("computer_use.mouse_click.enabled", defaults.ComputerUse.MouseClick.Enabled) + v.SetDefault("computer_use.mouse_scroll.enabled", defaults.ComputerUse.MouseScroll.Enabled) v.SetDefault("computer_use.keyboard_type.enabled", defaults.ComputerUse.KeyboardType.Enabled) v.SetDefault("computer_use.keyboard_type.max_text_length", defaults.ComputerUse.KeyboardType.MaxTextLength) + v.SetDefault("computer_use.get_focused_app.enabled", defaults.ComputerUse.GetFocusedApp.Enabled) + v.SetDefault("computer_use.activate_app.enabled", defaults.ComputerUse.ActivateApp.Enabled) v.SetDefault("computer_use.rate_limit.enabled", defaults.ComputerUse.RateLimit.Enabled) v.SetDefault("computer_use.rate_limit.max_actions_per_minute", defaults.ComputerUse.RateLimit.MaxActionsPerMinute) v.SetDefault("computer_use.rate_limit.window_seconds", defaults.ComputerUse.RateLimit.WindowSeconds) diff --git a/config/config.go b/config/config.go index 1656eb6c..c0eae334 100644 --- a/config/config.go +++ b/config/config.go @@ -250,13 +250,16 @@ type SandboxConfig struct { // ComputerUseConfig contains computer use tool settings type ComputerUseConfig struct { - Enabled bool `yaml:"enabled" mapstructure:"enabled"` - Display string `yaml:"display" mapstructure:"display"` - Screenshot ScreenshotToolConfig `yaml:"screenshot" mapstructure:"screenshot"` - MouseMove MouseMoveToolConfig `yaml:"mouse_move" mapstructure:"mouse_move"` - MouseClick MouseClickToolConfig `yaml:"mouse_click" mapstructure:"mouse_click"` - KeyboardType KeyboardTypeToolConfig `yaml:"keyboard_type" mapstructure:"keyboard_type"` - RateLimit RateLimitConfig `yaml:"rate_limit" mapstructure:"rate_limit"` + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + RestoreFocusOnApproval bool `yaml:"restore_focus_on_approval" mapstructure:"restore_focus_on_approval"` // Switch to terminal for approval, then restore focus + Screenshot ScreenshotToolConfig `yaml:"screenshot" mapstructure:"screenshot"` + MouseMove MouseMoveToolConfig `yaml:"mouse_move" mapstructure:"mouse_move"` + MouseClick MouseClickToolConfig `yaml:"mouse_click" mapstructure:"mouse_click"` + MouseScroll MouseScrollToolConfig `yaml:"mouse_scroll" mapstructure:"mouse_scroll"` + KeyboardType KeyboardTypeToolConfig `yaml:"keyboard_type" mapstructure:"keyboard_type"` + GetFocusedApp GetFocusedAppToolConfig `yaml:"get_focused_app" mapstructure:"get_focused_app"` + ActivateApp ActivateAppToolConfig `yaml:"activate_app" mapstructure:"activate_app"` + RateLimit RateLimitConfig `yaml:"rate_limit" mapstructure:"rate_limit"` } // ScreenshotToolConfig contains screenshot-specific tool settings @@ -286,6 +289,12 @@ type MouseClickToolConfig struct { RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` } +// MouseScrollToolConfig contains mouse scroll-specific tool settings +type MouseScrollToolConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` +} + // KeyboardTypeToolConfig contains keyboard type-specific tool settings type KeyboardTypeToolConfig struct { Enabled bool `yaml:"enabled" mapstructure:"enabled"` @@ -294,6 +303,18 @@ type KeyboardTypeToolConfig struct { RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` } +// GetFocusedAppToolConfig contains get focused app-specific tool settings +type GetFocusedAppToolConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` +} + +// ActivateAppToolConfig contains activate app-specific tool settings +type ActivateAppToolConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` +} + // RateLimitConfig contains rate limiting settings type RateLimitConfig struct { Enabled bool `yaml:"enabled" mapstructure:"enabled"` @@ -1093,8 +1114,8 @@ Write the AGENTS.md file to the project root when you have gathered enough infor Servers: []SSHServerConfig{}, }, ComputerUse: ComputerUseConfig{ - Enabled: false, - Display: ":0", + Enabled: false, + RestoreFocusOnApproval: true, // Switch to terminal for approval, then restore focus to original app Screenshot: ScreenshotToolConfig{ Enabled: true, MaxWidth: 1920, @@ -1104,24 +1125,36 @@ Write the AGENTS.md file to the project root when you have gathered enough infor RequireApproval: &[]bool{false}[0], StreamingEnabled: false, CaptureInterval: 3, - BufferSize: 30, + BufferSize: 5, TempDir: "", LogCaptures: false, }, MouseMove: MouseMoveToolConfig{ Enabled: true, - RequireApproval: &[]bool{true}[0], + RequireApproval: &[]bool{false}[0], }, MouseClick: MouseClickToolConfig{ Enabled: true, RequireApproval: &[]bool{true}[0], }, + MouseScroll: MouseScrollToolConfig{ + Enabled: true, + RequireApproval: &[]bool{false}[0], + }, KeyboardType: KeyboardTypeToolConfig{ Enabled: true, MaxTextLength: 1000, TypingDelayMs: 200, RequireApproval: &[]bool{true}[0], }, + GetFocusedApp: GetFocusedAppToolConfig{ + Enabled: true, + RequireApproval: &[]bool{false}[0], + }, + ActivateApp: ActivateAppToolConfig{ + Enabled: true, + RequireApproval: &[]bool{false}[0], + }, RateLimit: RateLimitConfig{ Enabled: true, MaxActionsPerMinute: 60, @@ -1208,10 +1241,24 @@ func (c *Config) IsApprovalRequired(toolName string) bool { // nolint:gocyclo,cy if c.ComputerUse.MouseClick.RequireApproval != nil { return *c.ComputerUse.MouseClick.RequireApproval } + case "MouseScroll": + if c.ComputerUse.MouseScroll.RequireApproval != nil { + return *c.ComputerUse.MouseScroll.RequireApproval + } case "KeyboardType": if c.ComputerUse.KeyboardType.RequireApproval != nil { return *c.ComputerUse.KeyboardType.RequireApproval } + case "GetFocusedApp": + if c.ComputerUse.GetFocusedApp.RequireApproval != nil { + return *c.ComputerUse.GetFocusedApp.RequireApproval + } + case "ActivateApp": + if c.ComputerUse.ActivateApp.RequireApproval != nil { + return *c.ComputerUse.ActivateApp.RequireApproval + } + case "GetLatestScreenshot": + return false } return globalApproval diff --git a/go.mod b/go.mod index 7c3667d5..0300a43a 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 github.com/creack/pty v1.1.24 github.com/go-redis/redis/v8 v8.11.5 + github.com/go-vgo/robotgo v1.0.0 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/inference-gateway/adk v0.17.0 @@ -52,18 +53,25 @@ require ( github.com/charmbracelet/x/term v0.2.1 // indirect github.com/cloudevents/sdk-go/v2 v2.15.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dblohm7/wingoes v0.0.0-20250822163801-6d8e6105c62d // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dlclark/regexp2 v1.11.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect + github.com/ebitengine/purego v0.9.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/gen2brain/shm v0.1.1 // indirect + github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-resty/resty/v2 v2.16.5 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/godbus/dbus/v5 v5.2.0 // indirect github.com/gorilla/css v1.0.1 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/invopop/jsonschema v0.12.0 // indirect + github.com/jezek/xgb v1.2.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect + github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-localereader v0.0.1 // indirect @@ -77,27 +85,41 @@ require ( github.com/muesli/termenv v0.16.0 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/oapi-codegen/runtime v1.1.2 // indirect + github.com/otiai10/gosseract/v2 v2.4.1 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/robotn/xgb v0.10.0 // indirect + github.com/robotn/xgbutil v0.10.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect + github.com/shirou/gopsutil/v4 v4.25.10 // indirect github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect + github.com/tailscale/win v0.0.0-20250627215312-f4da2b8ee071 // indirect github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/tklauser/go-sysconf v0.3.16 // indirect + github.com/tklauser/numcpus v0.11.0 // indirect + github.com/vcaesar/gops v0.41.0 // indirect + github.com/vcaesar/imgo v0.41.0 // indirect + github.com/vcaesar/keycode v0.10.1 // indirect + github.com/vcaesar/screenshot v0.11.1 // indirect + github.com/vcaesar/tt v0.20.1 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect github.com/yuin/goldmark v1.7.8 // indirect github.com/yuin/goldmark-emoji v1.0.5 // indirect + github.com/yusufpapurcu/wmi v1.2.4 // indirect go.uber.org/multierr v1.10.0 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect + golang.org/x/exp v0.0.0-20251125195548-87e1e737ad39 // indirect golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476 // indirect golang.org/x/mobile v0.0.0-20250606033058-a2a15c67f36f // indirect golang.org/x/mod v0.30.0 // indirect diff --git a/go.sum b/go.sum index 51f9ff39..fe64633d 100644 --- a/go.sum +++ b/go.sum @@ -65,12 +65,16 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dblohm7/wingoes v0.0.0-20250822163801-6d8e6105c62d h1:QRKpU+9ZBDs62LyBfwhZkJdB5DJX2Sm3p4kUh7l1aA0= +github.com/dblohm7/wingoes v0.0.0-20250822163801-6d8e6105c62d/go.mod h1:SUxUaAK/0UG5lYyZR1L1nC4AaYYvSSYTWQSH3FPcxKU= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/ebitengine/purego v0.9.1 h1:a/k2f2HQU3Pi399RPW1MOaZyhKJL9w/xFpKAg4q1s0A= +github.com/ebitengine/purego v0.9.1/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= @@ -79,10 +83,15 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= +github.com/gen2brain/shm v0.1.1 h1:1cTVA5qcsUFixnDHl14TmRoxgfWEEZlTezpUj1vm5uQ= +github.com/gen2brain/shm v0.1.1/go.mod h1:UgIcVtvmOu+aCJpqJX7GOtiN7X2ct+TKLg4RTxwPIUA= github.com/gin-contrib/sse v1.0.0 h1:y3bT1mUWUxDpW4JLQg/HnTqV4rozuW4tC9eFKTxYI9E= github.com/gin-contrib/sse v1.0.0/go.mod h1:zNuFdwarAygJBht0NTKiSi3jRf6RbqeILZ9Sp6Slhe0= github.com/gin-gonic/gin v1.10.1 h1:T0ujvqyCSqRopADpgPgiTT63DUQVSfojyME59Ei63pQ= github.com/gin-gonic/gin v1.10.1/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= +github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= +github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= @@ -93,10 +102,14 @@ github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-resty/resty/v2 v2.16.5 h1:hBKqmWrr7uRc3euHVqmh1HTHcKn99Smr7o5spptdhTM= github.com/go-resty/resty/v2 v2.16.5/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ07xAwp/fiA= +github.com/go-vgo/robotgo v1.0.0 h1:LTzPB8cQsP0E/iMMrh3sPhH9LgywyuuJHGPHk70UA74= +github.com/go-vgo/robotgo v1.0.0/go.mod h1:NcSL/tqNqkpWJ3rmT6YSDUVhQKZwyRsaanDMO4qkT5I= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.4 h1:JSwxQzIqKfmFX1swYPpUThQZp/Ka4wzJdK0LWVytLPM= github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= +github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -118,6 +131,8 @@ github.com/inference-gateway/sdk v1.14.1 h1:Rr1yWjrsMimUVX1nvbJEL4U4vhFhtiqxHFt3 github.com/inference-gateway/sdk v1.14.1/go.mod h1:2r2k+E38WtScmJk6e0GYCaUAbm48pH0oXcmIrVu/xhs= github.com/invopop/jsonschema v0.12.0 h1:6ovsNSuvn9wEQVOyc72aycBMVQFKz7cPdMJn10CvzRI= github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/jezek/xgb v1.2.0 h1:LzgkD11wOrPnxXEqo588cnjUt4NwMHrFh/tgajo50Q0= +github.com/jezek/xgb v1.2.0/go.mod h1:nrhwO0FX/enq75I7Y7G8iN1ubpSGZEiA3v9e9GyRFlk= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= @@ -136,6 +151,8 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3 h1:PwQumkgq4/acIiZhtifTV5OUqqiP82UAl0h87xj/l9k= +github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= @@ -166,6 +183,8 @@ github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= github.com/oapi-codegen/runtime v1.1.2 h1:P2+CubHq8fO4Q6fV1tqDBZHCwpVpvPg7oKiYzQgXIyI= @@ -174,6 +193,10 @@ github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/gomega v1.37.0 h1:CdEG8g0S133B4OswTDC/5XPSzE1OeP29QOioj2PID2Y= github.com/onsi/gomega v1.37.0/go.mod h1:8D9+Txp43QWKhM24yyOBEdpkzN8FvJyAwecBgsU4KU0= +github.com/otiai10/gosseract/v2 v2.4.1 h1:G8AyBpXEeSlcq8TI85LH/pM5SXk8Djy2GEXisgyblRw= +github.com/otiai10/gosseract/v2 v2.4.1/go.mod h1:1gNWP4Hgr2o7yqWfs6r5bZxAatjOIdqWxJLWsTsembk= +github.com/otiai10/mint v1.6.3 h1:87qsV/aw1F5as1eH1zS/yqHY85ANKVMgkDrf9rcxbQs= +github.com/otiai10/mint v1.6.3/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -181,12 +204,19 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= +github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/robotn/xgb v0.0.0-20190912153532-2cb92d044934/go.mod h1:SxQhJskUJ4rleVU44YvnrdvxQr0tKy5SRSigBrCgyyQ= +github.com/robotn/xgb v0.10.0 h1:O3kFbIwtwZ3pgLbp1h5slCQ4OpY8BdwugJLrUe6GPIM= +github.com/robotn/xgb v0.10.0/go.mod h1:SxQhJskUJ4rleVU44YvnrdvxQr0tKy5SRSigBrCgyyQ= +github.com/robotn/xgbutil v0.10.0 h1:gvf7mGQqCWQ68aHRtCxgdewRk+/KAJui6l3MJQQRCKw= +github.com/robotn/xgbutil v0.10.0/go.mod h1:svkDXUDQjUiWzLrA0OZgHc4lbOts3C+uRfP6/yjwYnU= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -196,6 +226,8 @@ github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDc github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= github.com/sclevine/spec v1.4.0 h1:z/Q9idDcay5m5irkZ28M7PtQM4aOISzOpj4bUPkDee8= github.com/sclevine/spec v1.4.0/go.mod h1:LvpgJaFyvQzRvc1kaDs0bulYwzC70PbiYjC4QnFHkOM= +github.com/shirou/gopsutil/v4 v4.25.10 h1:at8lk/5T1OgtuCp+AwrDofFRjnvosn0nkN2OLQ6g8tA= +github.com/shirou/gopsutil/v4 v4.25.10/go.mod h1:+kSwyC8DRUD9XXEHCAFjK+0nuArFJM0lva+StQAcskM= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8 h1:+jumHNA0Wrelhe64i8F6HNlS8pkoyMv5sreGx2Ry5Rw= github.com/sourcegraph/conc v0.3.1-0.20240121214520-5f936abd7ae8/go.mod h1:3n1Cwaq1E1/1lhQhtRK2ts/ZwZEhjcQeJQ1RuC6Q/8U= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= @@ -217,6 +249,10 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tailscale/win v0.0.0-20250627215312-f4da2b8ee071 h1:qo7kOhoN5DHioXNlFytBzIoA5glW6lsb8YqV0lP3IyE= +github.com/tailscale/win v0.0.0-20250627215312-f4da2b8ee071/go.mod h1:aMd4yDHLjbOuYP6fMxj1d9ACDQlSWwYztcpybGHCQc8= +github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= +github.com/tc-hib/winres v0.2.1/go.mod h1:C/JaNhH3KBvhNKVbvdlDWkbMDO9H4fKKDaN7/07SSuk= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= @@ -227,12 +263,26 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYICU0nA= +github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI= +github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw= +github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/vcaesar/gops v0.41.0 h1:FG748Jyw3FOuZnbzSgB+CQSx2e5LbLCPWV2JU1brFdc= +github.com/vcaesar/gops v0.41.0/go.mod h1:/3048L7Rj7QjQKTSB+kKc7hDm63YhTWy5QJ10TCP37A= +github.com/vcaesar/imgo v0.41.0 h1:kNLYGrThXhB9Dd6IwFmfPnxq9P6yat2g7dpPjr7OWO8= +github.com/vcaesar/imgo v0.41.0/go.mod h1:/LGOge8etlzaVu/7l+UfhJxR6QqaoX5yeuzGIMfWb4I= +github.com/vcaesar/keycode v0.10.1 h1:0DesGmMAPWpYTCYddOFiCMKCDKgNnwiQa2QXindVUHw= +github.com/vcaesar/keycode v0.10.1/go.mod h1:JNlY7xbKsh+LAGfY2j4M3znVrGEm5W1R8s/Uv6BJcfQ= +github.com/vcaesar/screenshot v0.11.1 h1:GgPuN89XC4Yh38dLx4quPlSo3YiWWhwIria/j3LtrqU= +github.com/vcaesar/screenshot v0.11.1/go.mod h1:gJNwHBiP1v1v7i8TQ4yV1XJtcyn2I/OJL7OziVQkwjs= +github.com/vcaesar/tt v0.20.1 h1:D/jUeeVCNbq3ad8M7hhtB3J9x5RZ6I1n1eZ0BJp7M+4= +github.com/vcaesar/tt v0.20.1/go.mod h1:cH2+AwGAJm19Wa6xvEa+0r+sXDJBT0QgNQey6mwqLeU= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= @@ -242,6 +292,8 @@ github.com/yuin/goldmark v1.7.8 h1:iERMLn0/QJeHFhxSt3p6PeN9mGnvIKSpG9YYorDMnic= github.com/yuin/goldmark v1.7.8/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yuin/goldmark-emoji v1.0.5 h1:EMVWyCGPlXJfUXBXpuMu+ii3TIaxbVBnEX9uaDC4cIk= github.com/yuin/goldmark-emoji v1.0.5/go.mod h1:tTkZEbwu5wkPmgTcitqddVxY9osFZiavD+r4AzQrh1U= +github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0= +github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= @@ -256,8 +308,8 @@ golang.org/x/arch v0.13.0 h1:KCkqVVV1kGg0X87TFysjCJ8MxtZEIU4Ja/yXGeoECdA= golang.org/x/arch v0.13.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= -golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/exp v0.0.0-20251125195548-87e1e737ad39 h1:DHNhtq3sNNzrvduZZIiFyXWOL9IWaDPHqTnLJp+rCBY= +golang.org/x/exp v0.0.0-20251125195548-87e1e737ad39/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476 h1:Wdx0vgH5Wgsw+lF//LJKmWOJBLWX6nprsMqnf99rYDE= golang.org/x/exp/shiny v0.0.0-20250606033433-dcc06ee1d476/go.mod h1:ygj7T6vSGhhm/9yTpOQQNvuAUFziTH7RUiH74EoE2C8= golang.org/x/image v0.34.0 h1:33gCkyw9hmwbZJeZkct8XyR11yH889EQt/QH4VmXMn8= @@ -270,7 +322,10 @@ golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201204225414-ed752295db88/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= diff --git a/internal/app/chat.go b/internal/app/chat.go index 187362f3..3a4c0935 100644 --- a/internal/app/chat.go +++ b/internal/app/chat.go @@ -31,7 +31,7 @@ import ( // ChatApplication represents the main application model using state management type ChatApplication struct { // Dependencies - configService *config.Config + configService domain.ConfigService agentService domain.AgentService conversationRepo domain.ConversationRepository conversationOptimizer domain.ConversationOptimizerService @@ -102,7 +102,7 @@ func NewChatApplication( conversationRepo domain.ConversationRepository, conversationOptimizer domain.ConversationOptimizerService, modelService domain.ModelService, - configService *config.Config, + configService domain.ConfigService, toolService domain.ToolService, fileService domain.FileService, imageService domain.ImageService, @@ -175,7 +175,7 @@ func NewChatApplication( iv.SetThemeService(app.themeService) iv.SetStateManager(app.stateManager) iv.SetImageService(app.imageService) - iv.SetConfigService(app.configService) + iv.SetConfigService(app.configService.GetConfig()) iv.SetConversationRepo(app.conversationRepo) } @@ -189,7 +189,7 @@ func NewChatApplication( isb.SetModelService(app.modelService) isb.SetThemeService(app.themeService) isb.SetStateManager(app.stateManager) - isb.SetConfigService(app.configService) + isb.SetConfigService(app.configService.GetConfig()) isb.SetConversationRepo(app.conversationRepo) isb.SetToolService(app.toolService) isb.SetTokenEstimator(services.NewTokenizerService(services.DefaultTokenizerConfig())) @@ -209,7 +209,7 @@ func NewChatApplication( app.applicationViewRenderer = components.NewApplicationViewRenderer(styleProvider) app.fileSelectionHandler = components.NewFileSelectionHandler(styleProvider) - app.keyBindingManager = keybinding.NewKeyBindingManager(app, app.configService) + app.keyBindingManager = keybinding.NewKeyBindingManager(app, app.configService.GetConfig()) app.updateHelpBarShortcuts() keyHintFormatter := app.keyBindingManager.GetHintFormatter() @@ -272,7 +272,7 @@ func NewChatApplication( app.backgroundTaskService, app.toolRegistry.GetBackgroundShellService(), agentManager, - configService, + app.configService.GetConfig(), ) app.messageHistoryHandler = handlers.NewMessageHistoryHandler( @@ -995,7 +995,7 @@ func (app *ChatApplication) handleA2ATaskManagementView(msg tea.Msg) []tea.Cmd { var cmds []tea.Cmd if app.taskManager == nil { - if !app.configService.A2A.Enabled { + if !app.configService.GetConfig().A2A.Enabled { cmds = append(cmds, func() tea.Msg { return domain.ShowErrorEvent{ Error: "Task management requires A2A to be enabled in configuration.", @@ -1544,7 +1544,7 @@ func (app *ChatApplication) GetImageService() domain.ImageService { // GetConfig returns the configuration for keybinding context func (app *ChatApplication) GetConfig() *config.Config { - return app.configService + return app.configService.GetConfig() } // GetConfigDir returns the configuration directory path diff --git a/internal/container/container.go b/internal/container/container.go index 05437f38..68d85015 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -218,7 +218,7 @@ func (c *ServiceContainer) initializeDomainServices() { c.initializeMCPManager() - c.toolRegistry = tools.NewRegistry(c.config, c.imageService, c.mcpManager, c.BackgroundShellService()) + c.toolRegistry = tools.NewRegistry(c.configService, c.imageService, c.mcpManager, c.BackgroundShellService()) c.taskTrackerService = c.toolRegistry.GetTaskTracker() toolFormatterService := services.NewToolFormatterService(c.toolRegistry) @@ -288,7 +288,7 @@ func (c *ServiceContainer) initializeDomainServices() { c.agentService = services.NewAgentService( agentClient, c.toolService, - c.config, + c.configService, c.conversationRepo, c.a2aAgentService, c.messageQueue, diff --git a/internal/display/interface.go b/internal/display/interface.go index 575350ad..ed843b5e 100644 --- a/internal/display/interface.go +++ b/internal/display/interface.go @@ -16,6 +16,7 @@ type DisplayController interface { GetCursorPosition(ctx context.Context) (x, y int, err error) MoveMouse(ctx context.Context, x, y int) error ClickMouse(ctx context.Context, button MouseButton, clicks int) error + ScrollMouse(ctx context.Context, clicks int, direction string) error // Keyboard operations TypeText(ctx context.Context, text string, delayMs int) error @@ -72,8 +73,8 @@ func ParseMouseButton(s string) MouseButton { // Provider creates DisplayController instances for a specific display server/protocol type Provider interface { - // GetController creates a new DisplayController for the specified display - GetController(display string) (DisplayController, error) + // GetController creates a new DisplayController (display is auto-detected from environment) + GetController() (DisplayController, error) // GetDisplayInfo returns information about the display server/protocol GetDisplayInfo() DisplayInfo @@ -91,3 +92,19 @@ type DisplayInfo struct { MaxTextLength int RequiresElevation bool } + +// FocusManager is an optional interface for window focus management +// Only implemented by platforms that support it (e.g., macOS) +type FocusManager interface { + // GetFrontmostApp returns the identifier of the currently focused application + GetFrontmostApp(ctx context.Context) (string, error) + + // ActivateApp brings an application to the foreground + ActivateApp(ctx context.Context, appIdentifier string) error + + // GetTerminalApp returns the identifier of the terminal application + GetTerminalApp(ctx context.Context) (string, error) + + // SwitchToTerminal switches focus to the terminal application + SwitchToTerminal(ctx context.Context) error +} diff --git a/internal/display/macos/client_darwin.go b/internal/display/macos/client_darwin.go new file mode 100644 index 00000000..941e4827 --- /dev/null +++ b/internal/display/macos/client_darwin.go @@ -0,0 +1,343 @@ +//go:build darwin + +package macos + +/* +#cgo CFLAGS: -x objective-c +#cgo LDFLAGS: -framework AppKit +#import + +// Get the bundle identifier of the frontmost application +const char* getFrontmostApp() { + NSRunningApplication *app = [[NSWorkspace sharedWorkspace] frontmostApplication]; + if (app == nil) { + return ""; + } + const char *bundleID = [app.bundleIdentifier UTF8String]; + return bundleID ? bundleID : ""; +} + +// Activate application by bundle identifier +bool activateApp(const char *bundleIdentifier) { + @autoreleasepool { + NSString *bundleID = [NSString stringWithUTF8String:bundleIdentifier]; + NSArray *apps = [NSRunningApplication runningApplicationsWithBundleIdentifier:bundleID]; + if ([apps count] == 0) { + return false; + } + NSRunningApplication *app = [apps firstObject]; + // Use activate instead of activateWithOptions (deprecated in macOS 14+) + return [app activateWithOptions:NSApplicationActivateAllWindows]; + } +} + +// Get the terminal app bundle ID (Terminal.app, iTerm2, VS Code, etc.) +const char* getTerminalApp() { + @autoreleasepool { + // Common terminal applications + NSArray *terminalBundles = @[ + @"com.apple.Terminal", // Terminal.app + @"com.googlecode.iterm2", // iTerm2 + @"com.microsoft.VSCode", // VS Code + @"com.sublimetext.4", // Sublime Text + @"com.jetbrains.goland", // GoLand + @"com.jetbrains.intellij", // IntelliJ IDEA + @"org.alacritty", // Alacritty + @"net.kovidgoyal.kitty", // Kitty + ]; + + for (NSString *bundleID in terminalBundles) { + NSArray *apps = [NSRunningApplication runningApplicationsWithBundleIdentifier:bundleID]; + if ([apps count] > 0) { + return [bundleID UTF8String]; + } + } + + return ""; + } +} +*/ +import "C" + +import ( + "bytes" + "fmt" + "image" + "image/png" + "strings" + "time" + "unsafe" + + robotgo "github.com/go-vgo/robotgo" +) + +// MacOSClient provides macOS screen control operations using RobotGo +type MacOSClient struct { + screenWidth int + screenHeight int +} + +// Modifier and key mapping tables +var ( + modifierMap = map[string]string{ + "super": "cmd", + "command": "cmd", + "cmd": "cmd", + "ctrl": "ctrl", + "control": "ctrl", + "alt": "alt", + "option": "alt", + "shift": "shift", + } + + specialKeyMap = map[string]string{ + "enter": "enter", + "return": "enter", + "tab": "tab", + "space": "space", + "backspace": "backspace", + "delete": "delete", + "del": "delete", + "esc": "esc", + "escape": "esc", + "up": "up", + "down": "down", + "left": "left", + "right": "right", + "home": "home", + "end": "end", + "pageup": "pageup", + "pagedown": "pagedown", + "f1": "f1", + "f2": "f2", + "f3": "f3", + "f4": "f4", + "f5": "f5", + "f6": "f6", + "f7": "f7", + "f8": "f8", + "f9": "f9", + "f10": "f10", + "f11": "f11", + "f12": "f12", + } +) + +// NewMacOSClient creates a new macOS client +func NewMacOSClient() (*MacOSClient, error) { + // Get screen dimensions + width, height := robotgo.GetScreenSize() + + return &MacOSClient{ + screenWidth: width, + screenHeight: height, + }, nil +} + +// Close closes the macOS client (no-op for RobotGo) +func (c *MacOSClient) Close() { + // Nothing to close for RobotGo +} + +// GetScreenDimensions returns the screen width and height +func (c *MacOSClient) GetScreenDimensions() (int, int) { + return c.screenWidth, c.screenHeight +} + +// CaptureScreen captures a screenshot and returns it as an image.Image +func (c *MacOSClient) CaptureScreen(x, y, width, height int) (image.Image, error) { + if width == 0 || height == 0 { + width = c.screenWidth + height = c.screenHeight + } + + if x < 0 || y < 0 || x+width > c.screenWidth || y+height > c.screenHeight { + return nil, fmt.Errorf("invalid region: (%d,%d,%d,%d) exceeds screen bounds (%d,%d)", + x, y, width, height, c.screenWidth, c.screenHeight) + } + + bitmap := robotgo.CaptureScreen(x, y, width, height) + if bitmap == nil { + return nil, fmt.Errorf("failed to capture screen") + } + + img := robotgo.ToImage(bitmap) + if img == nil { + return nil, fmt.Errorf("failed to convert bitmap to image") + } + + return img, nil +} + +// CaptureScreenBytes captures a screenshot and returns it as PNG bytes +func (c *MacOSClient) CaptureScreenBytes(x, y, width, height int) ([]byte, error) { + img, err := c.CaptureScreen(x, y, width, height) + if err != nil { + return nil, err + } + + var buf bytes.Buffer + if err := png.Encode(&buf, img); err != nil { + return nil, fmt.Errorf("failed to encode image to PNG: %w", err) + } + + return buf.Bytes(), nil +} + +// GetCursorPosition returns the current cursor position +func (c *MacOSClient) GetCursorPosition() (int, int, error) { + x, y := robotgo.Location() + return x, y, nil +} + +// MoveMouse moves the cursor to the specified coordinates (smooth movement) +func (c *MacOSClient) MoveMouse(x, y int) error { + if x < 0 || y < 0 || x > c.screenWidth || y > c.screenHeight { + return fmt.Errorf("invalid coordinates: (%d,%d) exceeds screen bounds (%d,%d)", + x, y, c.screenWidth, c.screenHeight) + } + + robotgo.Move(x, y) + return nil +} + +// ClickMouse clicks the specified mouse button +func (c *MacOSClient) ClickMouse(button string, clicks int) error { + robotButton := button + if button == "middle" { + robotButton = "center" + } + + if robotButton != "left" && robotButton != "right" && robotButton != "center" { + return fmt.Errorf("invalid button: %s (must be left, right, or middle)", button) + } + + if clicks < 1 || clicks > 3 { + return fmt.Errorf("invalid click count: %d (must be 1-3)", clicks) + } + + for i := range clicks { + if i > 0 { + time.Sleep(100 * time.Millisecond) + } + robotgo.Click(robotButton, false) + } + + return nil +} + +// ScrollMouse scrolls the mouse wheel +func (c *MacOSClient) ScrollMouse(clicks int, direction string) error { + if clicks == 0 { + return nil + } + + scrollAmount := clicks * 100 + absAmount := scrollAmount + if scrollAmount < 0 { + absAmount = -scrollAmount + } + + var scrollDir string + if direction == "horizontal" { + scrollDir = "right" + if scrollAmount < 0 { + scrollDir = "left" + } + } else { + scrollDir = "down" + if scrollAmount < 0 { + scrollDir = "up" + } + } + + robotgo.ScrollDir(absAmount, scrollDir) + return nil +} + +// TypeText types the specified text with delay between characters +func (c *MacOSClient) TypeText(text string, delayMs int) error { + if text == "" { + return fmt.Errorf("text cannot be empty") + } + + if delayMs > 0 { + for _, char := range text { + robotgo.Type(string(char)) + time.Sleep(time.Duration(delayMs) * time.Millisecond) + } + } else { + robotgo.Type(text) + } + + return nil +} + +// SendKeyCombo sends a key combination (e.g., "ctrl+c", "cmd+shift+t") +func (c *MacOSClient) SendKeyCombo(combo string) error { + if combo == "" { + return fmt.Errorf("key combo cannot be empty") + } + + parts := strings.Split(combo, "+") + if len(parts) == 0 { + return fmt.Errorf("invalid key combo: %s", combo) + } + + key := strings.ToLower(strings.TrimSpace(parts[len(parts)-1])) + var modifiers []any + + for i := 0; i < len(parts)-1; i++ { + mod := strings.ToLower(strings.TrimSpace(parts[i])) + if mappedMod, ok := modifierMap[mod]; ok { + modifiers = append(modifiers, mappedMod) + } else { + return fmt.Errorf("unknown modifier: %s", mod) + } + } + + if mappedKey, ok := specialKeyMap[key]; ok { + key = mappedKey + } + + if err := robotgo.KeyTap(key, modifiers...); err != nil { + return fmt.Errorf("failed to send key combo: %w", err) + } + + return nil +} + +// GetFrontmostApp returns the bundle identifier of the currently focused application +func (c *MacOSClient) GetFrontmostApp() string { + cAppID := C.getFrontmostApp() + return C.GoString(cAppID) +} + +// ActivateApp brings an application to the foreground by bundle identifier +func (c *MacOSClient) ActivateApp(bundleIdentifier string) error { + cBundleID := C.CString(bundleIdentifier) + defer C.free(unsafe.Pointer(cBundleID)) + + success := C.activateApp(cBundleID) + if !success { + return fmt.Errorf("failed to activate app: %s", bundleIdentifier) + } + + return nil +} + +// GetTerminalApp returns the bundle identifier of the running terminal application +func (c *MacOSClient) GetTerminalApp() string { + cTerminalID := C.getTerminalApp() + return C.GoString(cTerminalID) +} + +// SwitchToTerminal switches focus to the terminal application +func (c *MacOSClient) SwitchToTerminal() error { + terminalID := c.GetTerminalApp() + if terminalID == "" { + return fmt.Errorf("no terminal application found") + } + + return c.ActivateApp(terminalID) +} diff --git a/internal/display/macos/controller_darwin.go b/internal/display/macos/controller_darwin.go index 1391c0fd..276f3af6 100644 --- a/internal/display/macos/controller_darwin.go +++ b/internal/display/macos/controller_darwin.go @@ -2,74 +2,110 @@ package macos +/* +#cgo CFLAGS: -x objective-c +#cgo LDFLAGS: -framework ApplicationServices +#include + +bool checkAccessibilityPermissions() { + return AXIsProcessTrusted(); +} +*/ +import "C" + import ( "context" "fmt" "image" + "os" "runtime" display "github.com/inference-gateway/cli/internal/display" + logger "github.com/inference-gateway/cli/internal/logger" ) -// Controller implements display.DisplayController for macOS -// This is a placeholder for future CGO implementation using: -// - CGDisplayCreateImage for screenshots -// - CGEventPost for mouse/keyboard control -// - Accessibility API for permissions -type Controller struct{} +// Controller implements display.DisplayController for macOS using RobotGo +type Controller struct { + client *MacOSClient +} var _ display.DisplayController = (*Controller)(nil) +var _ display.FocusManager = (*Controller)(nil) func (c *Controller) CaptureScreenBytes(ctx context.Context, region *display.Region) ([]byte, error) { - // TODO: Implement using CGDisplayCreateImage + CGO - // Sample code structure: - // /* - // #cgo LDFLAGS: -framework CoreGraphics -framework CoreFoundation - // #include - // CGImageRef CGDisplayCreateImage(CGDirectDisplayID displayID); - // */ - // import "C" - return nil, fmt.Errorf("macOS screenshot not yet implemented (requires CGO)") + if region == nil { + return c.client.CaptureScreenBytes(0, 0, 0, 0) + } + return c.client.CaptureScreenBytes(region.X, region.Y, region.Width, region.Height) } func (c *Controller) CaptureScreen(ctx context.Context, region *display.Region) (image.Image, error) { - return nil, fmt.Errorf("macOS screenshot not yet implemented (requires CGO)") + if region == nil { + return c.client.CaptureScreen(0, 0, 0, 0) + } + return c.client.CaptureScreen(region.X, region.Y, region.Width, region.Height) } func (c *Controller) GetScreenDimensions(ctx context.Context) (width, height int, err error) { - // TODO: Implement using CGDisplayBounds - return 0, 0, fmt.Errorf("macOS screen dimensions not yet implemented (requires CGO)") + w, h := c.client.GetScreenDimensions() + return w, h, nil } func (c *Controller) GetCursorPosition(ctx context.Context) (x, y int, err error) { - // TODO: Implement using CGEventGetLocation - return 0, 0, fmt.Errorf("macOS cursor position not yet implemented (requires CGO)") + return c.client.GetCursorPosition() } func (c *Controller) MoveMouse(ctx context.Context, x, y int) error { - // TODO: Implement using CGEventCreateMouseEvent + CGEventPost - return fmt.Errorf("macOS mouse move not yet implemented (requires CGO)") + return c.client.MoveMouse(x, y) } func (c *Controller) ClickMouse(ctx context.Context, button display.MouseButton, clicks int) error { - // TODO: Implement using CGEventCreateMouseEvent + CGEventPost - return fmt.Errorf("macOS mouse click not yet implemented (requires CGO)") + return c.client.ClickMouse(button.String(), clicks) +} + +func (c *Controller) ScrollMouse(ctx context.Context, clicks int, direction string) error { + return c.client.ScrollMouse(clicks, direction) } func (c *Controller) TypeText(ctx context.Context, text string, delayMs int) error { - // TODO: Implement using CGEventCreateKeyboardEvent + CGEventPost - return fmt.Errorf("macOS keyboard type not yet implemented (requires CGO)") + return c.client.TypeText(text, delayMs) } func (c *Controller) SendKeyCombo(ctx context.Context, combo string) error { - // TODO: Implement using CGEventCreateKeyboardEvent with modifiers - return fmt.Errorf("macOS key combo not yet implemented (requires CGO)") + return c.client.SendKeyCombo(combo) } func (c *Controller) Close() error { + c.client.Close() return nil } +// FocusManager implementation for macOS + +func (c *Controller) GetFrontmostApp(ctx context.Context) (string, error) { + appID := c.client.GetFrontmostApp() + if appID == "" { + return "", fmt.Errorf("no frontmost application found") + } + return appID, nil +} + +func (c *Controller) ActivateApp(ctx context.Context, appIdentifier string) error { + return c.client.ActivateApp(appIdentifier) +} + +func (c *Controller) GetTerminalApp(ctx context.Context) (string, error) { + terminalID := c.client.GetTerminalApp() + if terminalID == "" { + return "", fmt.Errorf("no terminal application found") + } + return terminalID, nil +} + +func (c *Controller) SwitchToTerminal(ctx context.Context) error { + return c.client.SwitchToTerminal() +} + // Provider implements the display.Provider interface for macOS type Provider struct{} @@ -79,11 +115,34 @@ func NewProvider() *Provider { return &Provider{} } -func (p *Provider) GetController(display string) (display.DisplayController, error) { - // TODO: Check Accessibility permissions - // Sample code: - // AXIsProcessTrustedWithOptions() - return nil, fmt.Errorf("macOS provider not yet implemented (requires CGO)") +func (p *Provider) GetController() (display.DisplayController, error) { + if os.Getenv("SSH_CONNECTION") != "" { + return nil, fmt.Errorf("macOS display not available in SSH session") + } + + if !hasAccessibilityPermissions() { + return nil, fmt.Errorf("accessibility permissions required. Grant access in System Settings > Privacy & Security > Accessibility (or System Preferences > Security & Privacy > Privacy > Accessibility on older macOS)") + } + + client, err := NewMacOSClient() + if err != nil { + return nil, fmt.Errorf("failed to create macOS client: %w", err) + } + + return &Controller{client: client}, nil +} + +// hasAccessibilityPermissions checks if the app has accessibility permissions +// Uses native macOS AXIsProcessTrusted() API for reliable detection +func hasAccessibilityPermissions() bool { + trusted := C.checkAccessibilityPermissions() + hasPerm := bool(trusted) + + if !hasPerm { + logger.Debug("Accessibility permissions not granted") + } + + return hasPerm } func (p *Provider) GetDisplayInfo() display.DisplayInfo { @@ -105,9 +164,6 @@ func (p *Provider) IsAvailable() bool { // Register the macOS provider in the global registry (darwin only) func init() { - // TODO: Uncomment when implementation is ready - // display.Register(NewProvider()) - - // For now, don't register to avoid false positives - // The stub implementation will prevent compilation errors + display.Register(NewProvider()) + logger.Debug("Registered macOS display provider") } diff --git a/internal/display/macos/controller_stub.go b/internal/display/macos/controller_stub.go index 0c262bb7..7b9baa92 100644 --- a/internal/display/macos/controller_stub.go +++ b/internal/display/macos/controller_stub.go @@ -39,6 +39,10 @@ func (c *Controller) ClickMouse(ctx context.Context, button display.MouseButton, return fmt.Errorf("macOS platform not available on this system") } +func (c *Controller) ScrollMouse(ctx context.Context, clicks int, direction string) error { + return fmt.Errorf("macOS platform not available on this system") +} + func (c *Controller) TypeText(ctx context.Context, text string, delayMs int) error { return fmt.Errorf("macOS platform not available on this system") } @@ -60,7 +64,7 @@ func NewProvider() *Provider { return &Provider{} } -func (p *Provider) GetController(display string) (display.DisplayController, error) { +func (p *Provider) GetController() (display.DisplayController, error) { return nil, fmt.Errorf("macOS platform not available on this system") } diff --git a/internal/display/wayland/client.go b/internal/display/wayland/client.go index 502b4ab1..1dd5654d 100644 --- a/internal/display/wayland/client.go +++ b/internal/display/wayland/client.go @@ -7,6 +7,8 @@ import ( "strconv" "strings" "time" + + robotgo "github.com/go-vgo/robotgo" ) // WaylandClient provides Wayland screen control operations using command-line tools @@ -124,6 +126,16 @@ func (c *WaylandClient) ClickMouse(button string, clicks int) error { return nil } +// ScrollMouse scrolls the mouse wheel +func (c *WaylandClient) ScrollMouse(clicks int, direction string) error { + if direction == "horizontal" { + robotgo.ScrollDir(clicks, "right") + } else { + robotgo.Scroll(0, clicks) + } + return nil +} + // TypeText types the given text with a configurable delay between keystrokes (in milliseconds) func (c *WaylandClient) TypeText(text string, delayMs int) error { if _, err := exec.LookPath("wtype"); err == nil { diff --git a/internal/display/wayland/controller.go b/internal/display/wayland/controller.go index fe61224c..74533775 100644 --- a/internal/display/wayland/controller.go +++ b/internal/display/wayland/controller.go @@ -28,7 +28,6 @@ func (c *Controller) CaptureScreenBytes(ctx context.Context, region *display.Reg // CaptureScreen captures a screenshot and returns an image.Image func (c *Controller) CaptureScreen(ctx context.Context, region *display.Region) (image.Image, error) { - // WaylandClient only returns bytes, so we need to decode them var imgBytes []byte var err error @@ -42,7 +41,6 @@ func (c *Controller) CaptureScreen(ctx context.Context, region *display.Region) return nil, err } - // Decode PNG bytes to image.Image img, err := png.Decode(bytes.NewReader(imgBytes)) if err != nil { return nil, fmt.Errorf("failed to decode screenshot: %w", err) @@ -74,6 +72,11 @@ func (c *Controller) ClickMouse(ctx context.Context, button display.MouseButton, return c.client.ClickMouse(button.String(), clicks) } +// ScrollMouse scrolls the mouse wheel +func (c *Controller) ScrollMouse(ctx context.Context, clicks int, direction string) error { + return c.client.ScrollMouse(clicks, direction) +} + // TypeText types the given text with the specified delay between keystrokes func (c *Controller) TypeText(ctx context.Context, text string, delayMs int) error { return c.client.TypeText(text, delayMs) @@ -100,9 +103,11 @@ func NewProvider() *Provider { return &Provider{} } -// GetController creates a new DisplayController for the specified display -func (p *Provider) GetController(display string) (display.DisplayController, error) { - client, err := NewWaylandClient(display) +// GetController creates a new DisplayController (auto-detects display from $WAYLAND_DISPLAY env var) +func (p *Provider) GetController() (display.DisplayController, error) { + displayName := os.Getenv("WAYLAND_DISPLAY") + + client, err := NewWaylandClient(displayName) if err != nil { return nil, err } @@ -123,8 +128,6 @@ func (p *Provider) GetDisplayInfo() display.DisplayInfo { // IsAvailable returns true if Wayland is available on the current system func (p *Provider) IsAvailable() bool { - // Wayland is available if the WAYLAND_DISPLAY environment variable is set - // Wayland takes priority over X11 return os.Getenv("WAYLAND_DISPLAY") != "" } diff --git a/internal/display/x11/client.go b/internal/display/x11/client.go index e98b5462..88f5449d 100644 --- a/internal/display/x11/client.go +++ b/internal/display/x11/client.go @@ -47,9 +47,6 @@ var ( // NewX11Client creates a new X11 client connection func NewX11Client(display string) (*X11Client, error) { - if display == "" { - display = ":0" - } oldStderr := os.Stderr devNull, devErr := os.OpenFile(os.DevNull, os.O_WRONLY, 0) @@ -205,6 +202,53 @@ func (c *X11Client) ClickMouse(button string, clicks int) error { return nil } +// ScrollMouse scrolls the mouse wheel +// For X11: button 4 = scroll up, button 5 = scroll down +// +// button 6 = scroll left, button 7 = scroll right +func (c *X11Client) ScrollMouse(clicks int, direction string) error { + root := c.screen.Root + + var buttonCode byte + absClicks := clicks + if clicks < 0 { + absClicks = -clicks + } + + if direction == "horizontal" { + buttonCode = 7 + if clicks < 0 { + buttonCode = 6 + } + } else { + buttonCode = 5 + if clicks < 0 { + buttonCode = 4 + } + } + + absClicks = absClicks * 100 + + for i := 0; i < absClicks; i++ { + cookie := xtest.FakeInputChecked(c.conn, xproto.ButtonPress, buttonCode, 0, root, 0, 0, 0) + if err := cookie.Check(); err != nil { + return fmt.Errorf("failed to send scroll press: %w", err) + } + + cookie = xtest.FakeInputChecked(c.conn, xproto.ButtonRelease, buttonCode, 0, root, 0, 0, 0) + if err := cookie.Check(); err != nil { + return fmt.Errorf("failed to send scroll release: %w", err) + } + + if i < absClicks-1 { + time.Sleep(50 * time.Millisecond) + } + } + + c.conn.Sync() + return nil +} + // charToKeyInfo maps a character to its X11 key string and shift requirement type charToKeyInfo struct { keyStr string diff --git a/internal/display/x11/controller.go b/internal/display/x11/controller.go index 187a29a3..3593b344 100644 --- a/internal/display/x11/controller.go +++ b/internal/display/x11/controller.go @@ -52,6 +52,11 @@ func (c *Controller) ClickMouse(ctx context.Context, button display.MouseButton, return c.client.ClickMouse(button.String(), clicks) } +// ScrollMouse scrolls the mouse wheel +func (c *Controller) ScrollMouse(ctx context.Context, clicks int, direction string) error { + return c.client.ScrollMouse(clicks, direction) +} + // TypeText types the given text with the specified delay between keystrokes func (c *Controller) TypeText(ctx context.Context, text string, delayMs int) error { return c.client.TypeText(text, delayMs) @@ -78,9 +83,15 @@ func NewProvider() *Provider { return &Provider{} } -// GetController creates a new DisplayController for the specified display -func (p *Provider) GetController(display string) (display.DisplayController, error) { - client, err := NewX11Client(display) +// GetController creates a new DisplayController (auto-detects display from $DISPLAY env var) +func (p *Provider) GetController() (display.DisplayController, error) { + // Detect display from environment + displayName := os.Getenv("DISPLAY") + if displayName == "" { + displayName = ":0" // Fallback to default + } + + client, err := NewX11Client(displayName) if err != nil { return nil, err } @@ -101,8 +112,6 @@ func (p *Provider) GetDisplayInfo() display.DisplayInfo { // IsAvailable returns true if X11 is available on the current system func (p *Provider) IsAvailable() bool { - // X11 is available if the DISPLAY environment variable is set - // and WAYLAND_DISPLAY is not set (Wayland takes priority) return os.Getenv("DISPLAY") != "" && os.Getenv("WAYLAND_DISPLAY") == "" } diff --git a/internal/domain/config_service.go b/internal/domain/config_service.go index 1c0f4d76..4904bb16 100644 --- a/internal/domain/config_service.go +++ b/internal/domain/config_service.go @@ -22,4 +22,7 @@ type ConfigService interface { // Sandbox configuration GetSandboxDirectories() []string GetProtectedPaths() []string + + // Full configuration access + GetConfig() *config.Config } diff --git a/internal/services/agent.go b/internal/services/agent.go index b21ce844..63b47eff 100644 --- a/internal/services/agent.go +++ b/internal/services/agent.go @@ -9,6 +9,7 @@ import ( "time" constants "github.com/inference-gateway/cli/internal/constants" + display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" sdk "github.com/inference-gateway/sdk" @@ -1275,6 +1276,13 @@ func (s *AgentServiceImpl) requestToolApproval( tc sdk.ChatCompletionMessageToolCall, eventPublisher *eventPublisher, ) (bool, error) { + var savedAppID string + shouldRestoreFocus := s.shouldRestoreFocusForTool(tc.Function.Name) + + if shouldRestoreFocus { + savedAppID = s.saveFocusAndSwitchToTerminal(ctx) + } + responseChan := make(chan domain.ApprovalAction, 1) eventPublisher.chatEvents <- domain.ToolApprovalRequestedEvent{ @@ -1284,14 +1292,140 @@ func (s *AgentServiceImpl) requestToolApproval( ResponseChan: responseChan, } + var approved bool + var err error + select { case response := <-responseChan: - return response == domain.ApprovalApprove, nil + approved = response == domain.ApprovalApprove case <-ctx.Done(): - return false, fmt.Errorf("approval request cancelled: %w", ctx.Err()) + err = fmt.Errorf("approval request cancelled: %w", ctx.Err()) case <-time.After(5 * time.Minute): - return false, fmt.Errorf("approval request timed out") + err = fmt.Errorf("approval request timed out") + } + + if shouldRestoreFocus && savedAppID != "" { + s.restoreFocus(ctx, savedAppID) + time.Sleep(500 * time.Millisecond) + } + + return approved, err +} + +// shouldRestoreFocusForTool determines if focus should be restored for a given tool +func (s *AgentServiceImpl) shouldRestoreFocusForTool(toolName string) bool { + cfg := s.config.GetConfig() + if cfg == nil { + return false + } + + if !cfg.ComputerUse.Enabled || !cfg.ComputerUse.RestoreFocusOnApproval { + return false + } + + // Only restore focus for computer use tools + computerUseTools := map[string]bool{ + "MouseMove": true, + "MouseClick": true, + "KeyboardType": true, + } + + return computerUseTools[toolName] +} + +// saveFocusAndSwitchToTerminal saves the currently focused app and switches to approval target +// In terminal mode: switches to terminal +// In web mode: switches to browser running web UI +// Returns the saved app ID for later restoration +func (s *AgentServiceImpl) saveFocusAndSwitchToTerminal(ctx context.Context) string { + displayProvider, err := display.DetectDisplay() + if err != nil { + logger.Debug("Failed to detect display for focus management", "error", err) + return "" + } + + controller, err := displayProvider.GetController() + if err != nil { + logger.Debug("Failed to get display controller for focus management", "error", err) + return "" + } + defer func() { + if err := controller.Close(); err != nil { + logger.Debug("Failed to close display controller", "error", err) + } + }() + + // Check if controller supports focus management + focusManager, ok := controller.(display.FocusManager) + if !ok { + logger.Debug("Display controller does not support focus management") + return "" } + + // Save currently focused app + savedAppID, err := focusManager.GetFrontmostApp(ctx) + if err != nil { + logger.Debug("Failed to get frontmost app", "error", err) + return "" + } + + cfg := s.config.GetConfig() + if cfg == nil { + return savedAppID + } + + if cfg.Web.Enabled && !cfg.Web.SSH.Enabled { + logger.Debug("Web mode - no focus switch needed for approval", "saved_app", savedAppID) + return savedAppID + } + + // Terminal mode: switch to terminal + if err := focusManager.SwitchToTerminal(ctx); err != nil { + logger.Warn("Failed to switch to terminal for approval", "error", err) + return savedAppID + } + + logger.Debug("Switched to terminal for approval", "saved_app", savedAppID) + return savedAppID +} + +// restoreFocus restores focus to the previously focused application +func (s *AgentServiceImpl) restoreFocus(ctx context.Context, appID string) { + if appID == "" { + return + } + + displayProvider, err := display.DetectDisplay() + if err != nil { + logger.Debug("Failed to detect display for focus restoration", "error", err) + return + } + + controller, err := displayProvider.GetController() + if err != nil { + logger.Debug("Failed to get display controller for focus restoration", "error", err) + return + } + defer func() { + if err := controller.Close(); err != nil { + logger.Debug("Failed to close display controller", "error", err) + } + }() + + focusManager, ok := controller.(display.FocusManager) + if !ok { + return + } + + // Small delay to allow approval UI to process before switching away + time.Sleep(200 * time.Millisecond) + + if err := focusManager.ActivateApp(ctx, appID); err != nil { + logger.Debug("Failed to restore focus", "app", appID, "error", err) + return + } + + logger.Debug("Restored focus after approval", "app", appID) } // isBashCommandWhitelisted checks if a Bash tool command is whitelisted diff --git a/internal/services/agent_utils.go b/internal/services/agent_utils.go index 9a68bfc7..20dde8d0 100644 --- a/internal/services/agent_utils.go +++ b/internal/services/agent_utils.go @@ -3,6 +3,7 @@ package services import ( "encoding/json" "fmt" + "runtime" "strings" "time" @@ -85,8 +86,10 @@ func (s *AgentServiceImpl) addSystemPrompt(messages []sdk.Message) []sdk.Message a2aAgentInfo := s.buildA2AAgentInfo() - systemPromptWithInfo := fmt.Sprintf("%s\n\n%s%s\n\nCurrent date and time: %s", - baseSystemPrompt, sandboxInfo, a2aAgentInfo, currentTime) + osInfo := s.buildOSInfo() + + systemPromptWithInfo := fmt.Sprintf("%s\n\n%s%s%s\n\nCurrent date and time: %s", + baseSystemPrompt, sandboxInfo, a2aAgentInfo, osInfo, currentTime) systemMessages = append(systemMessages, sdk.Message{ Role: sdk.System, @@ -172,6 +175,27 @@ func (s *AgentServiceImpl) buildSandboxInfo() string { return sandboxInfo.String() } +// buildOSInfo creates dynamic OS information for the system prompt +func (s *AgentServiceImpl) buildOSInfo() string { + osInfo := fmt.Sprintf("\n\nOPERATING SYSTEM: %s", runtime.GOOS) + + switch runtime.GOOS { + case "darwin": + osInfo += "\n- Use 'cmd' modifier for keyboard shortcuts (e.g., 'cmd+c' for copy)" + osInfo += "\n- Use 'open -a AppName' to launch applications (e.g., 'open -a Firefox')" + osInfo += "\n- Use 'open file.txt' to open files with default app" + osInfo += "\n- IMPORTANT: When opening URLs in browsers, ALWAYS open in a NEW WINDOW, not a new tab" + osInfo += "\n Example: 'open -n -a \"Google Chrome\" --args --new-window https://example.com'" + osInfo += "\n Or for Safari: 'open -n -a Safari https://example.com'" + osInfo += "\n This allows proper focus management between windows" + case "linux": + osInfo += "\n- Use 'ctrl' modifier for keyboard shortcuts (e.g., 'ctrl+c' for copy)" + osInfo += "\n- Use command name directly or 'xdg-open' to launch applications" + } + + return osInfo +} + // validateRequest validates the agent request func (s *AgentServiceImpl) validateRequest(req *domain.AgentRequest) error { if req == nil { diff --git a/internal/services/config_service.go b/internal/services/config_service.go index c3c3c4c5..307b9e37 100644 --- a/internal/services/config_service.go +++ b/internal/services/config_service.go @@ -3,9 +3,10 @@ package services import ( "fmt" - "github.com/inference-gateway/cli/config" - "github.com/inference-gateway/cli/internal/utils" - "github.com/spf13/viper" + viper "github.com/spf13/viper" + + config "github.com/inference-gateway/cli/config" + utils "github.com/inference-gateway/cli/internal/utils" ) // ConfigService handles configuration management and reloading @@ -60,3 +61,41 @@ func (cs *ConfigService) SetValue(key, value string) error { return nil } + +// Domain ConfigService interface implementation (delegates to underlying config) + +func (cs *ConfigService) IsApprovalRequired(toolName string) bool { + return cs.config.IsApprovalRequired(toolName) +} + +func (cs *ConfigService) IsBashCommandWhitelisted(command string) bool { + return cs.config.IsBashCommandWhitelisted(command) +} + +func (cs *ConfigService) GetOutputDirectory() string { + return cs.config.GetOutputDirectory() +} + +func (cs *ConfigService) GetGatewayURL() string { + return cs.config.Gateway.URL +} + +func (cs *ConfigService) GetAPIKey() string { + return cs.config.Gateway.APIKey +} + +func (cs *ConfigService) GetTimeout() int { + return cs.config.Gateway.Timeout +} + +func (cs *ConfigService) GetAgentConfig() *config.AgentConfig { + return cs.config.GetAgentConfig() +} + +func (cs *ConfigService) GetSandboxDirectories() []string { + return cs.config.GetSandboxDirectories() +} + +func (cs *ConfigService) GetProtectedPaths() []string { + return cs.config.GetProtectedPaths() +} diff --git a/internal/services/screenshot_server.go b/internal/services/screenshot_server.go index c47ab365..23e67a0f 100644 --- a/internal/services/screenshot_server.go +++ b/internal/services/screenshot_server.go @@ -16,6 +16,7 @@ import ( domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" + _ "github.com/inference-gateway/cli/internal/display/macos" _ "github.com/inference-gateway/cli/internal/display/wayland" _ "github.com/inference-gateway/cli/internal/display/x11" ) @@ -186,7 +187,7 @@ func (s *ScreenshotServer) captureScreenshot() error { return fmt.Errorf("no compatible display platform detected: %w", err) } - controller, err := displayProvider.GetController(s.cfg.ComputerUse.Display) + controller, err := displayProvider.GetController() if err != nil { return fmt.Errorf("failed to get platform controller: %w", err) } diff --git a/internal/services/tools.go b/internal/services/tools.go index f8317685..af23c158 100644 --- a/internal/services/tools.go +++ b/internal/services/tools.go @@ -19,15 +19,6 @@ type LLMToolService struct { config *config.Config } -// NewLLMToolService creates a new LLM tool service with a new registry -func NewLLMToolService(cfg *config.Config) *LLMToolService { - return &LLMToolService{ - registry: tools.NewRegistry(cfg, nil, nil, nil), - enabled: cfg.Tools.Enabled, - config: cfg, - } -} - // NewLLMToolServiceWithRegistry creates a new LLM tool service with an existing registry func NewLLMToolServiceWithRegistry(cfg *config.Config, registry *tools.Registry) *LLMToolService { return &LLMToolService{ diff --git a/internal/services/tools/activate_app.go b/internal/services/tools/activate_app.go new file mode 100644 index 00000000..65b91d47 --- /dev/null +++ b/internal/services/tools/activate_app.go @@ -0,0 +1,175 @@ +package tools + +import ( + "context" + "fmt" + "time" + + display "github.com/inference-gateway/cli/internal/display" + domain "github.com/inference-gateway/cli/internal/domain" + logger "github.com/inference-gateway/cli/internal/logger" + sdk "github.com/inference-gateway/sdk" +) + +// ActivateAppTool switches focus to a specific application +type ActivateAppTool struct { + config domain.ConfigService +} + +// NewActivateAppTool creates a new ActivateApp tool +func NewActivateAppTool(config domain.ConfigService) *ActivateAppTool { + return &ActivateAppTool{ + config: config, + } +} + +// Definition returns the tool definition for ActivateApp +func (t *ActivateAppTool) Definition() sdk.ChatCompletionTool { + description := "Activates (brings to foreground/focus) a specific application by its bundle identifier. Use GetFocusedApp first to check the current state, then use this tool to switch to the target app before performing computer use actions. After activation, wait briefly before sending keyboard/mouse commands." + return sdk.ChatCompletionTool{ + Type: sdk.Function, + Function: sdk.FunctionObject{ + Name: "ActivateApp", + Description: &description, + Parameters: &sdk.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "bundle_id": map[string]any{ + "type": "string", + "description": "The bundle identifier of the application to activate (e.g., 'org.mozilla.firefox', 'com.google.Chrome', 'com.apple.Terminal'). Common apps: Firefox='org.mozilla.firefox', Chrome='com.google.Chrome', Safari='com.apple.Safari', Terminal='com.apple.Terminal', VSCode='com.microsoft.VSCode'", + }, + }, + "required": []string{"bundle_id"}, + }, + }, + } +} + +// Validate validates ActivateApp arguments +func (t *ActivateAppTool) Validate(args map[string]any) error { + bundleID, ok := args["bundle_id"].(string) + if !ok || bundleID == "" { + return fmt.Errorf("bundle_id is required and must be a non-empty string") + } + return nil +} + +// Execute executes the ActivateApp tool +func (t *ActivateAppTool) Execute(ctx context.Context, args map[string]any) (*domain.ToolExecutionResult, error) { + bundleID, ok := args["bundle_id"].(string) + if !ok { + return nil, fmt.Errorf("bundle_id must be a string") + } + + displayProvider, err := display.DetectDisplay() + if err != nil { + return nil, fmt.Errorf("failed to detect display: %w", err) + } + + controller, err := displayProvider.GetController() + if err != nil { + return nil, fmt.Errorf("failed to get display controller: %w", err) + } + defer func() { + if err := controller.Close(); err != nil { + logger.Debug("Failed to close display controller", "error", err) + } + }() + + focusManager, ok := controller.(display.FocusManager) + if !ok { + return nil, fmt.Errorf("display controller does not support focus management") + } + + // Activate the application + if err := focusManager.ActivateApp(ctx, bundleID); err != nil { + return nil, fmt.Errorf("failed to activate app '%s': %w (app may not be running)", bundleID, err) + } + + // Give the OS time to fully switch focus + time.Sleep(300 * time.Millisecond) + + // Verify activation + currentApp, err := focusManager.GetFrontmostApp(ctx) + if err == nil && currentApp == bundleID { + result := fmt.Sprintf("Successfully activated %s (bundle ID: %s). The application is now in focus.", parseAppName(bundleID), bundleID) + return &domain.ToolExecutionResult{ + ToolName: "ActivateApp", + Success: true, + Data: map[string]any{ + "bundle_id": bundleID, + "app_name": parseAppName(bundleID), + "message": result, + }, + }, nil + } + + // Activation succeeded but couldn't verify + result := fmt.Sprintf("Attempted to activate %s (bundle ID: %s)", parseAppName(bundleID), bundleID) + return &domain.ToolExecutionResult{ + ToolName: "ActivateApp", + Success: true, + Data: map[string]any{ + "bundle_id": bundleID, + "app_name": parseAppName(bundleID), + "message": result, + }, + }, nil +} + +// IsEnabled returns whether the tool is enabled +func (t *ActivateAppTool) IsEnabled() bool { + return t.config.GetConfig().ComputerUse.Enabled +} + +// FormatPreview formats the result for display preview +func (t *ActivateAppTool) FormatPreview(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return "Failed to activate app" + } + data, ok := result.Data.(map[string]any) + if !ok { + return "Activated app" + } + if appName, ok := data["app_name"].(string); ok { + return fmt.Sprintf("Activated: %s", appName) + } + return "Activated app" +} + +// FormatForLLM formats the result for LLM consumption +func (t *ActivateAppTool) FormatForLLM(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return fmt.Sprintf("Error: %s", result.Error) + } + data, ok := result.Data.(map[string]any) + if !ok { + return "Successfully activated application" + } + if message, ok := data["message"].(string); ok { + return message + } + return "Successfully activated application" +} + +// ShouldCollapseArg determines if an argument should be collapsed in display +func (t *ActivateAppTool) ShouldCollapseArg(key string) bool { + return false +} + +// ShouldAlwaysExpand determines if tool results should always be expanded in UI +func (t *ActivateAppTool) ShouldAlwaysExpand() bool { + return false +} + +// FormatResult formats the result based on the requested format type +func (t *ActivateAppTool) FormatResult(result *domain.ToolExecutionResult, formatType domain.FormatterType) string { + switch formatType { + case domain.FormatterLLM: + return t.FormatForLLM(result) + case domain.FormatterShort: + return t.FormatPreview(result) + default: + return t.FormatForLLM(result) + } +} diff --git a/internal/services/tools/get_focused_app.go b/internal/services/tools/get_focused_app.go new file mode 100644 index 00000000..bbe152e1 --- /dev/null +++ b/internal/services/tools/get_focused_app.go @@ -0,0 +1,188 @@ +package tools + +import ( + "context" + "fmt" + + display "github.com/inference-gateway/cli/internal/display" + domain "github.com/inference-gateway/cli/internal/domain" + logger "github.com/inference-gateway/cli/internal/logger" + sdk "github.com/inference-gateway/sdk" +) + +// GetFocusedAppTool gets the currently focused application +type GetFocusedAppTool struct { + config domain.ConfigService +} + +// NewGetFocusedAppTool creates a new GetFocusedApp tool +func NewGetFocusedAppTool(config domain.ConfigService) *GetFocusedAppTool { + return &GetFocusedAppTool{ + config: config, + } +} + +// Definition returns the tool definition for GetFocusedApp +func (t *GetFocusedAppTool) Definition() sdk.ChatCompletionTool { + description := "Gets the currently focused (frontmost) application. Returns the application name and bundle identifier. Use this before performing computer use actions to verify the correct application is in focus." + return sdk.ChatCompletionTool{ + Type: sdk.Function, + Function: sdk.FunctionObject{ + Name: "GetFocusedApp", + Description: &description, + Parameters: &sdk.FunctionParameters{ + "type": "object", + "properties": map[string]any{}, + }, + }, + } +} + +// Validate validates GetFocusedApp arguments +func (t *GetFocusedAppTool) Validate(args map[string]any) error { + return nil +} + +// Execute executes the GetFocusedApp tool +func (t *GetFocusedAppTool) Execute(ctx context.Context, args map[string]any) (*domain.ToolExecutionResult, error) { + displayProvider, err := display.DetectDisplay() + if err != nil { + return nil, fmt.Errorf("failed to detect display: %w", err) + } + + controller, err := displayProvider.GetController() + if err != nil { + return nil, fmt.Errorf("failed to get display controller: %w", err) + } + defer func() { + if err := controller.Close(); err != nil { + logger.Debug("Failed to close display controller", "error", err) + } + }() + + focusManager, ok := controller.(display.FocusManager) + if !ok { + return nil, fmt.Errorf("display controller does not support focus management") + } + + appID, err := focusManager.GetFrontmostApp(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get focused app: %w", err) + } + + if appID == "" { + return nil, fmt.Errorf("no application is currently focused") + } + + // Parse app name from bundle ID (e.g., "org.mozilla.firefox" -> "Firefox") + appName := parseAppName(appID) + + result := fmt.Sprintf("Currently focused application:\n- Name: %s\n- Bundle ID: %s", appName, appID) + + return &domain.ToolExecutionResult{ + ToolName: "GetFocusedApp", + Success: true, + Data: map[string]any{ + "app_name": appName, + "bundle_id": appID, + "message": result, + }, + }, nil +} + +// IsEnabled returns whether the tool is enabled +func (t *GetFocusedAppTool) IsEnabled() bool { + return t.config.GetConfig().ComputerUse.Enabled +} + +// FormatPreview formats the result for display preview +func (t *GetFocusedAppTool) FormatPreview(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return "Failed to get focused app" + } + data, ok := result.Data.(map[string]any) + if !ok { + return "Got focused app" + } + if appName, ok := data["app_name"].(string); ok { + return fmt.Sprintf("Focused: %s", appName) + } + return "Got focused app" +} + +// FormatForLLM formats the result for LLM consumption +func (t *GetFocusedAppTool) FormatForLLM(result *domain.ToolExecutionResult) string { + if result == nil || !result.Success { + return fmt.Sprintf("Error: %s", result.Error) + } + data, ok := result.Data.(map[string]any) + if !ok { + return "Successfully retrieved focused application" + } + if message, ok := data["message"].(string); ok { + return message + } + return "Successfully retrieved focused application" +} + +// ShouldCollapseArg determines if an argument should be collapsed in display +func (t *GetFocusedAppTool) ShouldCollapseArg(key string) bool { + return false +} + +// ShouldAlwaysExpand determines if tool results should always be expanded in UI +func (t *GetFocusedAppTool) ShouldAlwaysExpand() bool { + return false +} + +// FormatResult formats the result based on the requested format type +func (t *GetFocusedAppTool) FormatResult(result *domain.ToolExecutionResult, formatType domain.FormatterType) string { + switch formatType { + case domain.FormatterLLM: + return t.FormatForLLM(result) + case domain.FormatterShort: + return t.FormatPreview(result) + default: + return t.FormatForLLM(result) + } +} + +// parseAppName extracts a human-readable app name from bundle ID +func parseAppName(bundleID string) string { + // Common mappings + appNames := map[string]string{ + "com.apple.Terminal": "Terminal", + "com.googlecode.iterm2": "iTerm2", + "com.microsoft.VSCode": "Visual Studio Code", + "org.mozilla.firefox": "Firefox", + "com.google.Chrome": "Google Chrome", + "com.apple.Safari": "Safari", + "com.microsoft.edgemac": "Microsoft Edge", + "com.brave.Browser": "Brave Browser", + "com.sublimetext.4": "Sublime Text", + "com.jetbrains.goland": "GoLand", + "com.jetbrains.intellij": "IntelliJ IDEA", + "org.alacritty": "Alacritty", + "net.kovidgoyal.kitty": "Kitty", + "com.apple.finder": "Finder", + "com.apple.TextEdit": "TextEdit", + "com.spotify.client": "Spotify", + "com.tinyspeck.slackmacgap": "Slack", + "us.zoom.xos": "Zoom", + "com.microsoft.teams": "Microsoft Teams", + "com.docker.docker": "Docker Desktop", + "com.postmanlabs.app": "Postman", + "com.notion.desktop": "Notion", + "com.figma.Desktop": "Figma", + "com.apple.Notes": "Notes", + "com.apple.mail": "Mail", + "com.apple.iCal": "Calendar", + } + + if name, ok := appNames[bundleID]; ok { + return name + } + + // Fallback: return bundle ID + return bundleID +} diff --git a/internal/services/tools/get_latest_screenshot.go b/internal/services/tools/get_latest_screenshot.go index ef97fe07..7bbc4d2d 100644 --- a/internal/services/tools/get_latest_screenshot.go +++ b/internal/services/tools/get_latest_screenshot.go @@ -102,12 +102,11 @@ func (t *GetLatestScreenshotTool) Execute(ctx context.Context, args map[string]a } result := domain.ScreenshotToolResult{ - Display: t.config.ComputerUse.Display, - Region: nil, - Width: screenshot.Width, - Height: screenshot.Height, - Format: screenshot.Format, - Method: screenshot.Method, + Region: nil, + Width: screenshot.Width, + Height: screenshot.Height, + Format: screenshot.Format, + Method: screenshot.Method, } return &domain.ToolExecutionResult{ diff --git a/internal/services/tools/keyboard_type.go b/internal/services/tools/keyboard_type.go index 5524e28f..0b269520 100644 --- a/internal/services/tools/keyboard_type.go +++ b/internal/services/tools/keyboard_type.go @@ -34,7 +34,7 @@ func NewKeyboardTypeTool(cfg *config.Config, rateLimiter domain.RateLimiter, dis // Definition returns the tool definition for the LLM func (t *KeyboardTypeTool) Definition() sdk.ChatCompletionTool { - description := "Types text or sends key combinations. Can type regular text or send special key combinations like 'ctrl+c'. Requires user approval unless in auto-accept mode. Note: Exactly one of 'text' or 'key_combo' must be provided." + description := "Types text or sends key combinations INTO GUI APPLICATIONS at the current cursor position (e.g., typing in a text editor, browser search box, or form field). DO NOT use this to run shell commands - use the Bash tool instead. To open applications on macOS, use Bash with 'open -a AppName'. Requires user approval unless in auto-accept mode. Note: Exactly one of 'text' or 'key_combo' must be provided." return sdk.ChatCompletionTool{ Type: sdk.Function, Function: sdk.FunctionObject{ @@ -45,16 +45,11 @@ func (t *KeyboardTypeTool) Definition() sdk.ChatCompletionTool { "properties": map[string]any{ "text": map[string]any{ "type": "string", - "description": "Text to type. Mutually exclusive with key_combo.", + "description": "Text to type into the active GUI application (NOT for running commands).", }, "key_combo": map[string]any{ "type": "string", - "description": "Key combination to send (e.g., 'ctrl+c', 'alt+tab', 'shift+enter'). Mutually exclusive with text.", - }, - "display": map[string]any{ - "type": "string", - "description": "Display to use (e.g., ':0'). Defaults to ':0'.", - "default": ":0", + "description": "Key combination to send (e.g., 'cmd+c' for copy, 'cmd+v' for paste, 'cmd+tab' to switch apps). Use platform-specific modifiers: 'cmd' on macOS, 'ctrl' on Linux/Windows.", }, }, }, @@ -99,11 +94,6 @@ func (t *KeyboardTypeTool) Execute(ctx context.Context, args map[string]any) (*d }, nil } - displayName := t.config.ComputerUse.Display - if displayArg, ok := args["display"].(string); ok && displayArg != "" { - displayName = displayArg - } - if t.displayProvider == nil { return &domain.ToolExecutionResult{ ToolName: "KeyboardType", @@ -114,7 +104,7 @@ func (t *KeyboardTypeTool) Execute(ctx context.Context, args map[string]any) (*d }, nil } - controller, err := t.displayProvider.GetController(displayName) + controller, err := t.displayProvider.GetController() if err != nil { return &domain.ToolExecutionResult{ ToolName: "KeyboardType", @@ -150,7 +140,6 @@ func (t *KeyboardTypeTool) Execute(ctx context.Context, args map[string]any) (*d result := domain.KeyboardTypeToolResult{ Text: text, KeyCombo: keyCombo, - Display: displayName, Method: t.displayProvider.GetDisplayInfo().Name, } diff --git a/internal/services/tools/mouse_click.go b/internal/services/tools/mouse_click.go index 5ccf5a05..ab5be236 100644 --- a/internal/services/tools/mouse_click.go +++ b/internal/services/tools/mouse_click.go @@ -63,11 +63,6 @@ func (t *MouseClickTool) Definition() sdk.ChatCompletionTool { "type": "integer", "description": "Optional: Y coordinate to move to before clicking", }, - "display": map[string]any{ - "type": "string", - "description": "Display to use (e.g., ':0'). Defaults to ':0'.", - "default": ":0", - }, }, "required": []string{"button"}, }, @@ -99,11 +94,6 @@ func (t *MouseClickTool) Execute(ctx context.Context, args map[string]any) (*dom clicks = int(clicksArg) } - displayName := t.config.ComputerUse.Display - if displayArg, ok := args["display"].(string); ok && displayArg != "" { - displayName = displayArg - } - var finalX, finalY int shouldMove := false @@ -125,7 +115,7 @@ func (t *MouseClickTool) Execute(ctx context.Context, args map[string]any) (*dom }, nil } - controller, err := t.displayProvider.GetController(displayName) + controller, err := t.displayProvider.GetController() if err != nil { return &domain.ToolExecutionResult{ ToolName: "MouseClick", @@ -168,12 +158,11 @@ func (t *MouseClickTool) Execute(ctx context.Context, args map[string]any) (*dom } result := domain.MouseClickToolResult{ - Button: button, - Clicks: clicks, - X: finalX, - Y: finalY, - Display: displayName, - Method: t.displayProvider.GetDisplayInfo().Name, + Button: button, + Clicks: clicks, + X: finalX, + Y: finalY, + Method: t.displayProvider.GetDisplayInfo().Name, } return &domain.ToolExecutionResult{ diff --git a/internal/services/tools/mouse_move.go b/internal/services/tools/mouse_move.go index 2210b077..48f890fd 100644 --- a/internal/services/tools/mouse_move.go +++ b/internal/services/tools/mouse_move.go @@ -51,11 +51,6 @@ func (t *MouseMoveTool) Definition() sdk.ChatCompletionTool { "type": "integer", "description": "Y coordinate (absolute position from top edge of screen)", }, - "display": map[string]any{ - "type": "string", - "description": "Display to use (e.g., ':0'). Defaults to ':0'.", - "default": ":0", - }, }, "required": []string{"x", "y"}, }, @@ -90,11 +85,6 @@ func (t *MouseMoveTool) Execute(ctx context.Context, args map[string]any) (*doma }, nil } - displayName := t.config.ComputerUse.Display - if displayArg, ok := args["display"].(string); ok && displayArg != "" { - displayName = displayArg - } - if t.displayProvider == nil { return &domain.ToolExecutionResult{ ToolName: "MouseMove", @@ -105,7 +95,7 @@ func (t *MouseMoveTool) Execute(ctx context.Context, args map[string]any) (*doma }, nil } - controller, err := t.displayProvider.GetController(displayName) + controller, err := t.displayProvider.GetController() if err != nil { return &domain.ToolExecutionResult{ ToolName: "MouseMove", @@ -134,12 +124,11 @@ func (t *MouseMoveTool) Execute(ctx context.Context, args map[string]any) (*doma } result := domain.MouseMoveToolResult{ - FromX: fromX, - FromY: fromY, - ToX: int(x), - ToY: int(y), - Display: displayName, - Method: t.displayProvider.GetDisplayInfo().Name, + FromX: fromX, + FromY: fromY, + ToX: int(x), + ToY: int(y), + Method: t.displayProvider.GetDisplayInfo().Name, } return &domain.ToolExecutionResult{ diff --git a/internal/services/tools/mouse_scroll.go b/internal/services/tools/mouse_scroll.go new file mode 100644 index 00000000..7e8c3390 --- /dev/null +++ b/internal/services/tools/mouse_scroll.go @@ -0,0 +1,217 @@ +package tools + +import ( + "context" + "fmt" + "time" + + config "github.com/inference-gateway/cli/config" + display "github.com/inference-gateway/cli/internal/display" + domain "github.com/inference-gateway/cli/internal/domain" + logger "github.com/inference-gateway/cli/internal/logger" + sdk "github.com/inference-gateway/sdk" +) + +// MouseScrollTool scrolls the mouse wheel +type MouseScrollTool struct { + config *config.Config + enabled bool + formatter domain.BaseFormatter + rateLimiter domain.RateLimiter + displayProvider display.Provider +} + +// NewMouseScrollTool creates a new mouse scroll tool +func NewMouseScrollTool(cfg *config.Config, rateLimiter domain.RateLimiter, displayProvider display.Provider) *MouseScrollTool { + return &MouseScrollTool{ + config: cfg, + enabled: cfg.ComputerUse.Enabled, + formatter: domain.NewBaseFormatter("MouseScroll"), + rateLimiter: rateLimiter, + displayProvider: displayProvider, + } +} + +// Definition returns the tool definition for the LLM +func (t *MouseScrollTool) Definition() sdk.ChatCompletionTool { + description := "Scrolls the mouse wheel up or down. Useful for navigating web pages, documents, and long content. Positive values scroll down, negative values scroll up." + return sdk.ChatCompletionTool{ + Type: sdk.Function, + Function: sdk.FunctionObject{ + Name: "MouseScroll", + Description: &description, + Parameters: &sdk.FunctionParameters{ + "type": "object", + "properties": map[string]any{ + "clicks": map[string]any{ + "type": "integer", + "description": "Number of scroll clicks. Positive = scroll down, negative = scroll up. Each click scrolls by a few lines. Example: 5 scrolls down, -3 scrolls up.", + }, + "direction": map[string]any{ + "type": "string", + "description": "Scroll direction: 'vertical' (default, up/down) or 'horizontal' (left/right)", + "enum": []string{"vertical", "horizontal"}, + }, + }, + "required": []string{"clicks"}, + }, + }, + } +} + +// Execute runs the mouse scroll tool with given arguments +func (t *MouseScrollTool) Execute(ctx context.Context, args map[string]any) (*domain.ToolExecutionResult, error) { + start := time.Now() + + if err := t.rateLimiter.CheckAndRecord("MouseScroll"); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseScroll", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: err.Error(), + }, nil + } + + clicks, clicksOk := args["clicks"].(float64) + if !clicksOk { + return &domain.ToolExecutionResult{ + ToolName: "MouseScroll", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: "clicks must be an integer", + }, nil + } + + direction := "vertical" + if dirVal, ok := args["direction"].(string); ok { + direction = dirVal + } + + controller, err := t.displayProvider.GetController() + if err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseScroll", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: fmt.Sprintf("failed to get display controller: %v", err), + }, nil + } + defer func() { + if err := controller.Close(); err != nil { + logger.Debug("Failed to close display controller", "error", err) + } + }() + + clicksInt := int(clicks) + + if err := controller.ScrollMouse(ctx, clicksInt, direction); err != nil { + return &domain.ToolExecutionResult{ + ToolName: "MouseScroll", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: fmt.Sprintf("failed to scroll: %v", err), + }, nil + } + + scrollDir := "down" + if clicksInt < 0 { + scrollDir = "up" + clicksInt = -clicksInt + } + if direction == "horizontal" { + if int(clicks) > 0 { + scrollDir = "right" + } else { + scrollDir = "left" + } + } + + message := fmt.Sprintf("Scrolled %s by %d clicks", scrollDir, clicksInt) + logger.Info("Mouse scroll executed", "direction", direction, "clicks", clicks) + + return &domain.ToolExecutionResult{ + ToolName: "MouseScroll", + Arguments: args, + Success: true, + Duration: time.Since(start), + Data: map[string]any{ + "clicks": clicks, + "direction": direction, + "message": message, + }, + }, nil +} + +// Validate validates the tool arguments +func (t *MouseScrollTool) Validate(args map[string]any) error { + if _, ok := args["clicks"].(float64); !ok { + return fmt.Errorf("clicks must be an integer") + } + + if direction, ok := args["direction"].(string); ok { + if direction != "vertical" && direction != "horizontal" { + return fmt.Errorf("direction must be 'vertical' or 'horizontal'") + } + } + + return nil +} + +// IsEnabled returns whether the tool is enabled +func (t *MouseScrollTool) IsEnabled() bool { + return t.enabled +} + +// FormatPreview formats a short preview of tool execution +func (t *MouseScrollTool) FormatPreview(result *domain.ToolExecutionResult) string { + if !result.Success { + return "Scroll failed" + } + + data, ok := result.Data.(map[string]any) + if !ok { + return "Scrolled" + } + + return fmt.Sprintf("%s", data["message"]) +} + +// FormatForLLM formats the result for LLM consumption +func (t *MouseScrollTool) FormatForLLM(result *domain.ToolExecutionResult) string { + if !result.Success { + return fmt.Sprintf("Scroll failed: %s", result.Error) + } + + data, ok := result.Data.(map[string]any) + if !ok { + return "Scrolled successfully" + } + + return fmt.Sprintf("%s. Use GetLatestScreenshot to see the new content.", data["message"]) +} + +// ShouldCollapseArg determines if an argument should be collapsed in UI +func (t *MouseScrollTool) ShouldCollapseArg(argName string) bool { + return false +} + +// ShouldAlwaysExpand determines if tool results should always be expanded in UI +func (t *MouseScrollTool) ShouldAlwaysExpand() bool { + return false +} + +// FormatResult formats the result based on the requested format type +func (t *MouseScrollTool) FormatResult(result *domain.ToolExecutionResult, formatType domain.FormatterType) string { + switch formatType { + case domain.FormatterLLM: + return t.FormatForLLM(result) + case domain.FormatterShort: + return t.FormatPreview(result) + default: + return t.FormatForLLM(result) + } +} diff --git a/internal/services/tools/registry.go b/internal/services/tools/registry.go index e2d15352..e909103c 100644 --- a/internal/services/tools/registry.go +++ b/internal/services/tools/registry.go @@ -6,7 +6,6 @@ import ( "strings" "time" - config "github.com/inference-gateway/cli/config" display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" @@ -20,7 +19,7 @@ import ( // Registry manages all available tools type Registry struct { - config *config.Config + config domain.ConfigService tools map[string]domain.Tool readToolUsed bool taskTracker domain.TaskTracker @@ -30,7 +29,7 @@ type Registry struct { } // NewRegistry creates a new tool registry with self-contained tools -func NewRegistry(cfg *config.Config, imageService domain.ImageService, mcpManager domain.MCPManager, shellService domain.BackgroundShellService) *Registry { +func NewRegistry(cfg domain.ConfigService, imageService domain.ImageService, mcpManager domain.MCPManager, shellService domain.BackgroundShellService) *Registry { registry := &Registry{ config: cfg, tools: make(map[string]domain.Tool), @@ -47,62 +46,68 @@ func NewRegistry(cfg *config.Config, imageService domain.ImageService, mcpManage // registerTools initializes and registers all available tools func (r *Registry) registerTools() { - r.tools["Bash"] = NewBashTool(r.config, r.shellService) + cfg := r.config.GetConfig() - if r.config.Tools.Bash.BackgroundShells.Enabled && r.shellService != nil { - r.tools["BashOutput"] = NewBashOutputTool(r.config, r.shellService) - r.tools["KillShell"] = NewKillShellTool(r.config, r.shellService) - r.tools["ListShells"] = NewListShellsTool(r.config, r.shellService) + r.tools["Bash"] = NewBashTool(cfg, r.shellService) + + if cfg.Tools.Bash.BackgroundShells.Enabled && r.shellService != nil { + r.tools["BashOutput"] = NewBashOutputTool(cfg, r.shellService) + r.tools["KillShell"] = NewKillShellTool(cfg, r.shellService) + r.tools["ListShells"] = NewListShellsTool(cfg, r.shellService) } - r.tools["Read"] = NewReadTool(r.config) - r.tools["Write"] = NewWriteTool(r.config) - r.tools["Edit"] = NewEditToolWithRegistry(r.config, r) - r.tools["MultiEdit"] = NewMultiEditToolWithRegistry(r.config, r) - r.tools["Delete"] = NewDeleteTool(r.config) - r.tools["Grep"] = NewGrepTool(r.config) - r.tools["Tree"] = NewTreeTool(r.config) - r.tools["TodoWrite"] = NewTodoWriteTool(r.config) - r.tools["RequestPlanApproval"] = NewRequestPlanApprovalTool(r.config) - - if r.config.Tools.WebFetch.Enabled { - r.tools["WebFetch"] = NewWebFetchTool(r.config) + r.tools["Read"] = NewReadTool(cfg) + r.tools["Write"] = NewWriteTool(cfg) + r.tools["Edit"] = NewEditToolWithRegistry(cfg, r) + r.tools["MultiEdit"] = NewMultiEditToolWithRegistry(cfg, r) + r.tools["Delete"] = NewDeleteTool(cfg) + r.tools["Grep"] = NewGrepTool(cfg) + r.tools["Tree"] = NewTreeTool(cfg) + r.tools["TodoWrite"] = NewTodoWriteTool(cfg) + r.tools["RequestPlanApproval"] = NewRequestPlanApprovalTool(cfg) + + if cfg.Tools.WebFetch.Enabled { + r.tools["WebFetch"] = NewWebFetchTool(cfg) } - if r.config.Tools.WebSearch.Enabled { - r.tools["WebSearch"] = NewWebSearchTool(r.config) + if cfg.Tools.WebSearch.Enabled { + r.tools["WebSearch"] = NewWebSearchTool(cfg) } - if r.config.Tools.Github.Enabled { - r.tools["Github"] = NewGithubTool(r.config, r.imageService) + if cfg.Tools.Github.Enabled { + r.tools["Github"] = NewGithubTool(cfg, r.imageService) } - if r.config.IsA2AToolsEnabled() { - r.tools["A2A_QueryAgent"] = NewA2AQueryAgentTool(r.config) - r.tools["A2A_QueryTask"] = NewA2AQueryTaskTool(r.config, r.taskTracker) - r.tools["A2A_SubmitTask"] = NewA2ASubmitTaskTool(r.config, r.taskTracker) + if cfg.IsA2AToolsEnabled() { + r.tools["A2A_QueryAgent"] = NewA2AQueryAgentTool(cfg) + r.tools["A2A_QueryTask"] = NewA2AQueryTaskTool(cfg, r.taskTracker) + r.tools["A2A_SubmitTask"] = NewA2ASubmitTaskTool(cfg, r.taskTracker) } - if r.config.ComputerUse.Enabled { + if cfg.ComputerUse.Enabled { displayProvider, err := display.DetectDisplay() if err != nil { logger.Warn("No compatible display platform detected, computer use tools will be disabled", "error", err) } else { - rateLimiter := utils.NewRateLimiter(r.config.ComputerUse.RateLimit) - r.tools["MouseMove"] = NewMouseMoveTool(r.config, rateLimiter, displayProvider) - r.tools["MouseClick"] = NewMouseClickTool(r.config, rateLimiter, displayProvider) - r.tools["KeyboardType"] = NewKeyboardTypeTool(r.config, rateLimiter, displayProvider) + rateLimiter := utils.NewRateLimiter(cfg.ComputerUse.RateLimit) + r.tools["MouseMove"] = NewMouseMoveTool(cfg, rateLimiter, displayProvider) + r.tools["MouseClick"] = NewMouseClickTool(cfg, rateLimiter, displayProvider) + r.tools["MouseScroll"] = NewMouseScrollTool(cfg, rateLimiter, displayProvider) + r.tools["KeyboardType"] = NewKeyboardTypeTool(cfg, rateLimiter, displayProvider) + r.tools["GetFocusedApp"] = NewGetFocusedAppTool(r.config) + r.tools["ActivateApp"] = NewActivateAppTool(r.config) } } - if r.config.MCP.Enabled && r.mcpManager != nil { + if cfg.MCP.Enabled && r.mcpManager != nil { r.registerMCPTools() } } // registerMCPTools discovers and registers tools from enabled MCP servers func (r *Registry) registerMCPTools() { - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(r.config.MCP.DiscoveryTimeout)*time.Second) + cfg := r.config.GetConfig() + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.MCP.DiscoveryTimeout)*time.Second) defer cancel() toolCount := 0 @@ -197,6 +202,8 @@ func (r *Registry) RegisterMCPServerTools(serverName string, tools []domain.MCPD } toolCount := 0 + cfg := r.config.GetConfig() + for _, tool := range tools { fullToolName := fmt.Sprintf("MCP_%s_%s", serverName, tool.Name) @@ -206,7 +213,7 @@ func (r *Registry) RegisterMCPServerTools(serverName string, tools []domain.MCPD tool.Description, tool.InputSchema, targetClient, - &r.config.MCP, + &cfg.MCP, ) r.tools[fullToolName] = mcpTool @@ -246,7 +253,8 @@ func (r *Registry) UnregisterMCPServerTools(serverName string) int { // SetScreenshotServer dynamically registers the GetLatestScreenshot tool // This should be called after the screenshot server is started func (r *Registry) SetScreenshotServer(provider domain.ScreenshotProvider) { - if !r.config.ComputerUse.Enabled || !r.config.ComputerUse.Screenshot.StreamingEnabled { + cfg := r.config.GetConfig() + if !cfg.ComputerUse.Enabled || !cfg.ComputerUse.Screenshot.StreamingEnabled { logger.Debug("Screenshot streaming not enabled, skipping GetLatestScreenshot tool registration") return } @@ -256,7 +264,7 @@ func (r *Registry) SetScreenshotServer(provider domain.ScreenshotProvider) { return } - getLatestTool := NewGetLatestScreenshotTool(r.config, provider) + getLatestTool := NewGetLatestScreenshotTool(cfg, provider) r.tools["GetLatestScreenshot"] = getLatestTool logger.Info("Dynamically registered GetLatestScreenshot tool for streaming mode") diff --git a/internal/services/tools/registry_test.go b/internal/services/tools/registry_test.go index 2d0b30e6..195838fd 100644 --- a/internal/services/tools/registry_test.go +++ b/internal/services/tools/registry_test.go @@ -7,9 +7,66 @@ import ( config "github.com/inference-gateway/cli/config" domain "github.com/inference-gateway/cli/internal/domain" mocks "github.com/inference-gateway/cli/tests/mocks/domain" - "github.com/inference-gateway/sdk" + sdk "github.com/inference-gateway/sdk" ) +// testConfigService is a minimal mock implementation of domain.ConfigService for testing +type testConfigService struct { + config *config.Config +} + +func newTestConfigService(cfg *config.Config) domain.ConfigService { + return &testConfigService{config: cfg} +} + +func (t *testConfigService) GetConfig() *config.Config { + return t.config +} + +func (t *testConfigService) Reload() (*config.Config, error) { + return t.config, nil +} + +func (t *testConfigService) SetValue(key, value string) error { + return nil +} + +func (t *testConfigService) IsApprovalRequired(toolName string) bool { + return t.config.IsApprovalRequired(toolName) +} + +func (t *testConfigService) IsBashCommandWhitelisted(command string) bool { + return t.config.IsBashCommandWhitelisted(command) +} + +func (t *testConfigService) GetOutputDirectory() string { + return t.config.GetOutputDirectory() +} + +func (t *testConfigService) GetGatewayURL() string { + return t.config.Gateway.URL +} + +func (t *testConfigService) GetAPIKey() string { + return t.config.Gateway.APIKey +} + +func (t *testConfigService) GetTimeout() int { + return t.config.Gateway.Timeout +} + +func (t *testConfigService) GetAgentConfig() *config.AgentConfig { + return t.config.GetAgentConfig() +} + +func (t *testConfigService) GetSandboxDirectories() []string { + return t.config.GetSandboxDirectories() +} + +func (t *testConfigService) GetProtectedPaths() []string { + return t.config.GetProtectedPaths() +} + func createTestRegistry() *Registry { cfg := &config.Config{ Tools: config.ToolsConfig{ @@ -33,7 +90,7 @@ func createTestRegistry() *Registry { }, } - return NewRegistry(cfg, nil, nil, nil) + return NewRegistry(newTestConfigService(cfg), nil, nil, nil) } func TestRegistry_GetTool_Unknown(t *testing.T) { @@ -65,7 +122,7 @@ func TestRegistry_DisabledTools(t *testing.T) { }, } - registry := NewRegistry(cfg, nil, nil, nil) + registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil) tools := registry.ListAvailableTools() @@ -117,13 +174,14 @@ func TestRegistry_NewRegistry(t *testing.T) { }, } - registry := NewRegistry(cfg, nil, nil, nil) + configService := newTestConfigService(cfg) + registry := NewRegistry(configService, nil, nil, nil) if registry == nil { t.Fatal("Expected non-nil registry") } - if registry.config != cfg { + if registry.config.GetConfig() != cfg { t.Error("Expected config to be set correctly") } @@ -152,7 +210,7 @@ func TestRegistry_GetTool(t *testing.T) { }, } - registry := NewRegistry(cfg, nil, nil, nil) + registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil) tests := []struct { name string @@ -303,7 +361,7 @@ func TestRegistry_ListAvailableTools(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - registry := NewRegistry(tt.config, nil, nil, nil) + registry := NewRegistry(newTestConfigService(tt.config), nil, nil, nil) tools := registry.ListAvailableTools() if len(tools) < tt.expectedMin || len(tools) > tt.expectedMax { @@ -356,7 +414,7 @@ func TestRegistry_GetToolDefinitions(t *testing.T) { }, } - registry := NewRegistry(cfg, nil, nil, nil) + registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil) definitions := registry.GetToolDefinitions() if len(definitions) < 5 || len(definitions) > 15 { @@ -406,7 +464,7 @@ func TestRegistry_IsToolEnabled(t *testing.T) { }, } - registry := NewRegistry(cfg, nil, nil, nil) + registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil) tests := []struct { name string @@ -459,7 +517,7 @@ func TestRegistry_WithMockedTool(t *testing.T) { }, } - registry := NewRegistry(cfg, nil, nil, nil) + registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil) fakeTool := &mocks.FakeTool{} fakeTool.IsEnabledReturns(true) diff --git a/internal/web/pty_manager.go b/internal/web/pty_manager.go index 80a3b956..920774c4 100644 --- a/internal/web/pty_manager.go +++ b/internal/web/pty_manager.go @@ -165,7 +165,10 @@ func (s *LocalPTYSession) Start(cols, rows int) error { } s.cmd = exec.Command(execPath, "chat") - s.cmd.Env = append(os.Environ(), "TERM=xterm-256color") + s.cmd.Env = append(os.Environ(), + "TERM=xterm-256color", + "INFER_WEB_MODE=true", + ) ptyFile, err := pty.Start(s.cmd) if err != nil { diff --git a/tests/mocks/domain/fake_config_service.go b/tests/mocks/domain/fake_config_service.go index 11ae1f3d..368bac6f 100644 --- a/tests/mocks/domain/fake_config_service.go +++ b/tests/mocks/domain/fake_config_service.go @@ -29,6 +29,16 @@ type FakeConfigService struct { getAgentConfigReturnsOnCall map[int]struct { result1 *config.AgentConfig } + GetConfigStub func() *config.Config + getConfigMutex sync.RWMutex + getConfigArgsForCall []struct { + } + getConfigReturns struct { + result1 *config.Config + } + getConfigReturnsOnCall map[int]struct { + result1 *config.Config + } GetGatewayURLStub func() string getGatewayURLMutex sync.RWMutex getGatewayURLArgsForCall []struct { @@ -211,6 +221,59 @@ func (fake *FakeConfigService) GetAgentConfigReturnsOnCall(i int, result1 *confi }{result1} } +func (fake *FakeConfigService) GetConfig() *config.Config { + fake.getConfigMutex.Lock() + ret, specificReturn := fake.getConfigReturnsOnCall[len(fake.getConfigArgsForCall)] + fake.getConfigArgsForCall = append(fake.getConfigArgsForCall, struct { + }{}) + stub := fake.GetConfigStub + fakeReturns := fake.getConfigReturns + fake.recordInvocation("GetConfig", []interface{}{}) + fake.getConfigMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeConfigService) GetConfigCallCount() int { + fake.getConfigMutex.RLock() + defer fake.getConfigMutex.RUnlock() + return len(fake.getConfigArgsForCall) +} + +func (fake *FakeConfigService) GetConfigCalls(stub func() *config.Config) { + fake.getConfigMutex.Lock() + defer fake.getConfigMutex.Unlock() + fake.GetConfigStub = stub +} + +func (fake *FakeConfigService) GetConfigReturns(result1 *config.Config) { + fake.getConfigMutex.Lock() + defer fake.getConfigMutex.Unlock() + fake.GetConfigStub = nil + fake.getConfigReturns = struct { + result1 *config.Config + }{result1} +} + +func (fake *FakeConfigService) GetConfigReturnsOnCall(i int, result1 *config.Config) { + fake.getConfigMutex.Lock() + defer fake.getConfigMutex.Unlock() + fake.GetConfigStub = nil + if fake.getConfigReturnsOnCall == nil { + fake.getConfigReturnsOnCall = make(map[int]struct { + result1 *config.Config + }) + } + fake.getConfigReturnsOnCall[i] = struct { + result1 *config.Config + }{result1} +} + func (fake *FakeConfigService) GetGatewayURL() string { fake.getGatewayURLMutex.Lock() ret, specificReturn := fake.getGatewayURLReturnsOnCall[len(fake.getGatewayURLArgsForCall)] From 69ead75e80aece0e09b1d70a8c66fffd86e44063 Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Sun, 4 Jan 2026 02:44:02 +0200 Subject: [PATCH 07/14] chore: Enable computer use --- .infer/config.yaml | 4 ++-- cspell.yaml | 9 +++++++++ internal/display/wayland/client.go | 25 +++++++++++++++++++++---- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/.infer/config.yaml b/.infer/config.yaml index ac406a65..103bb61b 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -700,7 +700,7 @@ web: install_version: latest servers: [] computer_use: - enabled: false + enabled: true restore_focus_on_approval: true screenshot: enabled: true @@ -709,7 +709,7 @@ computer_use: format: jpeg quality: 80 require_approval: false - streaming_enabled: false + streaming_enabled: true capture_interval: 3 buffer_size: 5 temp_dir: "" diff --git a/cspell.yaml b/cspell.yaml index 74b1850f..ad638051 100644 --- a/cspell.yaml +++ b/cspell.yaml @@ -7,6 +7,7 @@ words: - apapsch - asciicircum - asciitilde + - autoreleasepool - aymanbagabas - aymerick - bahlo @@ -20,6 +21,7 @@ words: - cancelreader - cellbuf - cespare + - CFLAGS - charmbracelet - cloudevents - codegen @@ -38,12 +40,14 @@ words: - easyjson - erikgeiser - exclam + - frontmost - fsnotify - funlen - gjson - gocognit - gocyclo - goldmark + - googlecode - gopkg - gotenv - groq @@ -53,6 +57,7 @@ words: - invopop - isatty - ISPEED + - iterm - journalctl - jsonmerge - jsonparser @@ -60,6 +65,7 @@ words: - keybind - keygen - kimi + - kovidgoyal - ledongthuc - libsqlite - lipgloss @@ -74,6 +80,7 @@ words: - metoro - mixtral - moonshotai + - mousemove - mtoken - multierr - myshortcuts @@ -98,6 +105,7 @@ words: - resty - retryable - rivo + - robotgo - runewidth - sabhiram - sagents @@ -108,6 +116,7 @@ words: - sname - sourcegraph - stretchr + - sublimetext - subosito - sysinfo - systemctl diff --git a/internal/display/wayland/client.go b/internal/display/wayland/client.go index 1dd5654d..2e94eb8f 100644 --- a/internal/display/wayland/client.go +++ b/internal/display/wayland/client.go @@ -7,8 +7,6 @@ import ( "strconv" "strings" "time" - - robotgo "github.com/go-vgo/robotgo" ) // WaylandClient provides Wayland screen control operations using command-line tools @@ -128,11 +126,30 @@ func (c *WaylandClient) ClickMouse(button string, clicks int) error { // ScrollMouse scrolls the mouse wheel func (c *WaylandClient) ScrollMouse(clicks int, direction string) error { + if _, err := exec.LookPath("ydotool"); err != nil { + return fmt.Errorf("ydotool not found (install with: sudo apt install ydotool)") + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // ydotool uses scroll codes: 0x150007 for vertical scroll + // Positive clicks = scroll down/right, negative = scroll up/left + var scrollCode string if direction == "horizontal" { - robotgo.ScrollDir(clicks, "right") + // Horizontal scroll not commonly supported by ydotool + return fmt.Errorf("horizontal scrolling not supported on Wayland") } else { - robotgo.Scroll(0, clicks) + scrollCode = "0x150007" + } + + // Execute scroll command + cmd := exec.CommandContext(ctx, "ydotool", "click", scrollCode, "--", strconv.Itoa(clicks)) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ydotool scroll failed: %s", string(output)) } + return nil } From c863fef2c3543a3ef2b8e94051edbbdc145281fe Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Sun, 4 Jan 2026 21:43:01 +0200 Subject: [PATCH 08/14] feat: Add floating window for computer use --- .infer/config.yaml | 7 +- cmd/chat.go | 14 + cmd/floating_window_darwin.go | 39 + cmd/floating_window_stub.go | 25 + cmd/root.go | 5 +- config/config.go | 49 +- cspell.yaml | 2 + internal/display/macos/event_bridge.go | 139 +++ internal/display/macos/event_bridge_stub.go | 34 + internal/display/macos/manager.go | 1099 +++++++++++++++++ internal/display/macos/manager_stub.go | 23 + internal/display/macos/overlay_darwin.go | 155 +++ internal/display/macos/overlay_darwin_test.go | 149 +++ internal/display/macos/overlay_stub.go | 33 + internal/display/macos/types.go | 39 + internal/display/wayland/client.go | 4 - internal/domain/chat_events.go | 7 +- internal/domain/events.go | 13 +- internal/domain/interfaces.go | 14 + internal/domain/state.go | 2 +- internal/handlers/chat_event_handler.go | 14 + internal/services/agent.go | 160 +-- internal/services/screenshot_server.go | 62 +- internal/services/state_manager.go | 32 + internal/services/tools/get_focused_app.go | 3 - tests/mocks/domain/fake_state_manager.go | 74 ++ 26 files changed, 2013 insertions(+), 184 deletions(-) create mode 100644 cmd/floating_window_darwin.go create mode 100644 cmd/floating_window_stub.go create mode 100644 internal/display/macos/event_bridge.go create mode 100644 internal/display/macos/event_bridge_stub.go create mode 100644 internal/display/macos/manager.go create mode 100644 internal/display/macos/manager_stub.go create mode 100644 internal/display/macos/overlay_darwin.go create mode 100644 internal/display/macos/overlay_darwin_test.go create mode 100644 internal/display/macos/overlay_stub.go create mode 100644 internal/display/macos/types.go diff --git a/.infer/config.yaml b/.infer/config.yaml index 103bb61b..1ba02935 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -701,7 +701,11 @@ web: servers: [] computer_use: enabled: true - restore_focus_on_approval: true + floating_window: + enabled: true + respawn_on_close: true + position: top-right + always_on_top: true screenshot: enabled: true max_width: 1920 @@ -714,6 +718,7 @@ computer_use: buffer_size: 5 temp_dir: "" log_captures: false + show_overlay: true mouse_move: enabled: true require_approval: false diff --git a/cmd/chat.go b/cmd/chat.go index f21fb6c8..ff191363 100644 --- a/cmd/chat.go +++ b/cmd/chat.go @@ -91,6 +91,8 @@ and have a conversational interface with the inference gateway.`, } // StartChatSession starts a chat session +// +//nolint:funlen // Chat session initialization requires multiple setup steps func StartChatSession(cfg *config.Config, v *viper.Viper) error { _ = clipboard.Init() @@ -182,6 +184,18 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { } } + floatingWindowMgr, err := initFloatingWindow(config, stateManager) + if err != nil { + return fmt.Errorf("failed to initialize floating window: %w", err) + } + if floatingWindowMgr != nil { + defer func() { + if err := floatingWindowMgr.Shutdown(); err != nil { + logger.Error("Failed to shutdown floating window", "error", err) + } + }() + } + versionInfo := GetVersionInfo() application := app.NewChatApplication( models, diff --git a/cmd/floating_window_darwin.go b/cmd/floating_window_darwin.go new file mode 100644 index 00000000..20e5ba5d --- /dev/null +++ b/cmd/floating_window_darwin.go @@ -0,0 +1,39 @@ +//go:build darwin + +package cmd + +import ( + "fmt" + + config "github.com/inference-gateway/cli/config" + macos "github.com/inference-gateway/cli/internal/display/macos" + domain "github.com/inference-gateway/cli/internal/domain" + logger "github.com/inference-gateway/cli/internal/logger" +) + +// FloatingWindowManager is the platform-specific interface for the floating window +type FloatingWindowManager interface { + Shutdown() error +} + +// initFloatingWindow initializes the floating window manager if enabled +func initFloatingWindow(config *config.Config, stateManager domain.StateManager) (FloatingWindowManager, error) { + logger.Info("Checking floating window conditions", + "computer_use_enabled", config.ComputerUse.Enabled, + "floating_window_enabled", config.ComputerUse.FloatingWindow.Enabled) + + if !config.ComputerUse.Enabled || !config.ComputerUse.FloatingWindow.Enabled { + return nil, nil + } + + logger.Info("Initializing floating window manager") + eventBridge := macos.NewEventBridge() + stateManager.SetEventBridge(eventBridge) + + floatingWindowMgr, err := macos.NewFloatingWindowManager(config, eventBridge, stateManager) + if err != nil { + return nil, fmt.Errorf("failed to create floating window manager: %w", err) + } + + return floatingWindowMgr, nil +} diff --git a/cmd/floating_window_stub.go b/cmd/floating_window_stub.go new file mode 100644 index 00000000..4ca6fce8 --- /dev/null +++ b/cmd/floating_window_stub.go @@ -0,0 +1,25 @@ +//go:build !darwin + +package cmd + +import ( + config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" +) + +// FloatingWindowManager is the platform-specific interface for the floating window +type FloatingWindowManager interface { + Shutdown() error +} + +// noopFloatingWindowManager is a no-op implementation for non-darwin platforms +type noopFloatingWindowManager struct{} + +func (n *noopFloatingWindowManager) Shutdown() error { + return nil +} + +// initFloatingWindow returns a no-op manager on non-darwin platforms +func initFloatingWindow(config *config.Config, stateManager domain.StateManager) (FloatingWindowManager, error) { + return &noopFloatingWindowManager{}, nil +} diff --git a/cmd/root.go b/cmd/root.go index 16a6e9cb..09093f69 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -93,7 +93,10 @@ func initConfig() { // nolint:funlen v.SetDefault("web.servers", defaults.Web.Servers) v.SetDefault("computer_use", defaults.ComputerUse) v.SetDefault("computer_use.enabled", defaults.ComputerUse.Enabled) - v.SetDefault("computer_use.restore_focus_on_approval", defaults.ComputerUse.RestoreFocusOnApproval) + v.SetDefault("computer_use.floating_window.enabled", defaults.ComputerUse.FloatingWindow.Enabled) + v.SetDefault("computer_use.floating_window.respawn_on_close", defaults.ComputerUse.FloatingWindow.RespawnOnClose) + v.SetDefault("computer_use.floating_window.position", defaults.ComputerUse.FloatingWindow.Position) + v.SetDefault("computer_use.floating_window.always_on_top", defaults.ComputerUse.FloatingWindow.AlwaysOnTop) v.SetDefault("computer_use.screenshot.enabled", defaults.ComputerUse.Screenshot.Enabled) v.SetDefault("computer_use.screenshot.max_width", defaults.ComputerUse.Screenshot.MaxWidth) v.SetDefault("computer_use.screenshot.max_height", defaults.ComputerUse.Screenshot.MaxHeight) diff --git a/config/config.go b/config/config.go index c0eae334..52bb17c8 100644 --- a/config/config.go +++ b/config/config.go @@ -250,16 +250,16 @@ type SandboxConfig struct { // ComputerUseConfig contains computer use tool settings type ComputerUseConfig struct { - Enabled bool `yaml:"enabled" mapstructure:"enabled"` - RestoreFocusOnApproval bool `yaml:"restore_focus_on_approval" mapstructure:"restore_focus_on_approval"` // Switch to terminal for approval, then restore focus - Screenshot ScreenshotToolConfig `yaml:"screenshot" mapstructure:"screenshot"` - MouseMove MouseMoveToolConfig `yaml:"mouse_move" mapstructure:"mouse_move"` - MouseClick MouseClickToolConfig `yaml:"mouse_click" mapstructure:"mouse_click"` - MouseScroll MouseScrollToolConfig `yaml:"mouse_scroll" mapstructure:"mouse_scroll"` - KeyboardType KeyboardTypeToolConfig `yaml:"keyboard_type" mapstructure:"keyboard_type"` - GetFocusedApp GetFocusedAppToolConfig `yaml:"get_focused_app" mapstructure:"get_focused_app"` - ActivateApp ActivateAppToolConfig `yaml:"activate_app" mapstructure:"activate_app"` - RateLimit RateLimitConfig `yaml:"rate_limit" mapstructure:"rate_limit"` + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + FloatingWindow FloatingWindowConfig `yaml:"floating_window" mapstructure:"floating_window"` + Screenshot ScreenshotToolConfig `yaml:"screenshot" mapstructure:"screenshot"` + MouseMove MouseMoveToolConfig `yaml:"mouse_move" mapstructure:"mouse_move"` + MouseClick MouseClickToolConfig `yaml:"mouse_click" mapstructure:"mouse_click"` + MouseScroll MouseScrollToolConfig `yaml:"mouse_scroll" mapstructure:"mouse_scroll"` + KeyboardType KeyboardTypeToolConfig `yaml:"keyboard_type" mapstructure:"keyboard_type"` + GetFocusedApp GetFocusedAppToolConfig `yaml:"get_focused_app" mapstructure:"get_focused_app"` + ActivateApp ActivateAppToolConfig `yaml:"activate_app" mapstructure:"activate_app"` + RateLimit RateLimitConfig `yaml:"rate_limit" mapstructure:"rate_limit"` } // ScreenshotToolConfig contains screenshot-specific tool settings @@ -271,10 +271,19 @@ type ScreenshotToolConfig struct { Quality int `yaml:"quality" mapstructure:"quality"` RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` StreamingEnabled bool `yaml:"streaming_enabled" mapstructure:"streaming_enabled"` - CaptureInterval int `yaml:"capture_interval" mapstructure:"capture_interval"` // seconds - BufferSize int `yaml:"buffer_size" mapstructure:"buffer_size"` // number of screenshots - TempDir string `yaml:"temp_dir" mapstructure:"temp_dir"` // path for disk storage - LogCaptures bool `yaml:"log_captures" mapstructure:"log_captures"` // log every capture (debug) + CaptureInterval int `yaml:"capture_interval" mapstructure:"capture_interval"` + BufferSize int `yaml:"buffer_size" mapstructure:"buffer_size"` + TempDir string `yaml:"temp_dir" mapstructure:"temp_dir"` + LogCaptures bool `yaml:"log_captures" mapstructure:"log_captures"` + ShowOverlay bool `yaml:"show_overlay" mapstructure:"show_overlay"` +} + +// FloatingWindowConfig contains floating progress window settings +type FloatingWindowConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + RespawnOnClose bool `yaml:"respawn_on_close" mapstructure:"respawn_on_close"` + Position string `yaml:"position" mapstructure:"position"` + AlwaysOnTop bool `yaml:"always_on_top" mapstructure:"always_on_top"` } // MouseMoveToolConfig contains mouse move-specific tool settings @@ -1114,8 +1123,13 @@ Write the AGENTS.md file to the project root when you have gathered enough infor Servers: []SSHServerConfig{}, }, ComputerUse: ComputerUseConfig{ - Enabled: false, - RestoreFocusOnApproval: true, // Switch to terminal for approval, then restore focus to original app + Enabled: false, + FloatingWindow: FloatingWindowConfig{ + Enabled: true, + RespawnOnClose: true, + Position: "top-right", + AlwaysOnTop: true, + }, Screenshot: ScreenshotToolConfig{ Enabled: true, MaxWidth: 1920, @@ -1123,11 +1137,12 @@ Write the AGENTS.md file to the project root when you have gathered enough infor Format: "jpeg", Quality: 80, RequireApproval: &[]bool{false}[0], - StreamingEnabled: false, + StreamingEnabled: true, CaptureInterval: 3, BufferSize: 5, TempDir: "", LogCaptures: false, + ShowOverlay: true, }, MouseMove: MouseMoveToolConfig{ Enabled: true, diff --git a/cspell.yaml b/cspell.yaml index ad638051..c73794aa 100644 --- a/cspell.yaml +++ b/cspell.yaml @@ -40,6 +40,7 @@ words: - easyjson - erikgeiser - exclam + - floatingwindow - frontmost - fsnotify - funlen @@ -102,6 +103,7 @@ words: - pmezard - quotedbl - qwen + - Respawn - resty - retryable - rivo diff --git a/internal/display/macos/event_bridge.go b/internal/display/macos/event_bridge.go new file mode 100644 index 00000000..9309767c --- /dev/null +++ b/internal/display/macos/event_bridge.go @@ -0,0 +1,139 @@ +//go:build darwin + +package macos + +import ( + "container/ring" + "sync" + + domain "github.com/inference-gateway/cli/internal/domain" +) + +// EventBridge multicasts chat events to multiple subscribers +// without modifying the existing event flow to the terminal UI. +type EventBridge struct { + subscribers []*subscriber + subMutex sync.RWMutex + eventBuffer *ring.Ring + bufferSize int +} + +type subscriber struct { + ch chan domain.ChatEvent + closed bool + mu sync.Mutex +} + +// NewEventBridge creates a new event bridge with a circular buffer +func NewEventBridge() *EventBridge { + bufferSize := 50 + return &EventBridge{ + subscribers: make([]*subscriber, 0), + eventBuffer: ring.New(bufferSize), + bufferSize: bufferSize, + } +} + +// Publish broadcasts an event to all subscribers +// Non-blocking: if a subscriber's channel is full, the event is dropped for that subscriber +func (eb *EventBridge) Publish(event domain.ChatEvent) { + eb.subMutex.Lock() + eb.eventBuffer.Value = event + eb.eventBuffer = eb.eventBuffer.Next() + subscribers := make([]*subscriber, len(eb.subscribers)) + copy(subscribers, eb.subscribers) + eb.subMutex.Unlock() + + for _, sub := range subscribers { + sub.mu.Lock() + if !sub.closed { + select { + case sub.ch <- event: + default: + } + } + sub.mu.Unlock() + } +} + +// Subscribe creates a new event channel and returns it +// The subscriber will receive all future events published to the bridge +func (eb *EventBridge) Subscribe() chan domain.ChatEvent { + ch := make(chan domain.ChatEvent, 100) + sub := &subscriber{ + ch: ch, + closed: false, + } + + eb.subMutex.Lock() + defer eb.subMutex.Unlock() + + eb.subscribers = append(eb.subscribers, sub) + + eb.eventBuffer.Do(func(val interface{}) { + if val != nil { + event, ok := val.(domain.ChatEvent) + if ok { + select { + case ch <- event: + default: + } + } + } + }) + + return ch +} + +// Unsubscribe removes a subscriber and closes its channel +func (eb *EventBridge) Unsubscribe(ch chan domain.ChatEvent) { + eb.subMutex.Lock() + defer eb.subMutex.Unlock() + + for i, sub := range eb.subscribers { + if sub.ch == ch { + sub.mu.Lock() + if !sub.closed { + close(sub.ch) + sub.closed = true + } + sub.mu.Unlock() + + eb.subscribers = append(eb.subscribers[:i], eb.subscribers[i+1:]...) + break + } + } +} + +// Tap intercepts an event stream and multicasts it to all subscribers +// Returns a new channel that mirrors the input channel for the terminal UI +func (eb *EventBridge) Tap(input <-chan domain.ChatEvent) <-chan domain.ChatEvent { + output := make(chan domain.ChatEvent, 100) + + go func() { + defer close(output) + for event := range input { + output <- event + eb.Publish(event) + } + }() + + return output +} + +// Close closes all subscriber channels and clears the subscribers list +func (eb *EventBridge) Close() { + eb.subMutex.Lock() + defer eb.subMutex.Unlock() + + for _, sub := range eb.subscribers { + sub.mu.Lock() + if !sub.closed { + close(sub.ch) + sub.closed = true + } + sub.mu.Unlock() + } + + eb.subscribers = nil +} diff --git a/internal/display/macos/event_bridge_stub.go b/internal/display/macos/event_bridge_stub.go new file mode 100644 index 00000000..ca994f2a --- /dev/null +++ b/internal/display/macos/event_bridge_stub.go @@ -0,0 +1,34 @@ +//go:build !darwin + +package macos + +import ( + domain "github.com/inference-gateway/cli/internal/domain" +) + +// EventBridge stub for non-darwin platforms +type EventBridge struct{} + +// NewEventBridge creates a stub event bridge +func NewEventBridge() *EventBridge { + return &EventBridge{} +} + +// Tap returns the input channel unchanged on non-darwin platforms +func (eb *EventBridge) Tap(input <-chan domain.ChatEvent) <-chan domain.ChatEvent { + return input +} + +// Publish is a no-op on non-darwin platforms +func (eb *EventBridge) Publish(event domain.ChatEvent) {} + +// Subscribe returns a dummy channel that never receives events +func (eb *EventBridge) Subscribe() chan domain.ChatEvent { + return make(chan domain.ChatEvent) +} + +// Unsubscribe is a no-op on non-darwin platforms +func (eb *EventBridge) Unsubscribe(ch chan domain.ChatEvent) {} + +// Close is a no-op on non-darwin platforms +func (eb *EventBridge) Close() {} diff --git a/internal/display/macos/manager.go b/internal/display/macos/manager.go new file mode 100644 index 00000000..082cf95e --- /dev/null +++ b/internal/display/macos/manager.go @@ -0,0 +1,1099 @@ +//go:build darwin + +package macos + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "runtime" + "sync" + "syscall" + "time" + + config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" + logger "github.com/inference-gateway/cli/internal/logger" +) + +// FloatingWindowManager manages the lifecycle of the floating progress window +type FloatingWindowManager struct { + cfg *config.Config + eventBridge *EventBridge + stateManager domain.StateManager + cmd *exec.Cmd + enabled bool + eventSub chan domain.ChatEvent + stopForward chan struct{} + swiftTmpFile string + monitorWg sync.WaitGroup + // IPC fields (merged from ProcessManager) + stdin io.Writer + stdout io.Reader + stdinMutex sync.Mutex + approvalChans map[string]chan domain.ApprovalAction + approvalMutex sync.RWMutex + stopListener chan struct{} + listenerStopped bool + listenerStoppedMutex sync.Mutex +} + +// NewFloatingWindowManager creates and starts a new floating window manager +func NewFloatingWindowManager(cfg *config.Config, eventBridge *EventBridge, stateManager domain.StateManager) (*FloatingWindowManager, error) { + if runtime.GOOS != "darwin" { + return &FloatingWindowManager{enabled: false}, nil + } + + if !cfg.ComputerUse.Enabled || !cfg.ComputerUse.FloatingWindow.Enabled { + return &FloatingWindowManager{enabled: false}, nil + } + + mgr := &FloatingWindowManager{ + cfg: cfg, + eventBridge: eventBridge, + stateManager: stateManager, + enabled: true, + stopForward: make(chan struct{}), + approvalChans: make(map[string]chan domain.ApprovalAction), + stopListener: make(chan struct{}), + listenerStopped: false, + } + + if err := mgr.launchWindow(); err != nil { + return nil, fmt.Errorf("failed to launch floating window: %w", err) + } + + mgr.eventSub = eventBridge.Subscribe() + go mgr.forwardEvents() + + mgr.monitorWg.Add(1) + go mgr.monitorProcess() + + return mgr, nil +} + +// launchWindow starts the Swift window process +func (mgr *FloatingWindowManager) launchWindow() error { + swiftScript := mgr.generateSwiftScript() + + tmpDir := mgr.cfg.GetConfigDir() + "/tmp" + if err := os.MkdirAll(tmpDir, 0755); err != nil { + return fmt.Errorf("failed to create temp directory: %w", err) + } + + tmpFile, err := os.CreateTemp(tmpDir, "floating_window_*.swift") + if err != nil { + return fmt.Errorf("failed to create temp file: %w", err) + } + + if _, err := tmpFile.Write([]byte(swiftScript)); err != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpFile.Name()) + return fmt.Errorf("failed to write Swift script: %w", err) + } + if err := tmpFile.Close(); err != nil { + return fmt.Errorf("failed to close temp file: %w", err) + } + + mgr.swiftTmpFile = tmpFile.Name() + + cmd := exec.Command("swift", tmpFile.Name()) + + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to create stdin pipe: %w", err) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderr, err := cmd.StderrPipe() + if err != nil { + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("failed to start Swift process: %w", err) + } + + go func() { + buf := make([]byte, 1024) + for { + n, err := stderr.Read(buf) + if n > 0 { + logger.Debug("Swift stderr", "output", string(buf[:n])) + } + if err != nil { + break + } + } + }() + + mgr.stdin = stdin + mgr.stdout = stdout + mgr.cmd = cmd + + go mgr.startApprovalListener() + + return nil +} + +// forwardEvents forwards chat events from the EventBridge to the Swift window +func (mgr *FloatingWindowManager) forwardEvents() { + for { + select { + case event := <-mgr.eventSub: + if approvalEvent, ok := event.(domain.ToolApprovalRequestedEvent); ok { + mgr.registerApprovalChannel(approvalEvent.ToolCall.Id, approvalEvent.ResponseChan) + } + + if err := mgr.writeEvent(event); err != nil { + logger.Warn("Failed to forward event to window", "error", err) + } + + case <-mgr.stopForward: + logger.Debug("Event forwarding stopped") + return + } + } +} + +// monitorProcess watches the Swift process and respawns if configured +func (mgr *FloatingWindowManager) monitorProcess() { + defer mgr.monitorWg.Done() + + if mgr.cmd == nil { + return + } + + err := mgr.cmd.Wait() + if err != nil { + logger.Error("Swift process exited", "error", err) + } + + if mgr.enabled && mgr.cfg.ComputerUse.FloatingWindow.RespawnOnClose { + time.Sleep(1 * time.Second) + + if err := mgr.launchWindow(); err != nil { + logger.Error("Failed to respawn floating window", "error", err) + return + } + + mgr.monitorWg.Add(1) + go mgr.monitorProcess() + } +} + +// Shutdown gracefully shuts down the floating window manager +func (mgr *FloatingWindowManager) Shutdown() error { + if !mgr.enabled { + return nil + } + + logger.Info("Shutting down floating window manager") + + mgr.enabled = false + + close(mgr.stopForward) + + if mgr.eventSub != nil { + mgr.eventBridge.Unsubscribe(mgr.eventSub) + } + + mgr.stopApprovalListener() + + if err := mgr.shutdownProcess(); err != nil { + return err + } + + mgr.monitorWg.Wait() + + if mgr.swiftTmpFile != "" { + if err := os.Remove(mgr.swiftTmpFile); err != nil { + logger.Debug("Failed to remove temp Swift file", "error", err, "path", mgr.swiftTmpFile) + } else { + logger.Debug("Removed temp Swift file", "path", mgr.swiftTmpFile) + } + mgr.swiftTmpFile = "" + } + + logger.Info("Floating window manager shutdown complete") + + return nil +} + +// shutdownProcess terminates the Swift process gracefully +func (mgr *FloatingWindowManager) shutdownProcess() error { + if mgr.cmd == nil || mgr.cmd.Process == nil { + return nil + } + + return mgr.sendTermSignal() +} + +// sendTermSignal sends SIGTERM to the Swift process, falls back to SIGKILL if needed +func (mgr *FloatingWindowManager) sendTermSignal() error { + if err := mgr.cmd.Process.Signal(syscall.SIGTERM); err != nil { + logger.Debug("Failed to send SIGTERM, using SIGKILL", "error", err) + if killErr := mgr.cmd.Process.Kill(); killErr != nil { + logger.Warn("Failed to kill Swift process", "error", killErr) + return fmt.Errorf("failed to kill process: %w", killErr) + } + } + return nil +} + +// IPC Methods (merged from ProcessManager) + +// writeEvent sends an event to the Swift process via stdin +func (mgr *FloatingWindowManager) writeEvent(event domain.ChatEvent) error { + mgr.stdinMutex.Lock() + defer mgr.stdinMutex.Unlock() + + data, err := json.Marshal(event) + if err != nil { + return fmt.Errorf("failed to marshal event: %w", err) + } + + if _, err := fmt.Fprintf(mgr.stdin, "%s\n", data); err != nil { + logger.Warn("Failed to write event to window", "error", err) + return fmt.Errorf("failed to write to stdin: %w", err) + } + + return nil +} + +// startApprovalListener reads approval responses from the Swift process via stdout +func (mgr *FloatingWindowManager) startApprovalListener() { + scanner := bufio.NewScanner(mgr.stdout) + for scanner.Scan() { + select { + case <-mgr.stopListener: + logger.Debug("Approval listener stopped") + return + default: + } + + line := scanner.Text() + + if line == "" { + continue + } + + var response ApprovalResponse + if err := json.Unmarshal([]byte(line), &response); err != nil { + logger.Warn("Failed to parse approval response", "error", err, "line", line) + continue + } + + mgr.handleApprovalResponse(response) + } + + if err := scanner.Err(); err != nil { + mgr.listenerStoppedMutex.Lock() + if !mgr.listenerStopped { + logger.Warn("Approval listener error", "error", err) + } + mgr.listenerStoppedMutex.Unlock() + } +} + +// handleApprovalResponse processes an approval response from the window +func (mgr *FloatingWindowManager) handleApprovalResponse(resp ApprovalResponse) { + mgr.approvalMutex.Lock() + defer mgr.approvalMutex.Unlock() + + logger.Debug("handleApprovalResponse called", "call_id", resp.CallID, "registered_channels", len(mgr.approvalChans)) + + ch, exists := mgr.approvalChans[resp.CallID] + if !exists { + logger.Warn("Received approval for unknown call ID", "call_id", resp.CallID, "known_call_ids", mgr.getCallIDs()) + return + } + + logger.Info("Sending approval to channel", "call_id", resp.CallID, "action", resp.Action) + + select { + case ch <- resp.Action: + delete(mgr.approvalChans, resp.CallID) + logger.Debug("Approval processed", "call_id", resp.CallID, "action", resp.Action) + + if mgr.stateManager != nil { + mgr.stateManager.ClearApprovalUIState() + logger.Debug("Cleared approval UI state from floating window") + } + default: + logger.Warn("Approval channel blocked", "call_id", resp.CallID) + } +} + +// registerApprovalChannel registers a response channel for a specific tool call +func (mgr *FloatingWindowManager) registerApprovalChannel(callID string, ch chan domain.ApprovalAction) { + mgr.approvalMutex.Lock() + defer mgr.approvalMutex.Unlock() + + mgr.approvalChans[callID] = ch + logger.Debug("Registered approval channel", "call_id", callID) +} + +// getCallIDs returns a list of registered call IDs (for debugging) +func (mgr *FloatingWindowManager) getCallIDs() []string { + ids := make([]string, 0, len(mgr.approvalChans)) + for id := range mgr.approvalChans { + ids = append(ids, id) + } + return ids +} + +// stopApprovalListener signals the approval listener to stop +func (mgr *FloatingWindowManager) stopApprovalListener() { + mgr.listenerStoppedMutex.Lock() + defer mgr.listenerStoppedMutex.Unlock() + + if !mgr.listenerStopped { + close(mgr.stopListener) + mgr.listenerStopped = true + logger.Debug("Approval listener stopped") + } +} + +// generateSwiftScript generates the Swift script for the floating window +// +//nolint:funlen // Swift script embedding requires long function +func (mgr *FloatingWindowManager) generateSwiftScript() string { + position := mgr.cfg.ComputerUse.FloatingWindow.Position + alwaysOnTop := mgr.cfg.ComputerUse.FloatingWindow.AlwaysOnTop + + return fmt.Sprintf(` +import Cocoa +import Foundation +import WebKit + +// MARK: - Models + +struct ApprovalResponse: Codable { + let call_id: String + let action: Int // 0=Approve, 1=Reject, 2=AutoAccept +} + +// MARK: - Window Setup + +class AgentProgressWindow: NSPanel { + let webView = WKWebView() + var isTerminalReady = false + var isMinimized = false + var fullFrame: NSRect? + let minimizedWidth: CGFloat = 40 + let minimizedHeight: CGFloat = 150 + + init() { + let screenFrame = NSScreen.main!.visibleFrame + let windowWidth: CGFloat = 450 + let windowHeight: CGFloat = 600 + + // Position based on configuration + var xPos: CGFloat + let position = "%s" + switch position { + case "top-left": + xPos = screenFrame.minX + 20 + case "top-right": + xPos = screenFrame.maxX - windowWidth - 20 + default: + xPos = screenFrame.maxX - windowWidth - 20 + } + + let yPos = screenFrame.maxY - windowHeight - 20 + let frame = NSRect(x: xPos, y: yPos, width: windowWidth, height: windowHeight) + + let styleMask: NSWindow.StyleMask = [.titled, .resizable, .miniaturizable, .fullSizeContentView] + super.init(contentRect: frame, styleMask: styleMask, backing: .buffered, defer: false) + + self.title = "Computer Use" + self.isFloatingPanel = true + self.level = %t ? .floating : .normal + self.collectionBehavior = [.canJoinAllSpaces, .fullScreenAuxiliary] + self.hidesOnDeactivate = false + + self.isOpaque = false + self.alphaValue = 0.90 + self.backgroundColor = NSColor(calibratedWhite: 0.1, alpha: 0.85) + + self.contentView?.wantsLayer = true + if let layer = self.contentView?.layer { + layer.cornerRadius = 12 + layer.masksToBounds = false + } + + self.hasShadow = true + self.invalidateShadow() + + self.titlebarAppearsTransparent = true + self.titleVisibility = .visible + + self.isMovableByWindowBackground = true + + self.standardWindowButton(.closeButton)?.alphaValue = 0 + self.standardWindowButton(.zoomButton)?.alphaValue = 0 + + if let minimizeButton = self.standardWindowButton(.miniaturizeButton) { + minimizeButton.target = self + minimizeButton.action = #selector(customMinimize) + } + + setupUI() + + self.orderFront(nil) + } + + @objc func customMinimize() { + if isMinimized { + restoreWindow() + } else { + minimizeToSide() + } + } + + func minimizeToSide() { + guard let screen = NSScreen.main else { return } + isMinimized = true + fullFrame = self.frame + + let screenFrame = screen.visibleFrame + let xPos = screenFrame.maxX - minimizedWidth + let yPos = screenFrame.midY - (minimizedHeight / 2) + let minimizedFrame = NSRect(x: xPos, y: yPos, width: minimizedWidth, height: minimizedHeight) + + NSAnimationContext.runAnimationGroup({ context in + context.duration = 0.3 + context.timingFunction = CAMediaTimingFunction(name: .easeInEaseOut) + self.animator().setFrame(minimizedFrame, display: true) + self.animator().alphaValue = 1.0 + }, completionHandler: { + self.webView.isHidden = true + self.titleVisibility = .hidden + self.titlebarAppearsTransparent = true + self.standardWindowButton(.closeButton)?.alphaValue = 0 + self.standardWindowButton(.miniaturizeButton)?.alphaValue = 0 + self.standardWindowButton(.zoomButton)?.alphaValue = 0 + self.isOpaque = true + self.backgroundColor = NSColor(calibratedWhite: 0.1, alpha: 1.0) + self.updateMinimizedUI() + }) + } + + func restoreWindow() { + guard let savedFrame = fullFrame else { return } + isMinimized = false + + self.titleVisibility = .visible + self.titlebarAppearsTransparent = true + self.standardWindowButton(.closeButton)?.alphaValue = 0 + self.standardWindowButton(.miniaturizeButton)?.alphaValue = 1.0 + self.standardWindowButton(.zoomButton)?.alphaValue = 0 + self.isOpaque = false + self.backgroundColor = NSColor(calibratedWhite: 0.1, alpha: 0.85) + + if let contentView = self.contentView { + contentView.subviews.forEach { view in + if view.identifier?.rawValue == "minimizedLabel" { + view.removeFromSuperview() + } + } + } + + self.webView.isHidden = false + + NSAnimationContext.runAnimationGroup({ context in + context.duration = 0.3 + context.timingFunction = CAMediaTimingFunction(name: .easeInEaseOut) + self.animator().setFrame(savedFrame, display: true) + self.animator().alphaValue = 0.95 + }, completionHandler: nil) + } + + func updateMinimizedUI() { + guard let contentView = self.contentView else { return } + + contentView.subviews.forEach { view in + if view.identifier?.rawValue == "minimizedLabel" { + view.removeFromSuperview() + } + } + + // Create a simple dot indicator, vertically centered + let labelHeight: CGFloat = 30 + let labelY = (minimizedHeight - labelHeight) / 2 + let label = NSTextField(labelWithString: "●") + label.identifier = NSUserInterfaceItemIdentifier("minimizedLabel") + label.frame = NSRect(x: 0, y: labelY, width: minimizedWidth, height: labelHeight) + label.alignment = .center + label.font = NSFont.systemFont(ofSize: 20) + label.textColor = NSColor(calibratedRed: 0.48, green: 0.64, blue: 0.97, alpha: 1.0) // Blue accent color + label.backgroundColor = .clear + label.isBordered = false + label.isEditable = false + label.isSelectable = false + contentView.addSubview(label) + } + + override func mouseDown(with event: NSEvent) { + super.mouseDown(with: event) + if isMinimized { + customMinimize() + } + } + + func setupUI() { + guard let contentView = self.contentView else { return } + + // WebView leaves 30px at top for draggable title bar + let titleBarHeight: CGFloat = 30 + webView.frame = NSRect(x: 0, y: 0, width: contentView.bounds.width, height: contentView.bounds.height - titleBarHeight) + webView.autoresizingMask = [.width, .height] + webView.setValue(false, forKey: "drawsBackground") + + let html = """ + + + + + + + + + +
+
+
+ + + +
+
+ + + + """ + + let userController = webView.configuration.userContentController + userController.add(self, name: "terminalReady") + userController.add(self, name: "approval") + + let consoleScript = """ + console.log = function(msg) { + window.webkit.messageHandlers.consoleLog.postMessage(String(msg)); + }; + """ + let consoleUserScript = WKUserScript(source: consoleScript, injectionTime: .atDocumentStart, forMainFrameOnly: true) + userController.addUserScript(consoleUserScript) + userController.add(self, name: "consoleLog") + + webView.loadHTMLString(html, baseURL: nil) + contentView.addSubview(webView) + } + + func escapeForJS(_ text: String) -> String { + return text.replacingOccurrences(of: "\\", with: "\\\\") + .replacingOccurrences(of: "'", with: "\\'") + .replacingOccurrences(of: "\n", with: "\\n") + .replacingOccurrences(of: "\r", with: "") + } + + func writeToTerminal(_ text: String) { + guard isTerminalReady else { return } + let escaped = escapeForJS(text) + let js = "window.term.write('\(escaped)');" + webView.evaluateJavaScript(js, completionHandler: nil) + } + + func writeLineToTerminal(_ text: String) { + guard isTerminalReady else { return } + let escaped = escapeForJS(text) + let js = "window.term.writeln('\(escaped)');" + webView.evaluateJavaScript(js, completionHandler: nil) + } + + func formatToolArguments(_ jsonString: String) -> String { + guard let data = jsonString.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] else { + return jsonString.count > 100 ? String(jsonString.prefix(100)) + "..." : jsonString + } + + var lines: [String] = [] + for (key, value) in json.sorted(by: { $0.key < $1.key }) { + let valueStr: String + if let str = value as? String { + valueStr = str.count > 60 ? String(str.prefix(60)) + "..." : str + } else if let num = value as? NSNumber { + valueStr = "\(num)" + } else { + valueStr = "\(value)" + } + lines.append("\(key): \(valueStr)") + } + return lines.joined(separator: "\n") + } + + func addEvent(type: String, description: String, callID: String? = nil, toolName: String? = nil, toolArgs: String? = nil) { + DispatchQueue.main.async { + let esc = "\u{001B}" + let cyan = "\(esc)[36m" + let yellow = "\(esc)[33m" + let magenta = "\(esc)[35m" + let gray = "\(esc)[90m" + let reset = "\(esc)[0m" + + switch type { + case "Chat Start": + self.writeLineToTerminal("") + self.writeLineToTerminal("\(cyan)●\(reset) Starting...") + self.writeLineToTerminal("") + case "Chat Chunk": + self.writeToTerminal(description) + case "Tool Approval": + if let cid = callID, let tool = toolName { + self.showApprovalButtons(callID: cid, toolName: tool) + } + case "Tool Execution": + if let tool = toolName { + let green = "\(esc)[32m" + let blue = "\(esc)[34m" + let bold = "\(esc)[1m" + let dim = "\(esc)[2m" + + self.writeLineToTerminal("") + self.writeLineToTerminal("\(blue)▶\(reset) \(bold)\(tool)\(reset)") + + // Format arguments nicely + if let args = toolArgs, !args.isEmpty && args != "{}" { + let formattedArgs = self.formatToolArguments(args) + for line in formattedArgs.split(separator: "\n") { + self.writeLineToTerminal(" \(dim)\(line)\(reset)") + } + } + } else { + self.writeLineToTerminal("\(gray) \(description)\(reset)") + } + case "Tool Failed", "Tool Rejected": + let red = "\(esc)[31m" + let bold = "\(esc)[1m" + self.writeLineToTerminal("") + self.writeLineToTerminal("\(red)✗ \(bold)\(description)\(reset)") + self.writeLineToTerminal("") + case "Approval Cleared": + self.hideApprovalBox() + case "Cancelled": + let red = "\(esc)[31m" + let bold = "\(esc)[1m" + self.writeLineToTerminal("") + self.writeLineToTerminal("\(red)✗ \(bold)\(description)\(reset)") + self.writeLineToTerminal("") + case "Event": + // Skip generic events to reduce noise + break + case "Optimization": + self.writeLineToTerminal("") + self.writeLineToTerminal("\(magenta)⚡ \(description)\(reset)") + default: + self.writeLineToTerminal("\(gray)[\(type)] \(description)\(reset)") + } + } + } + + func hideApprovalBox() { + guard isTerminalReady else { return } + let js = "document.getElementById('approvalBox').classList.remove('visible'); window.currentCallID = null;" + webView.evaluateJavaScript(js, completionHandler: nil) + } + + func showApprovalButtons(callID: String, toolName: String) { + guard isTerminalReady else { + return + } + let escapedCallID = escapeForJS(callID) + let escapedToolName = escapeForJS(toolName) + let js = "window.showApproval('\(escapedCallID)', '\(escapedToolName)');" + webView.evaluateJavaScript(js, completionHandler: nil) + } + + func sendApproval(callID: String, action: Int) { + let response = ApprovalResponse(call_id: callID, action: action) + if let jsonData = try? JSONEncoder().encode(response), + let jsonString = String(data: jsonData, encoding: .utf8) { + print(jsonString) // Send to stdout + fflush(stdout) + } + } +} + +// MARK: - WebKit Message Handler + +extension AgentProgressWindow: WKScriptMessageHandler { + func userContentController(_ userContentController: WKUserContentController, didReceive message: WKScriptMessage) { + if message.name == "terminalReady" { + isTerminalReady = true + fputs("Terminal ready for output\n", stderr) + } else if message.name == "approval", + let data = message.body as? [String: Any], + let callID = data["call_id"] as? String, + let action = data["action"] as? Int { + fputs("Received approval from UI: callID=\(callID), action=\(action)\n", stderr) + sendApproval(callID: callID, action: action) + } else if message.name == "consoleLog" { + fputs("JS console: \(message.body)\n", stderr) + } + } +} + +// MARK: - Event Reading + +class EventReader { + let window: AgentProgressWindow + + init(window: AgentProgressWindow) { + self.window = window + } + + func startReading() { + DispatchQueue.global(qos: .userInitiated).async { + let handle = FileHandle.standardInput + + while true { + var lineData = Data() + + while true { + do { + guard let byte = try handle.read(upToCount: 1), !byte.isEmpty else { + return + } + + if byte[0] == 10 { + break + } + lineData.append(byte[0]) + } catch { + fputs("Read error: \(error)\n", stderr) + return + } + } + + if let line = String(data: lineData, encoding: .utf8), !line.isEmpty { + self.handleEvent(line) + } + } + } + } + + func handleEvent(_ jsonString: String) { + guard let data = jsonString.data(using: .utf8) else { + fputs("ERROR: Failed to convert string to data\n", stderr) + return + } + + do { + if let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] { + + let eventType = self.extractEventType(from: json) + let description = self.extractDescription(from: json) + + var callID: String? = nil + var toolName: String? = nil + var toolArgs: String? = nil + + if eventType == "Parallel Tools Start" { + if let tools = json["Tools"] as? [[String: Any]] { + for tool in tools { + let tName = tool["Name"] as? String + let tArgs = tool["Arguments"] as? String + + DispatchQueue.main.async { + self.window.addEvent(type: "Tool Execution", description: "", callID: nil, toolName: tName, toolArgs: tArgs) + } + } + } + } else if eventType == "Chat Complete" { + if let toolCalls = json["ToolCalls"] as? [[String: Any]] { + for toolCall in toolCalls { + if let function = toolCall["function"] as? [String: Any] { + let tName = function["name"] as? String + let tArgs = function["arguments"] as? String + + DispatchQueue.main.async { + self.window.addEvent(type: "Tool Execution", description: "", callID: nil, toolName: tName, toolArgs: tArgs) + } + } + } + } + } else if eventType == "Tool Execution" { + toolName = json["ToolName"] as? String + toolArgs = json["Arguments"] as? String + } else if eventType == "Tool Execution Progress" { + if let status = json["Status"] as? String, status == "failed" { + if let tName = json["ToolName"] as? String { + let failureMsg = "Tool: \(tName) failed" + DispatchQueue.main.async { + self.window.addEvent(type: "Tool Failed", description: failureMsg, callID: nil, toolName: nil, toolArgs: nil) + } + } + } + return + } else if eventType == "Tool Approval" { + if let toolCall = json["ToolCall"] as? [String: Any] { + callID = toolCall["id"] as? String + if let function = toolCall["function"] as? [String: Any] { + toolName = function["name"] as? String + } + } + } + + if eventType != "Chat Chunk" { + fputs("Parsed event: type=\(eventType), desc=\(description), callID=\(callID ?? "nil")\n", stderr) + } + + // Only call addEvent for events we haven't already handled + if eventType != "Parallel Tools Start" && eventType != "Chat Complete" { + DispatchQueue.main.async { + self.window.addEvent(type: eventType, description: description, callID: callID, toolName: toolName, toolArgs: toolArgs) + } + } + } + } catch { + fputs("ERROR: JSON parse error: \(error)\n", stderr) + } + } + + func extractEventType(from json: [String: Any]) -> String { + + if json["Tools"] != nil { + return "Parallel Tools Start" + } + + if json["Content"] != nil { + return "Chat Chunk" + } + + if json["Message"] != nil && json["IsActive"] != nil { + return "Optimization" + } + + if json["ToolCalls"] != nil { + return "Chat Complete" + } + + if json["ToolCallID"] != nil && json["ToolName"] != nil && json["Status"] != nil { + return "Tool Execution Progress" + } + + if json["ToolName"] != nil && json["Arguments"] != nil { + return "Tool Execution" + } + + if json["ToolCall"] != nil { + return "Tool Approval" + } + + if json["Reason"] != nil { + return "Cancelled" + } + + if json["Model"] != nil { + return "Chat Start" + } + + if json["RequestID"] != nil && json["Timestamp"] != nil && + json["Content"] == nil && json["ToolCall"] == nil && json["Model"] == nil && json["Message"] == nil { + return "Approval Cleared" + } + + if json["RequestID"] != nil { + return "Event" + } + return "Unknown" + } + + func extractDescription(from json: [String: Any]) -> String { + // For Content field (ChatChunkEvent), preserve ALL whitespace including spaces and newlines + if let content = json["Content"] as? String { + return content // Don't trim - spaces and newlines are important! + } + + if let reason = json["Reason"] as? String { + return "Interrupted: \(reason)" + } + + if let message = json["Message"] as? String { + return message + } + + if let model = json["Model"] as? String { + return "Model: \(model)" + } + + if let toolCall = json["ToolCall"] as? [String: Any], + let function = toolCall["function"] as? [String: Any], + let toolName = function["name"] as? String { + return "Tool approval: \(toolName)" + } + + if let status = json["Status"] as? String { + return status + } + + if let error = json["Error"] as? String { + return "Error: \(error)" + } + + if let toolName = json["ToolName"] as? String { + if let status = json["Status"] as? String { + return "\(toolName): \(status)" + } + return "Tool: \(toolName)" + } + + if json["RequestID"] != nil { + return "Event received" + } + return "No description" + } +} + +// MARK: - Main + +signal(SIGTERM, SIG_IGN) +let sigTermSource = DispatchSource.makeSignalSource(signal: SIGTERM, queue: .main) +sigTermSource.setEventHandler { + exit(0) +} +sigTermSource.resume() + +NSApplication.shared.setActivationPolicy(.accessory) +NSApplication.shared.activate(ignoringOtherApps: true) + +let window = AgentProgressWindow() + +let reader = EventReader(window: window) +reader.startReading() + +NSApplication.shared.run() +`, position, alwaysOnTop) +} diff --git a/internal/display/macos/manager_stub.go b/internal/display/macos/manager_stub.go new file mode 100644 index 00000000..9359bfc4 --- /dev/null +++ b/internal/display/macos/manager_stub.go @@ -0,0 +1,23 @@ +//go:build !darwin + +package macos + +import ( + config "github.com/inference-gateway/cli/config" + domain "github.com/inference-gateway/cli/internal/domain" +) + +// FloatingWindowManager stub for non-macOS platforms +type FloatingWindowManager struct { + enabled bool +} + +// NewFloatingWindowManager returns a disabled manager on non-macOS platforms +func NewFloatingWindowManager(cfg *config.Config, eventBridge *EventBridge, stateManager domain.StateManager) (*FloatingWindowManager, error) { + return &FloatingWindowManager{enabled: false}, nil +} + +// Shutdown is a no-op on non-macOS platforms +func (mgr *FloatingWindowManager) Shutdown() error { + return nil +} diff --git a/internal/display/macos/overlay_darwin.go b/internal/display/macos/overlay_darwin.go new file mode 100644 index 00000000..8820407d --- /dev/null +++ b/internal/display/macos/overlay_darwin.go @@ -0,0 +1,155 @@ +//go:build darwin + +package macos + +import ( + "fmt" + "os/exec" + "syscall" + "time" + + logger "github.com/inference-gateway/cli/internal/logger" +) + +// OverlayWindow represents a macOS overlay indicator using osascript +type OverlayWindow struct { + cmd *exec.Cmd + visible bool +} + +// NewOverlayWindow creates a new macOS overlay window using persistent alert +func NewOverlayWindow() (*OverlayWindow, error) { + logger.Info("Creating macOS overlay window") + return &OverlayWindow{ + visible: false, + }, nil +} + +// Show displays a screen border overlay using Swift +func (w *OverlayWindow) Show() error { + logger.Info("Attempting to show macOS screen border overlay") + + swiftScript := ` +import Cocoa +import Foundation + +class BorderWindow: NSWindow { + init(frame: NSRect, color: NSColor) { + super.init(contentRect: frame, styleMask: .borderless, backing: .buffered, defer: false) + self.backgroundColor = color + self.isOpaque = false + self.level = .floating + self.ignoresMouseEvents = true + self.collectionBehavior = [.canJoinAllSpaces, .stationary] + self.orderFront(nil) + } +} + +signal(SIGTERM, SIG_IGN) +let sigTermSource = DispatchSource.makeSignalSource(signal: SIGTERM, queue: .main) +sigTermSource.setEventHandler { + exit(0) +} +sigTermSource.resume() + +let screen = NSScreen.main! +let frame = screen.visibleFrame +let borderWidth: CGFloat = 3 +let borderColor = NSColor(red: 0.3, green: 0.6, blue: 1.0, alpha: 0.95) + +_ = BorderWindow(frame: NSRect(x: frame.minX, y: frame.maxY - borderWidth, width: frame.width, height: borderWidth), color: borderColor) +_ = BorderWindow(frame: NSRect(x: frame.minX, y: frame.minY, width: frame.width, height: borderWidth), color: borderColor) +_ = BorderWindow(frame: NSRect(x: frame.minX, y: frame.minY, width: borderWidth, height: frame.height), color: borderColor) +_ = BorderWindow(frame: NSRect(x: frame.maxX - borderWidth, y: frame.minY, width: borderWidth, height: frame.height), color: borderColor) + +RunLoop.main.run() +` + + logger.Debug("Compiling and running Swift screen border overlay") + + cmd := exec.Command("swift", "-") + stdin, err := cmd.StdinPipe() + if err != nil { + logger.Error("Failed to create stdin pipe", "error", err) + return fmt.Errorf("failed to create stdin pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + logger.Error("Failed to start Swift process", "error", err) + return fmt.Errorf("failed to start Swift process: %w", err) + } + + go func() { + defer func() { _ = stdin.Close() }() + if _, err := stdin.Write([]byte(swiftScript)); err != nil { + logger.Error("Failed to write Swift script to stdin", "error", err) + } + }() + + w.cmd = cmd + w.visible = true + logger.Info("Screen border overlay shown successfully") + return nil +} + +// Hide hides the screen border overlay by terminating the process gracefully +func (w *OverlayWindow) Hide() error { + logger.Info("Hiding macOS screen border overlay") + + if w.cmd == nil { + w.visible = false + return nil + } + + if w.cmd.Process == nil { + w.cmd = nil + w.visible = false + return nil + } + + cmd := w.cmd + w.cmd = nil + w.visible = false + + if err := cmd.Process.Signal(syscall.SIGTERM); err != nil { + logger.Debug("Failed to send SIGTERM, using SIGKILL", "error", err) + if err := cmd.Process.Kill(); err != nil { + logger.Warn("Failed to kill overlay process", "error", err) + return fmt.Errorf("failed to kill overlay process: %w", err) + } + } + + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case err := <-done: + if err != nil { + logger.Debug("Overlay process exited with error", "error", err) + } else { + logger.Debug("Overlay process exited cleanly") + } + case <-time.After(5 * time.Second): + logger.Warn("Overlay process did not exit within timeout, force killing") + if err := cmd.Process.Kill(); err != nil { + logger.Error("Failed to force kill overlay process", "error", err) + return fmt.Errorf("failed to force kill overlay process: %w", err) + } + <-done + } + + return nil +} + +// IsVisible returns whether the overlay is currently visible +func (w *OverlayWindow) IsVisible() bool { + return w.visible +} + +// Destroy cleans up the screen border overlay by killing the process +func (w *OverlayWindow) Destroy() error { + logger.Info("Destroying macOS screen border overlay") + return w.Hide() +} diff --git a/internal/display/macos/overlay_darwin_test.go b/internal/display/macos/overlay_darwin_test.go new file mode 100644 index 00000000..5de72ac2 --- /dev/null +++ b/internal/display/macos/overlay_darwin_test.go @@ -0,0 +1,149 @@ +//go:build darwin + +package macos + +import ( + "runtime" + "testing" + "time" + + assert "github.com/stretchr/testify/assert" + require "github.com/stretchr/testify/require" +) + +func TestOverlayWindow_Creation(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS only test") + } + + overlay, err := NewOverlayWindow() + require.NoError(t, err) + require.NotNil(t, overlay) + require.False(t, overlay.visible) + defer func() { + if overlay != nil { + _ = overlay.Destroy() + } + }() +} + +func TestOverlayWindow_ShowHide(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS only test") + } + + overlay, err := NewOverlayWindow() + require.NoError(t, err) + require.NotNil(t, overlay) + defer func() { _ = overlay.Destroy() }() + + err = overlay.Show() + require.NoError(t, err) + assert.True(t, overlay.visible) + + time.Sleep(100 * time.Millisecond) + + assert.True(t, overlay.IsVisible()) + + err = overlay.Hide() + require.NoError(t, err) + assert.False(t, overlay.visible) + + time.Sleep(100 * time.Millisecond) + + assert.False(t, overlay.IsVisible()) +} + +func TestOverlayWindow_Lifecycle(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS only test") + } + + overlay, err := NewOverlayWindow() + require.NoError(t, err) + require.NotNil(t, overlay) + + err = overlay.Show() + require.NoError(t, err) + assert.True(t, overlay.visible) + + time.Sleep(100 * time.Millisecond) + + err = overlay.Hide() + require.NoError(t, err) + assert.False(t, overlay.visible) + + err = overlay.Destroy() + require.NoError(t, err) + assert.False(t, overlay.visible) +} + +func TestOverlayWindow_DestroyWithoutShow(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS only test") + } + + overlay, err := NewOverlayWindow() + require.NoError(t, err) + require.NotNil(t, overlay) + + err = overlay.Destroy() + require.NoError(t, err) +} + +func TestOverlayWindow_OperationsOnEmptyWindow(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS only test") + } + + overlay := &OverlayWindow{cmd: nil, visible: false} + + err := overlay.Hide() + assert.NoError(t, err) + + err = overlay.Destroy() + assert.NoError(t, err) + + assert.False(t, overlay.IsVisible()) +} + +func TestOverlayWindow_MultipleShowCalls(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS only test") + } + + overlay, err := NewOverlayWindow() + require.NoError(t, err) + require.NotNil(t, overlay) + defer func() { _ = overlay.Destroy() }() + + err = overlay.Show() + require.NoError(t, err) + + err = overlay.Show() + require.NoError(t, err) + + assert.True(t, overlay.visible) +} + +func TestOverlayWindow_MultipleHideCalls(t *testing.T) { + if runtime.GOOS != "darwin" { + t.Skip("macOS only test") + } + + overlay, err := NewOverlayWindow() + require.NoError(t, err) + require.NotNil(t, overlay) + defer func() { _ = overlay.Destroy() }() + + err = overlay.Show() + require.NoError(t, err) + + err = overlay.Hide() + require.NoError(t, err) + + err = overlay.Hide() + require.NoError(t, err) + + assert.False(t, overlay.visible) +} diff --git a/internal/display/macos/overlay_stub.go b/internal/display/macos/overlay_stub.go new file mode 100644 index 00000000..85e281bc --- /dev/null +++ b/internal/display/macos/overlay_stub.go @@ -0,0 +1,33 @@ +//go:build !darwin + +package macos + +import "fmt" + +// OverlayWindow is a stub for non-macOS platforms +type OverlayWindow struct{} + +// NewOverlayWindow returns an error on non-macOS platforms +func NewOverlayWindow() (*OverlayWindow, error) { + return nil, fmt.Errorf("overlay window only supported on macOS") +} + +// Show is a no-op on non-macOS platforms +func (w *OverlayWindow) Show() error { + return fmt.Errorf("overlay window only supported on macOS") +} + +// Hide is a no-op on non-macOS platforms +func (w *OverlayWindow) Hide() error { + return fmt.Errorf("overlay window only supported on macOS") +} + +// IsVisible always returns false on non-macOS platforms +func (w *OverlayWindow) IsVisible() bool { + return false +} + +// Destroy is a no-op on non-macOS platforms +func (w *OverlayWindow) Destroy() error { + return fmt.Errorf("overlay window only supported on macOS") +} diff --git a/internal/display/macos/types.go b/internal/display/macos/types.go new file mode 100644 index 00000000..091463fe --- /dev/null +++ b/internal/display/macos/types.go @@ -0,0 +1,39 @@ +//go:build darwin + +package macos + +import ( + domain "github.com/inference-gateway/cli/internal/domain" +) + +// ApprovalResponse represents a response from the Swift window +// when the user approves or rejects a tool execution +type ApprovalResponse struct { + CallID string `json:"call_id"` + Action domain.ApprovalAction `json:"action"` +} + +// WindowEvent wraps domain.ChatEvent with additional metadata for the window +type WindowEvent struct { + Type string `json:"type"` + Timestamp int64 `json:"timestamp"` + Data interface{} `json:"data"` +} + +// WindowState represents the current state snapshot of the floating window +// Used for reconnection and state synchronization +type WindowState struct { + SessionID string `json:"session_id"` + IsActive bool `json:"is_active"` + CurrentStatus string `json:"current_status"` + PendingApprovals []PendingApproval `json:"pending_approvals"` + ActivityCount int `json:"activity_count"` +} + +// PendingApproval represents a tool waiting for user approval +type PendingApproval struct { + CallID string `json:"call_id"` + ToolName string `json:"tool_name"` + Arguments map[string]interface{} `json:"arguments"` + Timestamp int64 `json:"timestamp"` +} diff --git a/internal/display/wayland/client.go b/internal/display/wayland/client.go index 2e94eb8f..6353327b 100644 --- a/internal/display/wayland/client.go +++ b/internal/display/wayland/client.go @@ -133,17 +133,13 @@ func (c *WaylandClient) ScrollMouse(clicks int, direction string) error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - // ydotool uses scroll codes: 0x150007 for vertical scroll - // Positive clicks = scroll down/right, negative = scroll up/left var scrollCode string if direction == "horizontal" { - // Horizontal scroll not commonly supported by ydotool return fmt.Errorf("horizontal scrolling not supported on Wayland") } else { scrollCode = "0x150007" } - // Execute scroll command cmd := exec.CommandContext(ctx, "ydotool", "click", scrollCode, "--", strconv.Itoa(clicks)) output, err := cmd.CombinedOutput() if err != nil { diff --git a/internal/domain/chat_events.go b/internal/domain/chat_events.go index cca9effe..9cfe8e43 100644 --- a/internal/domain/chat_events.go +++ b/internal/domain/chat_events.go @@ -4,9 +4,10 @@ import "time" // ToolInfo represents basic tool information for UI display type ToolInfo struct { - CallID string - Name string - Status string + CallID string + Name string + Status string + Arguments string } // BaseChatEvent provides common implementation for ChatEvent interface diff --git a/internal/domain/events.go b/internal/domain/events.go index 3ac7346d..534868a0 100644 --- a/internal/domain/events.go +++ b/internal/domain/events.go @@ -210,7 +210,7 @@ type ToolApprovalRequestedEvent struct { RequestID string Timestamp time.Time ToolCall sdk.ChatCompletionMessageToolCall - ResponseChan chan ApprovalAction + ResponseChan chan ApprovalAction `json:"-"` } func (e ToolApprovalRequestedEvent) GetRequestID() string { return e.RequestID } @@ -236,12 +236,21 @@ type ToolRejectedEvent struct { func (e ToolRejectedEvent) GetRequestID() string { return e.RequestID } func (e ToolRejectedEvent) GetTimestamp() time.Time { return e.Timestamp } +// ToolApprovalClearedEvent indicates approval UI should be cleared (approval was processed) +type ToolApprovalClearedEvent struct { + RequestID string + Timestamp time.Time +} + +func (e ToolApprovalClearedEvent) GetRequestID() string { return e.RequestID } +func (e ToolApprovalClearedEvent) GetTimestamp() time.Time { return e.Timestamp } + // PlanApprovalRequestedEvent indicates plan mode completion requires user approval type PlanApprovalRequestedEvent struct { RequestID string Timestamp time.Time PlanContent string - ResponseChan chan PlanApprovalAction + ResponseChan chan PlanApprovalAction `json:"-"` } func (e PlanApprovalRequestedEvent) GetRequestID() string { return e.RequestID } diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 8039d592..4a50f643 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -201,6 +201,16 @@ type ChatEvent interface { GetTimestamp() time.Time } +// EventBridge multicasts chat events to multiple subscribers (e.g., terminal UI and floating window) +type EventBridge interface { + // Tap intercepts an event stream and multicasts it to all subscribers + // Returns a new channel that mirrors the input channel + Tap(input <-chan ChatEvent) <-chan ChatEvent + + // Publish broadcasts an event to all subscribers + Publish(event ChatEvent) +} + // ChatMetrics holds performance and usage metrics type ChatMetrics struct { Duration time.Duration @@ -268,6 +278,10 @@ type StateManager interface { GetChatSession() *ChatSession IsAgentBusy() bool + // Event multicast for floating window + SetEventBridge(bridge EventBridge) + BroadcastEvent(event ChatEvent) + // Tool execution management StartToolExecution(toolCalls []sdk.ChatCompletionMessageToolCall) error CompleteCurrentTool(result *ToolExecutionResult) error diff --git a/internal/domain/state.go b/internal/domain/state.go index 1ecf1a57..e75b192e 100644 --- a/internal/domain/state.go +++ b/internal/domain/state.go @@ -763,7 +763,7 @@ func (s *ApplicationState) ClearFileSelectionState() { // SetupApprovalUIState initializes approval UI state with the pending tool call func (s *ApplicationState) SetupApprovalUIState(toolCall *sdk.ChatCompletionMessageToolCall, responseChan chan ApprovalAction) { s.approvalUIState = &ApprovalUIState{ - SelectedIndex: int(ApprovalApprove), // Default to approve + SelectedIndex: int(ApprovalApprove), PendingToolCall: toolCall, ResponseChan: responseChan, } diff --git a/internal/handlers/chat_event_handler.go b/internal/handlers/chat_event_handler.go index c54cea8f..8975727a 100644 --- a/internal/handlers/chat_event_handler.go +++ b/internal/handlers/chat_event_handler.go @@ -217,6 +217,18 @@ func (e *ChatEventHandler) handleChatComplete( var cmds []tea.Cmd + for _, toolCall := range msg.ToolCalls { + previewEvent := domain.ToolCallPreviewEvent{ + RequestID: msg.RequestID, + Timestamp: msg.Timestamp, + ToolCallID: toolCall.Id, + ToolName: toolCall.Function.Name, + Arguments: toolCall.Function.Arguments, + } + + e.handler.stateManager.BroadcastEvent(previewEvent) + } + cmds = append(cmds, func() tea.Msg { return domain.UpdateHistoryEvent{ History: e.handler.conversationRepo.GetMessages(), @@ -445,6 +457,8 @@ func (e *ChatEventHandler) handleToolExecutionProgress( ) tea.Cmd { var cmds []tea.Cmd + // Don't broadcast progress events - tool calls are broadcast from ChatCompleteEvent + switch msg.Status { case "starting": e.activeToolCallID = msg.ToolCallID diff --git a/internal/services/agent.go b/internal/services/agent.go index 63b47eff..9b3030dd 100644 --- a/internal/services/agent.go +++ b/internal/services/agent.go @@ -9,7 +9,6 @@ import ( "time" constants "github.com/inference-gateway/cli/internal/constants" - display "github.com/inference-gateway/cli/internal/display" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" sdk "github.com/inference-gateway/sdk" @@ -107,9 +106,10 @@ func (p *eventPublisher) publishParallelToolsStart(toolCalls []sdk.ChatCompletio tools := make([]domain.ToolInfo, len(toolCalls)) for i, tc := range toolCalls { tools[i] = domain.ToolInfo{ - CallID: tc.Id, - Name: tc.Function.Name, - Status: "queued", + CallID: tc.Id, + Name: tc.Function.Name, + Status: "queued", + Arguments: tc.Function.Arguments, } } @@ -700,6 +700,15 @@ func (s *AgentServiceImpl) CancelRequest(requestID string) error { } } + if s.stateManager != nil { + cancelEvent := domain.CancelledEvent{ + RequestID: requestID, + Timestamp: time.Now(), + Reason: "user cancelled", + } + s.stateManager.BroadcastEvent(cancelEvent) + } + return nil } @@ -775,20 +784,13 @@ func (s *AgentServiceImpl) optimizeConversation(_ context.Context, req *domain.A originalCount := len(conversation) - eventPublisher.publishOptimizationStatus("Optimizing conversation history...", true, originalCount, originalCount) - conversation = s.optimizer.OptimizeMessages(conversation, req.Model, false) optimizedCount := len(conversation) - var message string if originalCount != optimizedCount { - message = fmt.Sprintf("Conversation optimized (%d → %d messages)", originalCount, optimizedCount) - } else { - message = "Conversation optimization completed" + eventPublisher.publishOptimizationStatus(fmt.Sprintf("Conversation optimized (%d → %d messages)", originalCount, optimizedCount), false, originalCount, optimizedCount) } - eventPublisher.publishOptimizationStatus(message, false, originalCount, optimizedCount) - return conversation } @@ -1276,13 +1278,6 @@ func (s *AgentServiceImpl) requestToolApproval( tc sdk.ChatCompletionMessageToolCall, eventPublisher *eventPublisher, ) (bool, error) { - var savedAppID string - shouldRestoreFocus := s.shouldRestoreFocusForTool(tc.Function.Name) - - if shouldRestoreFocus { - savedAppID = s.saveFocusAndSwitchToTerminal(ctx) - } - responseChan := make(chan domain.ApprovalAction, 1) eventPublisher.chatEvents <- domain.ToolApprovalRequestedEvent{ @@ -1297,137 +1292,20 @@ func (s *AgentServiceImpl) requestToolApproval( select { case response := <-responseChan: - approved = response == domain.ApprovalApprove + if response == domain.ApprovalAutoAccept { + logger.Info("Switching to auto-accept mode from floating window") + s.stateManager.SetAgentMode(domain.AgentModeAutoAccept) + } + approved = response == domain.ApprovalApprove || response == domain.ApprovalAutoAccept case <-ctx.Done(): err = fmt.Errorf("approval request cancelled: %w", ctx.Err()) case <-time.After(5 * time.Minute): err = fmt.Errorf("approval request timed out") } - if shouldRestoreFocus && savedAppID != "" { - s.restoreFocus(ctx, savedAppID) - time.Sleep(500 * time.Millisecond) - } - return approved, err } -// shouldRestoreFocusForTool determines if focus should be restored for a given tool -func (s *AgentServiceImpl) shouldRestoreFocusForTool(toolName string) bool { - cfg := s.config.GetConfig() - if cfg == nil { - return false - } - - if !cfg.ComputerUse.Enabled || !cfg.ComputerUse.RestoreFocusOnApproval { - return false - } - - // Only restore focus for computer use tools - computerUseTools := map[string]bool{ - "MouseMove": true, - "MouseClick": true, - "KeyboardType": true, - } - - return computerUseTools[toolName] -} - -// saveFocusAndSwitchToTerminal saves the currently focused app and switches to approval target -// In terminal mode: switches to terminal -// In web mode: switches to browser running web UI -// Returns the saved app ID for later restoration -func (s *AgentServiceImpl) saveFocusAndSwitchToTerminal(ctx context.Context) string { - displayProvider, err := display.DetectDisplay() - if err != nil { - logger.Debug("Failed to detect display for focus management", "error", err) - return "" - } - - controller, err := displayProvider.GetController() - if err != nil { - logger.Debug("Failed to get display controller for focus management", "error", err) - return "" - } - defer func() { - if err := controller.Close(); err != nil { - logger.Debug("Failed to close display controller", "error", err) - } - }() - - // Check if controller supports focus management - focusManager, ok := controller.(display.FocusManager) - if !ok { - logger.Debug("Display controller does not support focus management") - return "" - } - - // Save currently focused app - savedAppID, err := focusManager.GetFrontmostApp(ctx) - if err != nil { - logger.Debug("Failed to get frontmost app", "error", err) - return "" - } - - cfg := s.config.GetConfig() - if cfg == nil { - return savedAppID - } - - if cfg.Web.Enabled && !cfg.Web.SSH.Enabled { - logger.Debug("Web mode - no focus switch needed for approval", "saved_app", savedAppID) - return savedAppID - } - - // Terminal mode: switch to terminal - if err := focusManager.SwitchToTerminal(ctx); err != nil { - logger.Warn("Failed to switch to terminal for approval", "error", err) - return savedAppID - } - - logger.Debug("Switched to terminal for approval", "saved_app", savedAppID) - return savedAppID -} - -// restoreFocus restores focus to the previously focused application -func (s *AgentServiceImpl) restoreFocus(ctx context.Context, appID string) { - if appID == "" { - return - } - - displayProvider, err := display.DetectDisplay() - if err != nil { - logger.Debug("Failed to detect display for focus restoration", "error", err) - return - } - - controller, err := displayProvider.GetController() - if err != nil { - logger.Debug("Failed to get display controller for focus restoration", "error", err) - return - } - defer func() { - if err := controller.Close(); err != nil { - logger.Debug("Failed to close display controller", "error", err) - } - }() - - focusManager, ok := controller.(display.FocusManager) - if !ok { - return - } - - // Small delay to allow approval UI to process before switching away - time.Sleep(200 * time.Millisecond) - - if err := focusManager.ActivateApp(ctx, appID); err != nil { - logger.Debug("Failed to restore focus", "app", appID, "error", err) - return - } - - logger.Debug("Restored focus after approval", "app", appID) -} - // isBashCommandWhitelisted checks if a Bash tool command is whitelisted func (s *AgentServiceImpl) isBashCommandWhitelisted(tc *sdk.ChatCompletionMessageToolCall) bool { var args map[string]any diff --git a/internal/services/screenshot_server.go b/internal/services/screenshot_server.go index 23e67a0f..b84b0e15 100644 --- a/internal/services/screenshot_server.go +++ b/internal/services/screenshot_server.go @@ -7,32 +7,34 @@ import ( "net" "net/http" "path/filepath" + "runtime" "strconv" "sync" "time" config "github.com/inference-gateway/cli/config" display "github.com/inference-gateway/cli/internal/display" + "github.com/inference-gateway/cli/internal/display/macos" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" - _ "github.com/inference-gateway/cli/internal/display/macos" _ "github.com/inference-gateway/cli/internal/display/wayland" _ "github.com/inference-gateway/cli/internal/display/x11" ) // ScreenshotServer provides an HTTP API for screenshot streaming type ScreenshotServer struct { - cfg *config.Config - port int - server *http.Server - buffer *CircularScreenshotBuffer - captureCtx context.Context - captureStop context.CancelFunc - mu sync.RWMutex - sessionID string - imageSvc domain.ImageService - running bool + cfg *config.Config + port int + server *http.Server + buffer *CircularScreenshotBuffer + captureCtx context.Context + captureStop context.CancelFunc + mu sync.RWMutex + sessionID string + imageSvc domain.ImageService + running bool + overlayWindow *macos.OverlayWindow } // NewScreenshotServer creates a new screenshot server @@ -102,6 +104,8 @@ func (s *ScreenshotServer) Start() error { s.running = true + s.showOverlayIfEnabled() + interval := s.cfg.ComputerUse.Screenshot.CaptureInterval if interval <= 0 { interval = 3 @@ -144,6 +148,17 @@ func (s *ScreenshotServer) Stop() error { } } + if s.overlayWindow != nil { + if err := s.overlayWindow.Hide(); err != nil { + logger.Warn("Failed to hide overlay window", "error", err) + } + if err := s.overlayWindow.Destroy(); err != nil { + logger.Warn("Failed to destroy overlay window", "error", err) + } + s.overlayWindow = nil + logger.Info("Screenshot overlay window destroyed") + } + s.running = false return nil @@ -156,6 +171,31 @@ func (s *ScreenshotServer) Port() int { return s.port } +// showOverlayIfEnabled shows the overlay window if configured +func (s *ScreenshotServer) showOverlayIfEnabled() { + if !s.cfg.ComputerUse.Screenshot.ShowOverlay { + return + } + + if runtime.GOOS != "darwin" { + return + } + + overlay, err := macos.NewOverlayWindow() + if err != nil { + logger.Warn("Failed to create overlay window", "error", err) + return + } + + s.overlayWindow = overlay + if err := s.overlayWindow.Show(); err != nil { + logger.Warn("Failed to show overlay window", "error", err) + return + } + + logger.Info("Screenshot overlay window shown") +} + // startCaptureLoop runs the background screenshot capture loop func (s *ScreenshotServer) startCaptureLoop() { interval := s.cfg.ComputerUse.Screenshot.CaptureInterval diff --git a/internal/services/state_manager.go b/internal/services/state_manager.go index 0b28ed72..8c2dd73c 100644 --- a/internal/services/state_manager.go +++ b/internal/services/state_manager.go @@ -19,6 +19,9 @@ type StateManager struct { // State change listeners listeners []StateChangeListener + // Event multicast for floating window (optional) + eventBridge domain.EventBridge + // Debug and audit trail debugMode bool stateHistory []domain.StateSnapshot @@ -187,6 +190,13 @@ func (sm *StateManager) SetChatPending() { } } +// SetEventBridge sets the event bridge for multicasting events to floating window +func (sm *StateManager) SetEventBridge(bridge domain.EventBridge) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.eventBridge = bridge +} + // StartChatSession starts a new chat session func (sm *StateManager) StartChatSession(requestID, model string, eventChan <-chan domain.ChatEvent) error { sm.mutex.Lock() @@ -194,6 +204,10 @@ func (sm *StateManager) StartChatSession(requestID, model string, eventChan <-ch oldState := sm.state.GetStateSnapshot() + if sm.eventBridge != nil { + eventChan = sm.eventBridge.Tap(eventChan) + } + sm.state.StartChatSession(requestID, model, eventChan) sm.captureStateChange(StateChangeTypeChatStatus, oldState) @@ -505,6 +519,24 @@ func (sm *StateManager) ClearApprovalUIState() { defer sm.mutex.Unlock() sm.state.ClearApprovalUIState() + + if sm.eventBridge != nil { + requestID := "" + if chatSession := sm.state.GetChatSession(); chatSession != nil { + requestID = chatSession.RequestID + } + sm.eventBridge.Publish(domain.ToolApprovalClearedEvent{ + RequestID: requestID, + Timestamp: time.Now(), + }) + } +} + +// BroadcastEvent publishes an event to the EventBridge for floating window +func (sm *StateManager) BroadcastEvent(event domain.ChatEvent) { + if sm.eventBridge != nil { + sm.eventBridge.Publish(event) + } } // Plan approval state methods diff --git a/internal/services/tools/get_focused_app.go b/internal/services/tools/get_focused_app.go index bbe152e1..c38493a9 100644 --- a/internal/services/tools/get_focused_app.go +++ b/internal/services/tools/get_focused_app.go @@ -74,7 +74,6 @@ func (t *GetFocusedAppTool) Execute(ctx context.Context, args map[string]any) (* return nil, fmt.Errorf("no application is currently focused") } - // Parse app name from bundle ID (e.g., "org.mozilla.firefox" -> "Firefox") appName := parseAppName(appID) result := fmt.Sprintf("Currently focused application:\n- Name: %s\n- Bundle ID: %s", appName, appID) @@ -149,7 +148,6 @@ func (t *GetFocusedAppTool) FormatResult(result *domain.ToolExecutionResult, for // parseAppName extracts a human-readable app name from bundle ID func parseAppName(bundleID string) string { - // Common mappings appNames := map[string]string{ "com.apple.Terminal": "Terminal", "com.googlecode.iterm2": "iTerm2", @@ -183,6 +181,5 @@ func parseAppName(bundleID string) string { return name } - // Fallback: return bundle ID return bundleID } diff --git a/tests/mocks/domain/fake_state_manager.go b/tests/mocks/domain/fake_state_manager.go index 9a78bd3a..e2748439 100644 --- a/tests/mocks/domain/fake_state_manager.go +++ b/tests/mocks/domain/fake_state_manager.go @@ -19,6 +19,11 @@ type FakeStateManager struct { areAllAgentsReadyReturnsOnCall map[int]struct { result1 bool } + BroadcastEventStub func(domain.ChatEvent) + broadcastEventMutex sync.RWMutex + broadcastEventArgsForCall []struct { + arg1 domain.ChatEvent + } ClearAgentReadinessStub func() clearAgentReadinessMutex sync.RWMutex clearAgentReadinessArgsForCall []struct { @@ -257,6 +262,11 @@ type FakeStateManager struct { arg1 int arg2 int } + SetEventBridgeStub func(domain.EventBridge) + setEventBridgeMutex sync.RWMutex + setEventBridgeArgsForCall []struct { + arg1 domain.EventBridge + } SetFileSelectedIndexStub func(int) setFileSelectedIndexMutex sync.RWMutex setFileSelectedIndexArgsForCall []struct { @@ -411,6 +421,38 @@ func (fake *FakeStateManager) AreAllAgentsReadyReturnsOnCall(i int, result1 bool }{result1} } +func (fake *FakeStateManager) BroadcastEvent(arg1 domain.ChatEvent) { + fake.broadcastEventMutex.Lock() + fake.broadcastEventArgsForCall = append(fake.broadcastEventArgsForCall, struct { + arg1 domain.ChatEvent + }{arg1}) + stub := fake.BroadcastEventStub + fake.recordInvocation("BroadcastEvent", []interface{}{arg1}) + fake.broadcastEventMutex.Unlock() + if stub != nil { + fake.BroadcastEventStub(arg1) + } +} + +func (fake *FakeStateManager) BroadcastEventCallCount() int { + fake.broadcastEventMutex.RLock() + defer fake.broadcastEventMutex.RUnlock() + return len(fake.broadcastEventArgsForCall) +} + +func (fake *FakeStateManager) BroadcastEventCalls(stub func(domain.ChatEvent)) { + fake.broadcastEventMutex.Lock() + defer fake.broadcastEventMutex.Unlock() + fake.BroadcastEventStub = stub +} + +func (fake *FakeStateManager) BroadcastEventArgsForCall(i int) domain.ChatEvent { + fake.broadcastEventMutex.RLock() + defer fake.broadcastEventMutex.RUnlock() + argsForCall := fake.broadcastEventArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeStateManager) ClearAgentReadiness() { fake.clearAgentReadinessMutex.Lock() fake.clearAgentReadinessArgsForCall = append(fake.clearAgentReadinessArgsForCall, struct { @@ -1717,6 +1759,38 @@ func (fake *FakeStateManager) SetDimensionsArgsForCall(i int) (int, int) { return argsForCall.arg1, argsForCall.arg2 } +func (fake *FakeStateManager) SetEventBridge(arg1 domain.EventBridge) { + fake.setEventBridgeMutex.Lock() + fake.setEventBridgeArgsForCall = append(fake.setEventBridgeArgsForCall, struct { + arg1 domain.EventBridge + }{arg1}) + stub := fake.SetEventBridgeStub + fake.recordInvocation("SetEventBridge", []interface{}{arg1}) + fake.setEventBridgeMutex.Unlock() + if stub != nil { + fake.SetEventBridgeStub(arg1) + } +} + +func (fake *FakeStateManager) SetEventBridgeCallCount() int { + fake.setEventBridgeMutex.RLock() + defer fake.setEventBridgeMutex.RUnlock() + return len(fake.setEventBridgeArgsForCall) +} + +func (fake *FakeStateManager) SetEventBridgeCalls(stub func(domain.EventBridge)) { + fake.setEventBridgeMutex.Lock() + defer fake.setEventBridgeMutex.Unlock() + fake.SetEventBridgeStub = stub +} + +func (fake *FakeStateManager) SetEventBridgeArgsForCall(i int) domain.EventBridge { + fake.setEventBridgeMutex.RLock() + defer fake.setEventBridgeMutex.RUnlock() + argsForCall := fake.setEventBridgeArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeStateManager) SetFileSelectedIndex(arg1 int) { fake.setFileSelectedIndexMutex.Lock() fake.setFileSelectedIndexArgsForCall = append(fake.setFileSelectedIndexArgsForCall, struct { From 75610a0f4a51c910a4b9ffe2ff75f215496fc870 Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Sun, 4 Jan 2026 22:17:54 +0200 Subject: [PATCH 09/14] refactor: Replace interface{} with any type --- internal/display/macos/event_bridge.go | 2 +- internal/display/macos/types.go | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/internal/display/macos/event_bridge.go b/internal/display/macos/event_bridge.go index 9309767c..0f4576d7 100644 --- a/internal/display/macos/event_bridge.go +++ b/internal/display/macos/event_bridge.go @@ -70,7 +70,7 @@ func (eb *EventBridge) Subscribe() chan domain.ChatEvent { eb.subscribers = append(eb.subscribers, sub) - eb.eventBuffer.Do(func(val interface{}) { + eb.eventBuffer.Do(func(val any) { if val != nil { event, ok := val.(domain.ChatEvent) if ok { diff --git a/internal/display/macos/types.go b/internal/display/macos/types.go index 091463fe..ee822593 100644 --- a/internal/display/macos/types.go +++ b/internal/display/macos/types.go @@ -15,9 +15,9 @@ type ApprovalResponse struct { // WindowEvent wraps domain.ChatEvent with additional metadata for the window type WindowEvent struct { - Type string `json:"type"` - Timestamp int64 `json:"timestamp"` - Data interface{} `json:"data"` + Type string `json:"type"` + Timestamp int64 `json:"timestamp"` + Data any `json:"data"` } // WindowState represents the current state snapshot of the floating window @@ -32,8 +32,8 @@ type WindowState struct { // PendingApproval represents a tool waiting for user approval type PendingApproval struct { - CallID string `json:"call_id"` - ToolName string `json:"tool_name"` - Arguments map[string]interface{} `json:"arguments"` - Timestamp int64 `json:"timestamp"` + CallID string `json:"call_id"` + ToolName string `json:"tool_name"` + Arguments map[string]any `json:"arguments"` + Timestamp int64 `json:"timestamp"` } From b10a45fb558b29e0ff2b72d703d536090067df6f Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Thu, 8 Jan 2026 01:02:56 +0200 Subject: [PATCH 10/14] feat: Add native macOS floating window for computer use --- .github/workflows/artifacts.yml | 7 + .infer/config.yaml | 6 +- Taskfile.yml | 4 + cmd/export.go | 2 +- cmd/root.go | 2 + config/config.go | 8 +- internal/container/container.go | 2 +- internal/display/macos/ComputerUse/.gitignore | 2 + internal/display/macos/ComputerUse/Info.plist | 30 + internal/display/macos/ComputerUse/README.md | 49 + internal/display/macos/ComputerUse/build.sh | 41 + internal/display/macos/ComputerUse/main.swift | 811 ++++++++++++++++ internal/display/macos/client_darwin.go | 66 +- internal/display/macos/client_darwin_test.go | 207 ++++ internal/display/macos/manager.go | 889 +++--------------- internal/display/macos/overlay_darwin.go | 155 --- internal/display/macos/overlay_darwin_test.go | 149 --- internal/display/macos/overlay_stub.go | 33 - internal/display/x11/client.go | 2 - internal/domain/chat_events.go | 24 + internal/domain/context.go | 6 + internal/domain/interfaces.go | 24 +- internal/domain/state.go | 42 + internal/handlers/chat_command_handler.go | 2 + .../services/circular_screenshot_buffer.go | 14 +- internal/services/mcp_manager.go | 2 +- internal/services/screenshot_server.go | 201 ++-- internal/services/state_manager.go | 44 + internal/services/tools/coordinate_scaler.go | 54 ++ internal/services/tools/keyboard_type.go | 51 +- internal/services/tools/keyboard_type_test.go | 8 +- internal/services/tools/mouse_click.go | 208 ++-- internal/services/tools/mouse_move.go | 80 +- internal/services/tools/mouse_move_test.go | 215 +++++ internal/services/tools/registry.go | 56 +- internal/services/tools/registry_test.go | 16 +- internal/shortcuts/custom.go | 1 + .../mocks/display/fake_display_controller.go | 802 ++++++++++++++++ tests/mocks/display/fake_provider.go | 231 +++++ tests/mocks/domain/fake_rate_limiter.go | 200 ++++ tests/mocks/domain/fake_state_manager.go | 263 ++++++ 41 files changed, 3708 insertions(+), 1301 deletions(-) create mode 100644 internal/display/macos/ComputerUse/.gitignore create mode 100644 internal/display/macos/ComputerUse/Info.plist create mode 100644 internal/display/macos/ComputerUse/README.md create mode 100755 internal/display/macos/ComputerUse/build.sh create mode 100644 internal/display/macos/ComputerUse/main.swift create mode 100644 internal/display/macos/client_darwin_test.go delete mode 100644 internal/display/macos/overlay_darwin.go delete mode 100644 internal/display/macos/overlay_darwin_test.go delete mode 100644 internal/display/macos/overlay_stub.go create mode 100644 internal/services/tools/coordinate_scaler.go create mode 100644 internal/services/tools/mouse_move_test.go create mode 100644 tests/mocks/display/fake_display_controller.go create mode 100644 tests/mocks/display/fake_provider.go create mode 100644 tests/mocks/domain/fake_rate_limiter.go diff --git a/.github/workflows/artifacts.yml b/.github/workflows/artifacts.yml index 8e64b331..9b60afbd 100644 --- a/.github/workflows/artifacts.yml +++ b/.github/workflows/artifacts.yml @@ -66,6 +66,13 @@ jobs: golang:1.25-alpine3.23 \ sh -c "go build -ldflags '-w -s -X github.com/inference-gateway/cli/cmd.version=${{ steps.version.outputs.version }} -X github.com/inference-gateway/cli/cmd.commit=${{ steps.version.outputs.commit }} -X github.com/inference-gateway/cli/cmd.date=${{ steps.version.outputs.date }}' -o infer-${{ matrix.goos }}-${{ matrix.goarch }} ." + - name: Computer Use App (macOS only) + if: matrix.goos == 'darwin' + run: | + cd internal/display/macos/ComputerUse + ./build.sh + cd ../../../.. + - name: Build binary (macOS with CGO for clipboard image support) if: matrix.goos == 'darwin' env: diff --git a/.infer/config.yaml b/.infer/config.yaml index 1ba02935..7f1a1c77 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -710,8 +710,10 @@ computer_use: enabled: true max_width: 1920 max_height: 1080 + target_width: 1024 + target_height: 768 format: jpeg - quality: 80 + quality: 85 require_approval: false streaming_enabled: true capture_interval: 3 @@ -731,7 +733,7 @@ computer_use: keyboard_type: enabled: true max_text_length: 1000 - typing_delay_ms: 200 + typing_delay_ms: 100 require_approval: true get_focused_app: enabled: true diff --git a/Taskfile.yml b/Taskfile.yml index 85033627..f533ba0c 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -269,7 +269,11 @@ tasks: - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain TaskTracker - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain A2AAgentService - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain MCPClient + - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/domain RateLimiter - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/domain internal/infra/storage ConversationStorage + - mkdir -p tests/mocks/display + - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/display internal/display DisplayController + - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/display internal/display Provider - mkdir -p tests/mocks/services - go run github.com/maxbrunsfeld/counterfeiter/v6 -o tests/mocks/services internal/services TitleGenerator - mkdir -p tests/mocks/shortcuts diff --git a/cmd/export.go b/cmd/export.go index 0f066c01..7200affa 100644 --- a/cmd/export.go +++ b/cmd/export.go @@ -47,7 +47,7 @@ func runExport(sessionID string) error { } configService := services.NewConfigService(V, cfg) - toolRegistry := tools.NewRegistry(configService, nil, nil, nil) + toolRegistry := tools.NewRegistry(configService, nil, nil, nil, nil, nil) toolFormatterService := services.NewToolFormatterService(toolRegistry) pricingService := services.NewPricingService(&cfg.Pricing) persistentRepo := services.NewPersistentConversationRepository(toolFormatterService, pricingService, storageBackend) diff --git a/cmd/root.go b/cmd/root.go index 09093f69..f209f100 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -100,6 +100,8 @@ func initConfig() { // nolint:funlen v.SetDefault("computer_use.screenshot.enabled", defaults.ComputerUse.Screenshot.Enabled) v.SetDefault("computer_use.screenshot.max_width", defaults.ComputerUse.Screenshot.MaxWidth) v.SetDefault("computer_use.screenshot.max_height", defaults.ComputerUse.Screenshot.MaxHeight) + v.SetDefault("computer_use.screenshot.target_width", defaults.ComputerUse.Screenshot.TargetWidth) + v.SetDefault("computer_use.screenshot.target_height", defaults.ComputerUse.Screenshot.TargetHeight) v.SetDefault("computer_use.screenshot.format", defaults.ComputerUse.Screenshot.Format) v.SetDefault("computer_use.screenshot.quality", defaults.ComputerUse.Screenshot.Quality) v.SetDefault("computer_use.screenshot.streaming_enabled", defaults.ComputerUse.Screenshot.StreamingEnabled) diff --git a/config/config.go b/config/config.go index 52bb17c8..e9cfd5f3 100644 --- a/config/config.go +++ b/config/config.go @@ -267,6 +267,8 @@ type ScreenshotToolConfig struct { Enabled bool `yaml:"enabled" mapstructure:"enabled"` MaxWidth int `yaml:"max_width" mapstructure:"max_width"` MaxHeight int `yaml:"max_height" mapstructure:"max_height"` + TargetWidth int `yaml:"target_width" mapstructure:"target_width"` + TargetHeight int `yaml:"target_height" mapstructure:"target_height"` Format string `yaml:"format" mapstructure:"format"` Quality int `yaml:"quality" mapstructure:"quality"` RequireApproval *bool `yaml:"require_approval,omitempty" mapstructure:"require_approval,omitempty"` @@ -1134,8 +1136,10 @@ Write the AGENTS.md file to the project root when you have gathered enough infor Enabled: true, MaxWidth: 1920, MaxHeight: 1080, + TargetWidth: 1024, + TargetHeight: 768, Format: "jpeg", - Quality: 80, + Quality: 85, RequireApproval: &[]bool{false}[0], StreamingEnabled: true, CaptureInterval: 3, @@ -1159,7 +1163,7 @@ Write the AGENTS.md file to the project root when you have gathered enough infor KeyboardType: KeyboardTypeToolConfig{ Enabled: true, MaxTextLength: 1000, - TypingDelayMs: 200, + TypingDelayMs: 100, RequireApproval: &[]bool{true}[0], }, GetFocusedApp: GetFocusedAppToolConfig{ diff --git a/internal/container/container.go b/internal/container/container.go index 68d85015..de9a9b8c 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -218,7 +218,7 @@ func (c *ServiceContainer) initializeDomainServices() { c.initializeMCPManager() - c.toolRegistry = tools.NewRegistry(c.configService, c.imageService, c.mcpManager, c.BackgroundShellService()) + c.toolRegistry = tools.NewRegistry(c.configService, c.imageService, c.mcpManager, c.BackgroundShellService(), c.stateManager, nil) c.taskTrackerService = c.toolRegistry.GetTaskTracker() toolFormatterService := services.NewToolFormatterService(c.toolRegistry) diff --git a/internal/display/macos/ComputerUse/.gitignore b/internal/display/macos/ComputerUse/.gitignore new file mode 100644 index 00000000..1e90d0b2 --- /dev/null +++ b/internal/display/macos/ComputerUse/.gitignore @@ -0,0 +1,2 @@ +# Build artifacts (built in CI/CD, not committed to SCM) +build/ diff --git a/internal/display/macos/ComputerUse/Info.plist b/internal/display/macos/ComputerUse/Info.plist new file mode 100644 index 00000000..0473a3cc --- /dev/null +++ b/internal/display/macos/ComputerUse/Info.plist @@ -0,0 +1,30 @@ + + + + + CFBundleDevelopmentRegion + en + CFBundleExecutable + FloatingWindow + CFBundleIdentifier + com.inference-gateway.floatingwindow + CFBundleInfoDictionaryVersion + 6.0 + CFBundleName + FloatingWindow + CFBundlePackageType + APPL + CFBundleShortVersionString + 1.0 + CFBundleVersion + 1 + LSMinimumSystemVersion + 10.15.4 + LSUIElement + + NSHighResolutionCapable + + NSSupportsAutomaticGraphicsSwitching + + + diff --git a/internal/display/macos/ComputerUse/README.md b/internal/display/macos/ComputerUse/README.md new file mode 100644 index 00000000..54903e3e --- /dev/null +++ b/internal/display/macos/ComputerUse/README.md @@ -0,0 +1,49 @@ +# FloatingWindow.app + +Native macOS floating window for Computer Use tool visualization. + +## Building + +The FloatingWindow.app must be built **before** building the Go binary, as it's embedded using `go:embed`. + +### Local Development + +```bash +cd internal/display/macos/FloatingWindow +./build.sh +``` + +This creates `build/FloatingWindow.app/` which is embedded in the Go binary. + +### CI/CD + +Add this step to your build pipeline before `go build`: + +```bash +# Build FloatingWindow.app (macOS only) +if [ "$(uname)" = "Darwin" ]; then + cd internal/display/macos/FloatingWindow + ./build.sh + cd - +fi + +# Build Go binary (embeds FloatingWindow.app) +go build -o infer . +``` + +## Requirements + +- macOS 10.15.4+ +- Swift compiler (included with Xcode Command Line Tools) +- `swiftc` available in PATH + +## Architecture + +- **Source**: `main.swift` - Native NSTextView-based window +- **Config**: `Info.plist` - App metadata +- **Build**: `build.sh` - Compiles to standalone .app bundle +- **Output**: `build/FloatingWindow.app/` - Embedded in Go binary via `go:embed` + +## Runtime Behavior + +On first run, the embedded .app is extracted to `~/.infer/FloatingWindow.app`. Subsequent runs reuse this extracted copy. diff --git a/internal/display/macos/ComputerUse/build.sh b/internal/display/macos/ComputerUse/build.sh new file mode 100755 index 00000000..9c5f579a --- /dev/null +++ b/internal/display/macos/ComputerUse/build.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +APP_NAME="ComputerUse" +BUILD_DIR="${SCRIPT_DIR}/build" +APP_BUNDLE="${BUILD_DIR}/${APP_NAME}.app" +CONTENTS_DIR="${APP_BUNDLE}/Contents" +MACOS_DIR="${CONTENTS_DIR}/MacOS" +RESOURCES_DIR="${CONTENTS_DIR}/Resources" + +echo "Building ${APP_NAME}.app..." + +rm -rf "${BUILD_DIR}" + +mkdir -p "${MACOS_DIR}" +mkdir -p "${RESOURCES_DIR}" + +ARCH=$(uname -m) +if [ "$ARCH" = "arm64" ]; then + TARGET="arm64-apple-macos11.0" +else + TARGET="x86_64-apple-macos10.15.4" +fi + +echo "Compiling Swift source for ${ARCH} (${TARGET})..." +swiftc -O \ + -sdk "$(xcrun --show-sdk-path)" \ + -target "${TARGET}" \ + -o "${MACOS_DIR}/${APP_NAME}" \ + "${SCRIPT_DIR}/main.swift" + +echo "Copying Info.plist..." +cp "${SCRIPT_DIR}/Info.plist" "${CONTENTS_DIR}/Info.plist" + +echo "Signing app..." +codesign --force --deep --sign - "${APP_BUNDLE}" + +echo "Build complete: ${APP_BUNDLE}" +echo "App size: $(du -sh "${APP_BUNDLE}" | cut -f1)" diff --git a/internal/display/macos/ComputerUse/main.swift b/internal/display/macos/ComputerUse/main.swift new file mode 100644 index 00000000..21537b9f --- /dev/null +++ b/internal/display/macos/ComputerUse/main.swift @@ -0,0 +1,811 @@ +import Cocoa +import Foundation + +// MARK: - Models + +struct ApprovalResponse: Codable { + let call_id: String + let action: Int // 0=Approve, 1=Reject, 2=AutoAccept +} + +// MARK: - Border Overlay Windows + +class BorderWindow: NSWindow { + init(frame: NSRect, color: NSColor) { + super.init(contentRect: frame, styleMask: .borderless, backing: .buffered, defer: false) + self.backgroundColor = color + self.isOpaque = false + self.level = .screenSaver + self.ignoresMouseEvents = true + self.collectionBehavior = [.canJoinAllSpaces, .stationary] + self.orderFront(nil) + } +} + +// MARK: - Click Indicator Window + +private var activeClickIndicators: [ClickIndicatorWindow] = [] + +class ClickIndicatorWindow: NSWindow { + private var closeTimer: Timer? + + init(x: CGFloat, y: CGFloat) { + let size: CGFloat = 40 + let frame = NSRect(x: x - size/2, y: y - size/2, width: size, height: size) + + super.init(contentRect: frame, styleMask: .borderless, backing: .buffered, defer: false) + + self.isReleasedWhenClosed = false + + self.backgroundColor = .clear + self.isOpaque = false + self.level = .screenSaver + self.ignoresMouseEvents = true + self.collectionBehavior = [.canJoinAllSpaces, .stationary] + + let indicatorView = ClickIndicatorView(frame: NSRect(x: 0, y: 0, width: size, height: size)) + self.contentView = indicatorView + + self.orderFront(nil) + + activeClickIndicators.append(self) + + self.alphaValue = 0 + self.alphaValue = 1.0 + + self.closeTimer = Timer.scheduledTimer(withTimeInterval: 0.6, repeats: false) { [weak self] timer in + guard let self = self else { return } + timer.invalidate() + + if let index = activeClickIndicators.firstIndex(where: { $0 === self }) { + activeClickIndicators.remove(at: index) + } + + self.orderOut(nil) + } + } + + deinit { + closeTimer?.invalidate() + } +} + +class ClickIndicatorView: NSView { + override func draw(_ dirtyRect: NSRect) { + super.draw(dirtyRect) + + guard let context = NSGraphicsContext.current?.cgContext else { return } + + let outerColor = NSColor(red: 0.48, green: 0.64, blue: 0.97, alpha: 0.8) + context.setStrokeColor(outerColor.cgColor) + context.setLineWidth(3) + let outerRect = bounds.insetBy(dx: 2, dy: 2) + context.strokeEllipse(in: outerRect) + + let innerColor = NSColor(red: 0.48, green: 0.64, blue: 0.97, alpha: 0.3) + context.setFillColor(innerColor.cgColor) + let innerRect = bounds.insetBy(dx: 12, dy: 12) + context.fillEllipse(in: innerRect) + } +} + +// MARK: - Move Trail Indicator Window + +private var activeMoveIndicators: [MoveTrailWindow] = [] + +class MoveTrailWindow: NSWindow { + init(fromX: CGFloat, fromY: CGFloat, toX: CGFloat, toY: CGFloat) { + let screen = NSScreen.main! + let screenFrame = screen.frame + + let fromPoint = NSPoint(x: fromX, y: screenFrame.height - fromY) + let toPoint = NSPoint(x: toX, y: screenFrame.height - toY) + + super.init( + contentRect: screenFrame, + styleMask: .borderless, + backing: .buffered, + defer: false + ) + + self.isOpaque = false + self.backgroundColor = .clear + self.level = .screenSaver + self.ignoresMouseEvents = true + self.collectionBehavior = [.canJoinAllSpaces, .stationary] + + let trailView = MoveTrailView(from: fromPoint, to: toPoint) + self.contentView = trailView + + self.makeKeyAndOrderFront(nil) + + activeMoveIndicators.append(self) + + Timer.scheduledTimer(withTimeInterval: 0.8, repeats: false) { [weak self] timer in + timer.invalidate() + guard let self = self else { return } + + if let index = activeMoveIndicators.firstIndex(where: { $0 === self }) { + activeMoveIndicators.remove(at: index) + } + + self.orderOut(nil) + } + } +} + +class MoveTrailView: NSView { + let fromPoint: NSPoint + let toPoint: NSPoint + var animationProgress: CGFloat = 0.0 + var displayLink: CVDisplayLink? + + init(from: NSPoint, to: NSPoint) { + self.fromPoint = from + self.toPoint = to + super.init(frame: .zero) + + animateTrail() + } + + required init?(coder: NSCoder) { + fatalError("init(coder:) not implemented") + } + + func animateTrail() { + Timer.scheduledTimer(withTimeInterval: 0.016, repeats: true) { [weak self] timer in + guard let self = self else { + timer.invalidate() + return + } + + self.animationProgress += 0.03 + if self.animationProgress >= 1.0 { + self.animationProgress = 1.0 + timer.invalidate() + } + self.needsDisplay = true + } + } + + override func draw(_ dirtyRect: NSRect) { + super.draw(dirtyRect) + + guard let context = NSGraphicsContext.current?.cgContext else { return } + + let path = NSBezierPath() + path.move(to: fromPoint) + + let currentX = fromPoint.x + (toPoint.x - fromPoint.x) * animationProgress + let currentY = fromPoint.y + (toPoint.y - fromPoint.y) * animationProgress + + path.line(to: NSPoint(x: currentX, y: currentY)) + + let trailColor = NSColor(red: 0.3, green: 0.9, blue: 0.7, alpha: 1.0 - animationProgress * 0.5) + trailColor.setStroke() + path.lineWidth = 3.0 + path.stroke() + + if animationProgress > 0.1 { + let arrowSize: CGFloat = 10.0 + + let dx = toPoint.x - fromPoint.x + let dy = toPoint.y - fromPoint.y + let angle = atan2(dy, dx) + + context.saveGState() + context.translateBy(x: currentX, y: currentY) + context.rotate(by: angle) + + let arrowPath = NSBezierPath() + arrowPath.move(to: NSPoint(x: 0, y: 0)) + arrowPath.line(to: NSPoint(x: -arrowSize, y: arrowSize/2)) + arrowPath.line(to: NSPoint(x: -arrowSize, y: -arrowSize/2)) + arrowPath.close() + + trailColor.setFill() + arrowPath.fill() + + context.restoreGState() + } + } +} + +// MARK: - Clickable View for Minimized State + +class ClickableView: NSView { + var onClicked: (() -> Void)? + var dragStartLocation: NSPoint? + var hasDragged = false + + override func mouseDown(with event: NSEvent) { + dragStartLocation = event.locationInWindow + hasDragged = false + } + + override func mouseDragged(with event: NSEvent) { + guard let window = self.window, + let startLocation = dragStartLocation else { return } + + let currentLocation = event.locationInWindow + let deltaY = currentLocation.y - startLocation.y + + if abs(deltaY) > 3 { + hasDragged = true + } + + guard let screen = NSScreen.main else { return } + let screenFrame = screen.visibleFrame + + var newOrigin = window.frame.origin + newOrigin.y += deltaY + newOrigin.x = screenFrame.maxX - window.frame.width + + window.setFrameOrigin(newOrigin) + } + + override func mouseUp(with event: NSEvent) { + if !hasDragged { + onClicked?() + } + + dragStartLocation = nil + hasDragged = false + } +} + +// MARK: - Main Floating Window + +class FloatingWindow: NSPanel { + let scrollView = NSScrollView() + let textView = NSTextView() + let approvalBox = NSView() + let approveButton = NSButton(title: "✓ Approve", target: nil, action: nil) + let rejectButton = NSButton(title: "✗ Reject", target: nil, action: nil) + let autoButton = NSButton(title: "Auto-Approve", target: nil, action: nil) + + var currentCallID: String? + var isMinimized = false + var fullFrame: NSRect? + var wasApprovalVisible = false + let minimizedWidth: CGFloat = 40 + let minimizedHeight: CGFloat = 150 + var minimizedYPosition: CGFloat? + + var borderWindows: [BorderWindow] = [] + + init(position: String, alwaysOnTop: Bool) { + let screenFrame = NSScreen.main!.visibleFrame + let windowWidth: CGFloat = 450 + let windowHeight: CGFloat = 600 + + var xPos: CGFloat + switch position { + case "top-left": + xPos = screenFrame.minX + 20 + case "top-right": + xPos = screenFrame.maxX - windowWidth - 20 + default: + xPos = screenFrame.maxX - windowWidth - 20 + } + + let yPos = screenFrame.maxY - windowHeight - 20 + let frame = NSRect(x: xPos, y: yPos, width: windowWidth, height: windowHeight) + + super.init(contentRect: frame, styleMask: [.titled, .resizable, .miniaturizable, .fullSizeContentView], backing: .buffered, defer: false) + + self.title = "Computer Use" + self.isFloatingPanel = true + self.level = alwaysOnTop ? .floating : .normal + self.collectionBehavior = [.canJoinAllSpaces, .fullScreenAuxiliary] + self.hidesOnDeactivate = false + + self.isOpaque = false + self.alphaValue = 0.90 + self.backgroundColor = NSColor(calibratedWhite: 0.1, alpha: 0.85) + + self.contentView?.wantsLayer = true + if let layer = self.contentView?.layer { + layer.cornerRadius = 12 + layer.masksToBounds = false + } + + self.hasShadow = true + self.invalidateShadow() + + self.titlebarAppearsTransparent = true + self.titleVisibility = .visible + self.isMovableByWindowBackground = true + + self.standardWindowButton(.closeButton)?.alphaValue = 0 + self.standardWindowButton(.zoomButton)?.alphaValue = 0 + + if let minimizeButton = self.standardWindowButton(.miniaturizeButton) { + minimizeButton.target = self + minimizeButton.action = #selector(customMinimize) + } + + setupUI() + self.orderFront(nil) + } + + @objc func customMinimize() { + if isMinimized { + restoreWindow() + } else { + minimizeToSide() + } + } + + func minimizeToSide() { + guard let screen = NSScreen.main else { return } + isMinimized = true + fullFrame = self.frame + wasApprovalVisible = !self.approvalBox.isHidden + + let screenFrame = screen.visibleFrame + let xPos = screenFrame.maxX - minimizedWidth + + let yPos: CGFloat + if let savedY = minimizedYPosition { + yPos = savedY + } else { + yPos = screenFrame.midY - (minimizedHeight / 2) + } + + let minimizedFrame = NSRect(x: xPos, y: yPos, width: minimizedWidth, height: minimizedHeight) + + NSAnimationContext.runAnimationGroup({ context in + context.duration = 0.3 + context.timingFunction = CAMediaTimingFunction(name: .easeInEaseOut) + self.animator().setFrame(minimizedFrame, display: true) + self.animator().alphaValue = 1.0 + }, completionHandler: { + self.scrollView.isHidden = true + self.approvalBox.isHidden = true + self.titleVisibility = .hidden + self.titlebarAppearsTransparent = true + self.standardWindowButton(.closeButton)?.alphaValue = 0 + self.standardWindowButton(.miniaturizeButton)?.alphaValue = 0 + self.standardWindowButton(.zoomButton)?.alphaValue = 0 + self.isOpaque = true + self.backgroundColor = NSColor(calibratedWhite: 0.1, alpha: 1.0) + + self.isMovableByWindowBackground = false + self.styleMask.remove(.resizable) + + self.updateMinimizedUI() + }) + } + + func restoreWindow() { + guard let savedFrame = fullFrame else { return } + + minimizedYPosition = self.frame.origin.y + + isMinimized = false + + self.titleVisibility = .visible + self.titlebarAppearsTransparent = true + self.standardWindowButton(.closeButton)?.alphaValue = 0 + self.standardWindowButton(.miniaturizeButton)?.alphaValue = 1.0 + self.standardWindowButton(.zoomButton)?.alphaValue = 0 + self.isOpaque = false + self.backgroundColor = NSColor(calibratedWhite: 0.1, alpha: 0.85) + + self.isMovableByWindowBackground = true + self.styleMask.insert(.resizable) + + if let contentView = self.contentView { + contentView.subviews.forEach { view in + if view.identifier?.rawValue == "minimizedLabel" { + view.removeFromSuperview() + } + } + } + + self.scrollView.isHidden = false + self.approvalBox.isHidden = !wasApprovalVisible + + NSAnimationContext.runAnimationGroup({ context in + context.duration = 0.3 + context.timingFunction = CAMediaTimingFunction(name: .easeInEaseOut) + self.animator().setFrame(savedFrame, display: true) + self.animator().alphaValue = 0.95 + }, completionHandler: nil) + } + + func updateMinimizedUI() { + guard let contentView = self.contentView else { return } + + contentView.subviews.forEach { view in + if view.identifier?.rawValue == "minimizedLabel" { + view.removeFromSuperview() + } + } + + let clickableView = ClickableView(frame: NSRect(x: 0, y: 0, width: minimizedWidth, height: minimizedHeight)) + clickableView.identifier = NSUserInterfaceItemIdentifier("minimizedLabel") + clickableView.onClicked = { [weak self] in + self?.restoreWindow() + } + contentView.addSubview(clickableView) + + let labelHeight: CGFloat = 30 + let labelY = (minimizedHeight - labelHeight) / 2 + let label = NSTextField(labelWithString: "●") + label.frame = NSRect(x: 0, y: labelY, width: minimizedWidth, height: labelHeight) + label.alignment = .center + label.font = NSFont.systemFont(ofSize: 20) + label.textColor = NSColor(red: 0.48, green: 0.64, blue: 0.97, alpha: 1.0) + label.backgroundColor = .clear + label.isBordered = false + label.isEditable = false + label.isSelectable = false + clickableView.addSubview(label) + } + + override func mouseDown(with event: NSEvent) { + if !isMinimized { + super.mouseDown(with: event) + } + } + + func setupUI() { + guard let contentView = self.contentView else { return } + + textView.frame = contentView.bounds + textView.autoresizingMask = [.width] + textView.isEditable = false + textView.isSelectable = true + textView.backgroundColor = NSColor(red: 0.10, green: 0.11, blue: 0.15, alpha: 1.0) + textView.textColor = NSColor(red: 0.66, green: 0.69, blue: 0.84, alpha: 1.0) + textView.font = NSFont.systemFont(ofSize: 13) + textView.textContainerInset = NSSize(width: 16, height: 16) + textView.isVerticallyResizable = true + textView.isHorizontallyResizable = false + textView.textContainer?.widthTracksTextView = true + textView.textContainer?.containerSize = NSSize(width: contentView.bounds.width, height: CGFloat.greatestFiniteMagnitude) + textView.textContainer?.lineBreakMode = .byWordWrapping + + scrollView.documentView = textView + scrollView.hasVerticalScroller = true + scrollView.autohidesScrollers = true + scrollView.frame = contentView.bounds + scrollView.autoresizingMask = [.width, .height] + contentView.addSubview(scrollView) + + approvalBox.wantsLayer = true + approvalBox.layer?.backgroundColor = NSColor(red: 0.14, green: 0.16, blue: 0.23, alpha: 1.0).cgColor + approvalBox.layer?.borderColor = NSColor(red: 0.48, green: 0.64, blue: 0.97, alpha: 1.0).cgColor + approvalBox.layer?.borderWidth = 2 + approvalBox.layer?.cornerRadius = 8 + approvalBox.frame = NSRect(x: 10, y: 10, width: contentView.bounds.width - 20, height: 50) + approvalBox.autoresizingMask = [.width, .maxYMargin] + approvalBox.isHidden = true + + approveButton.bezelStyle = .regularSquare + approveButton.target = self + approveButton.action = #selector(approveClicked) + approveButton.frame = NSRect(x: 10, y: 10, width: 120, height: 30) + approveButton.contentTintColor = NSColor(red: 0.45, green: 0.87, blue: 0.68, alpha: 1.0) + approveButton.wantsLayer = true + approveButton.layer?.backgroundColor = NSColor(red: 0.45, green: 0.87, blue: 0.68, alpha: 0.2).cgColor + approveButton.layer?.cornerRadius = 6 + + rejectButton.bezelStyle = .regularSquare + rejectButton.target = self + rejectButton.action = #selector(rejectClicked) + rejectButton.frame = NSRect(x: 140, y: 10, width: 120, height: 30) + rejectButton.contentTintColor = NSColor(red: 0.97, green: 0.46, blue: 0.56, alpha: 1.0) + rejectButton.wantsLayer = true + rejectButton.layer?.backgroundColor = NSColor(red: 0.97, green: 0.46, blue: 0.56, alpha: 0.2).cgColor + rejectButton.layer?.cornerRadius = 6 + + autoButton.bezelStyle = .regularSquare + autoButton.target = self + autoButton.action = #selector(autoClicked) + autoButton.frame = NSRect(x: 270, y: 10, width: 140, height: 30) + autoButton.contentTintColor = NSColor(red: 0.48, green: 0.64, blue: 0.97, alpha: 1.0) + autoButton.wantsLayer = true + autoButton.layer?.backgroundColor = NSColor(red: 0.48, green: 0.64, blue: 0.97, alpha: 0.2).cgColor + autoButton.layer?.cornerRadius = 6 + + approvalBox.addSubview(approveButton) + approvalBox.addSubview(rejectButton) + approvalBox.addSubview(autoButton) + + contentView.addSubview(approvalBox) + + fputs("UI ready for output\n", stderr) + fflush(stderr) + } + + @objc func approveClicked() { + sendApproval(action: 0) + } + + @objc func rejectClicked() { + sendApproval(action: 1) + } + + @objc func autoClicked() { + sendApproval(action: 2) + } + + func sendApproval(action: Int) { + guard let callID = currentCallID else { return } + let response = ApprovalResponse(call_id: callID, action: action) + if let jsonData = try? JSONEncoder().encode(response), + let jsonString = String(data: jsonData, encoding: .utf8) { + print(jsonString) + fflush(stdout) + } + approvalBox.isHidden = true + currentCallID = nil + } + + func appendText(_ text: String, color: NSColor? = nil) { + DispatchQueue.main.async { + let attrs: [NSAttributedString.Key: Any] = [ + .foregroundColor: color ?? self.textView.textColor!, + .font: self.textView.font! + ] + let attrString = NSAttributedString(string: text, attributes: attrs) + self.textView.textStorage?.append(attrString) + self.textView.scrollToEndOfDocument(nil) + } + } + + func showApproval(callID: String, toolName: String) { + DispatchQueue.main.async { + self.currentCallID = callID + self.approvalBox.isHidden = false + } + } + + // MARK: - Border Overlay Control + + func showBorderOverlay() { + DispatchQueue.main.async { + guard self.borderWindows.isEmpty else { return } + guard let screen = NSScreen.main else { return } + + let frame = screen.visibleFrame + let borderWidth: CGFloat = 3 + let borderColor = NSColor(red: 0.3, green: 0.6, blue: 1.0, alpha: 0.95) + + self.borderWindows.append(BorderWindow( + frame: NSRect(x: frame.minX, y: frame.maxY - borderWidth, width: frame.width, height: borderWidth), + color: borderColor + )) + + self.borderWindows.append(BorderWindow( + frame: NSRect(x: frame.minX, y: frame.minY, width: frame.width, height: borderWidth), + color: borderColor + )) + + self.borderWindows.append(BorderWindow( + frame: NSRect(x: frame.minX, y: frame.minY, width: borderWidth, height: frame.height), + color: borderColor + )) + + self.borderWindows.append(BorderWindow( + frame: NSRect(x: frame.maxX - borderWidth, y: frame.minY, width: borderWidth, height: frame.height), + color: borderColor + )) + + fputs("Border overlay shown\n", stderr) + fflush(stderr) + } + } + + func hideBorderOverlay() { + DispatchQueue.main.async { + for window in self.borderWindows { + window.close() + } + self.borderWindows.removeAll() + fputs("Border overlay hidden\n", stderr) + fflush(stderr) + } + } + + // MARK: - Click Indicator Control + + func showClickIndicator(x: CGFloat, y: CGFloat) { + DispatchQueue.main.async { + _ = ClickIndicatorWindow(x: x, y: y) + fputs("Click indicator shown at (\(x), \(y))\n", stderr) + fflush(stderr) + } + } + + func showMoveIndicator(fromX: CGFloat, fromY: CGFloat, toX: CGFloat, toY: CGFloat) { + DispatchQueue.main.async { + _ = MoveTrailWindow(fromX: fromX, fromY: fromY, toX: toX, toY: toY) + fputs("Move trail shown from (\(fromX), \(fromY)) to (\(toX), \(toY))\n", stderr) + fflush(stderr) + } + } +} + +// MARK: - Event Reader + +class EventReader { + let window: FloatingWindow + + init(window: FloatingWindow) { + self.window = window + } + + func startReading() { + DispatchQueue.global(qos: .userInitiated).async { + let handle = FileHandle.standardInput + + while true { + var lineData = Data() + + while true { + do { + guard let byte = try handle.read(upToCount: 1), !byte.isEmpty else { + return + } + + if byte[0] == 10 { + break + } + lineData.append(byte[0]) + } catch { + fputs("Read error: \(error)\n", stderr) + return + } + } + + if let line = String(data: lineData, encoding: .utf8), !line.isEmpty { + self.handleEvent(line) + } + } + } + } + + func handleEvent(_ jsonString: String) { + guard let data = jsonString.data(using: .utf8) else { + fputs("ERROR: Failed to convert string to data\n", stderr) + return + } + + do { + guard let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] else { + return + } + + let eventType = extractEventType(from: json) + let description = extractDescription(from: json) + + switch eventType { + case "Chat Start": + let blue = NSColor(red: 0.48, green: 0.64, blue: 0.97, alpha: 1.0) + window.appendText("\n● Starting...\n\n", color: blue) + + case "Chat Chunk": + window.appendText(description) + + case "Tool Approval": + if let toolCall = json["ToolCall"] as? [String: Any], + let callID = toolCall["id"] as? String, + let function = toolCall["function"] as? [String: Any], + let toolName = function["name"] as? String { + window.showApproval(callID: callID, toolName: toolName) + } + + case "Parallel Tools Start": + if let tools = json["Tools"] as? [[String: Any]] { + let blue = NSColor(red: 0.48, green: 0.64, blue: 0.97, alpha: 1.0) + for tool in tools { + if let toolName = tool["Name"] as? String { + let args = tool["Arguments"] as? String ?? "" + if !args.isEmpty { + window.appendText("\n▶ \(toolName): \(args)\n", color: blue) + } else { + window.appendText("\n▶ \(toolName)\n", color: blue) + } + } + } + } + + case "Tool Execution Progress": + if let toolName = json["ToolName"] as? String, + let status = json["Status"] as? String { + if status == "completed" { + let green = NSColor(red: 0.45, green: 0.87, blue: 0.68, alpha: 1.0) + window.appendText("✓ \(toolName) completed\n", color: green) + } else if status == "failed" { + let red = NSColor(red: 0.97, green: 0.46, blue: 0.56, alpha: 1.0) + window.appendText("✗ \(toolName) failed\n", color: red) + } + } + + case "Tool Failed", "Tool Rejected": + let red = NSColor(red: 0.97, green: 0.46, blue: 0.56, alpha: 1.0) + window.appendText("\n✗ \(description)\n", color: red) + + case "Approval Cleared": + DispatchQueue.main.async { + self.window.approvalBox.isHidden = true + } + + case "Border Show": + window.showBorderOverlay() + + case "Border Hide": + window.hideBorderOverlay() + + case "Click Indicator": + if let x = json["X"] as? Double, + let y = json["Y"] as? Double { + window.showClickIndicator(x: CGFloat(x), y: CGFloat(y)) + } + + case "Move Indicator": + if let fromX = json["FromX"] as? Int, + let fromY = json["FromY"] as? Int, + let toX = json["ToX"] as? Int, + let toY = json["ToY"] as? Int { + window.showMoveIndicator( + fromX: CGFloat(fromX), + fromY: CGFloat(fromY), + toX: CGFloat(toX), + toY: CGFloat(toY) + ) + } + + default: + break + } + + } catch { + fputs("ERROR: JSON parse error: \(error)\n", stderr) + } + } + + func extractEventType(from json: [String: Any]) -> String { + if json["BorderAction"] as? String == "show" { return "Border Show" } + if json["BorderAction"] as? String == "hide" { return "Border Hide" } + if json["FromX"] != nil && json["FromY"] != nil && json["ToX"] != nil && json["ToY"] != nil && json["MoveIndicator"] as? Bool == true { return "Move Indicator" } + if json["X"] != nil && json["Y"] != nil && json["ClickIndicator"] as? Bool == true { return "Click Indicator" } + if json["Content"] != nil { return "Chat Chunk" } + if json["Model"] != nil { return "Chat Start" } + if json["ToolCall"] != nil { return "Tool Approval" } + if json["Tools"] != nil { return "Parallel Tools Start" } + if json["ToolName"] != nil && json["Status"] != nil { return "Tool Execution Progress" } + if json["RequestID"] != nil && json["Timestamp"] != nil && + json["Content"] == nil && json["ToolCall"] == nil { return "Approval Cleared" } + return "Unknown" + } + + func extractDescription(from json: [String: Any]) -> String { + if let content = json["Content"] as? String { return content } + if let reason = json["Reason"] as? String { return "Interrupted: \(reason)" } + return "" + } +} + +// MARK: - Main + +signal(SIGTERM, SIG_IGN) +let sigTermSource = DispatchSource.makeSignalSource(signal: SIGTERM, queue: .main) +sigTermSource.setEventHandler { exit(0) } +sigTermSource.resume() + +NSApplication.shared.setActivationPolicy(.accessory) +NSApplication.shared.activate(ignoringOtherApps: true) + +let position = CommandLine.arguments.count > 1 ? CommandLine.arguments[1] : "top-right" +let alwaysOnTop = CommandLine.arguments.count > 2 ? CommandLine.arguments[2] == "true" : true + +let window = FloatingWindow(position: position, alwaysOnTop: alwaysOnTop) +let reader = EventReader(window: window) +reader.startReading() + +NSApplication.shared.run() diff --git a/internal/display/macos/client_darwin.go b/internal/display/macos/client_darwin.go index 941e4827..38e738c3 100644 --- a/internal/display/macos/client_darwin.go +++ b/internal/display/macos/client_darwin.go @@ -26,7 +26,6 @@ bool activateApp(const char *bundleIdentifier) { return false; } NSRunningApplication *app = [apps firstObject]; - // Use activate instead of activateWithOptions (deprecated in macOS 14+) return [app activateWithOptions:NSApplicationActivateAllWindows]; } } @@ -34,7 +33,6 @@ bool activateApp(const char *bundleIdentifier) { // Get the terminal app bundle ID (Terminal.app, iTerm2, VS Code, etc.) const char* getTerminalApp() { @autoreleasepool { - // Common terminal applications NSArray *terminalBundles = @[ @"com.apple.Terminal", // Terminal.app @"com.googlecode.iterm2", // iTerm2 @@ -73,8 +71,9 @@ import ( // MacOSClient provides macOS screen control operations using RobotGo type MacOSClient struct { - screenWidth int - screenHeight int + screenWidth int // Physical pixels (e.g., 2880 on 2x Retina) + screenHeight int // Physical pixels (e.g., 1800 on 2x Retina) + scaleFactor float64 // Display scale factor (1.0, 2.0, or 3.0) } // Modifier and key mapping tables @@ -125,12 +124,20 @@ var ( // NewMacOSClient creates a new macOS client func NewMacOSClient() (*MacOSClient, error) { - // Get screen dimensions - width, height := robotgo.GetScreenSize() + logicalWidth, logicalHeight := robotgo.GetScreenSize() + + scaleFactor := robotgo.ScaleF() + if scaleFactor == 0.0 { + scaleFactor = 1.0 + } + + physicalWidth := int(float64(logicalWidth) * scaleFactor) + physicalHeight := int(float64(logicalHeight) * scaleFactor) return &MacOSClient{ - screenWidth: width, - screenHeight: height, + screenWidth: physicalWidth, + screenHeight: physicalHeight, + scaleFactor: scaleFactor, }, nil } @@ -139,21 +146,42 @@ func (c *MacOSClient) Close() { // Nothing to close for RobotGo } -// GetScreenDimensions returns the screen width and height +// GetScreenDimensions returns the screen width and height in logical pixels +// This matches the coordinate space used by RobotGo's mouse operations func (c *MacOSClient) GetScreenDimensions() (int, int) { - return c.screenWidth, c.screenHeight + return c.ScalePhysicalToLogical(c.screenWidth, c.screenHeight) +} + +// GetScaleFactor returns the display scale factor (1.0, 2.0, or 3.0) +func (c *MacOSClient) GetScaleFactor() float64 { + return c.scaleFactor +} + +// ScalePhysicalToLogical converts physical pixel coordinates to logical coordinates +// This is used when passing coordinates to RobotGo APIs (mouse movement, clicks) +// On Retina displays, physical pixels are 2x or 3x logical pixels +func (c *MacOSClient) ScalePhysicalToLogical(x, y int) (int, int) { + if c.scaleFactor == 1.0 { + return x, y + } + logicalX := int(float64(x) / c.scaleFactor) + logicalY := int(float64(y) / c.scaleFactor) + return logicalX, logicalY } // CaptureScreen captures a screenshot and returns it as an image.Image +// Coordinates are expected in logical pixels (matching RobotGo's coordinate space) func (c *MacOSClient) CaptureScreen(x, y, width, height int) (image.Image, error) { + logicalWidth, logicalHeight := c.ScalePhysicalToLogical(c.screenWidth, c.screenHeight) + if width == 0 || height == 0 { - width = c.screenWidth - height = c.screenHeight + width = logicalWidth + height = logicalHeight } - if x < 0 || y < 0 || x+width > c.screenWidth || y+height > c.screenHeight { + if x < 0 || y < 0 || x+width > logicalWidth || y+height > logicalHeight { return nil, fmt.Errorf("invalid region: (%d,%d,%d,%d) exceeds screen bounds (%d,%d)", - x, y, width, height, c.screenWidth, c.screenHeight) + x, y, width, height, logicalWidth, logicalHeight) } bitmap := robotgo.CaptureScreen(x, y, width, height) @@ -185,16 +213,20 @@ func (c *MacOSClient) CaptureScreenBytes(x, y, width, height int) ([]byte, error } // GetCursorPosition returns the current cursor position +// Returns coordinates in top-left origin (Y=0 at top) to match screenshot coordinates func (c *MacOSClient) GetCursorPosition() (int, int, error) { x, y := robotgo.Location() return x, y, nil } -// MoveMouse moves the cursor to the specified coordinates (smooth movement) +// MoveMouse moves the cursor to the specified coordinates +// Coordinates should be in logical pixel space (matching GetScreenDimensions) +// Input coordinates use top-left origin (Y=0 at top) func (c *MacOSClient) MoveMouse(x, y int) error { - if x < 0 || y < 0 || x > c.screenWidth || y > c.screenHeight { + logicalWidth, logicalHeight := c.GetScreenDimensions() + if x < 0 || y < 0 || x > logicalWidth || y > logicalHeight { return fmt.Errorf("invalid coordinates: (%d,%d) exceeds screen bounds (%d,%d)", - x, y, c.screenWidth, c.screenHeight) + x, y, logicalWidth, logicalHeight) } robotgo.Move(x, y) diff --git a/internal/display/macos/client_darwin_test.go b/internal/display/macos/client_darwin_test.go new file mode 100644 index 00000000..2f11ae42 --- /dev/null +++ b/internal/display/macos/client_darwin_test.go @@ -0,0 +1,207 @@ +//go:build darwin + +package macos + +import ( + "testing" +) + +func TestScalePhysicalToLogical(t *testing.T) { + tests := []struct { + name string + scaleFactor float64 + physicalX int + physicalY int + expectedX int + expectedY int + }{ + { + name: "1x display (no scaling)", + scaleFactor: 1.0, + physicalX: 800, + physicalY: 600, + expectedX: 800, + expectedY: 600, + }, + { + name: "2x Retina display - full screen", + scaleFactor: 2.0, + physicalX: 2880, + physicalY: 1800, + expectedX: 1440, + expectedY: 900, + }, + { + name: "2x Retina display - center point", + scaleFactor: 2.0, + physicalX: 1440, + physicalY: 900, + expectedX: 720, + expectedY: 450, + }, + { + name: "2x Retina display - origin", + scaleFactor: 2.0, + physicalX: 0, + physicalY: 0, + expectedX: 0, + expectedY: 0, + }, + { + name: "2x Retina display - arbitrary point", + scaleFactor: 2.0, + physicalX: 1800, + physicalY: 800, + expectedX: 900, + expectedY: 400, + }, + { + name: "3x Retina display", + scaleFactor: 3.0, + physicalX: 4320, + physicalY: 2700, + expectedX: 1440, + expectedY: 900, + }, + { + name: "3x Retina display - center", + scaleFactor: 3.0, + physicalX: 2160, + physicalY: 1350, + expectedX: 720, + expectedY: 450, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &MacOSClient{ + screenWidth: int(float64(1440) * tt.scaleFactor), + screenHeight: int(float64(900) * tt.scaleFactor), + scaleFactor: tt.scaleFactor, + } + + gotX, gotY := client.ScalePhysicalToLogical(tt.physicalX, tt.physicalY) + + if gotX != tt.expectedX || gotY != tt.expectedY { + t.Errorf("ScalePhysicalToLogical(%d, %d) = (%d, %d), want (%d, %d)", + tt.physicalX, tt.physicalY, gotX, gotY, tt.expectedX, tt.expectedY) + } + }) + } +} + +func TestGetScaleFactor(t *testing.T) { + tests := []struct { + name string + scaleFactor float64 + }{ + {"1x display", 1.0}, + {"2x Retina", 2.0}, + {"3x Retina", 3.0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &MacOSClient{scaleFactor: tt.scaleFactor} + if got := client.GetScaleFactor(); got != tt.scaleFactor { + t.Errorf("GetScaleFactor() = %v, want %v", got, tt.scaleFactor) + } + }) + } +} + +func TestGetScreenDimensions(t *testing.T) { + tests := []struct { + name string + scaleFactor float64 + logicalWidth int + logicalHeight int + expectedWidth int + expectedHeight int + }{ + { + name: "1x display", + scaleFactor: 1.0, + logicalWidth: 1920, + logicalHeight: 1080, + expectedWidth: 1920, + expectedHeight: 1080, + }, + { + name: "2x Retina display", + scaleFactor: 2.0, + logicalWidth: 1440, + logicalHeight: 900, + expectedWidth: 1440, + expectedHeight: 900, + }, + { + name: "3x Retina display", + scaleFactor: 3.0, + logicalWidth: 1440, + logicalHeight: 900, + expectedWidth: 1440, + expectedHeight: 900, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + physicalWidth := int(float64(tt.logicalWidth) * tt.scaleFactor) + physicalHeight := int(float64(tt.logicalHeight) * tt.scaleFactor) + + client := &MacOSClient{ + screenWidth: physicalWidth, + screenHeight: physicalHeight, + scaleFactor: tt.scaleFactor, + } + + gotWidth, gotHeight := client.GetScreenDimensions() + + if gotWidth != tt.expectedWidth || gotHeight != tt.expectedHeight { + t.Errorf("GetScreenDimensions() = (%d, %d), want (%d, %d)", + gotWidth, gotHeight, tt.expectedWidth, tt.expectedHeight) + } + }) + } +} + +func TestScalePhysicalToLogical_RoundingBehavior(t *testing.T) { + // Test that integer division behaves correctly for odd numbers + client := &MacOSClient{scaleFactor: 2.0} + + tests := []struct { + name string + physicalX int + physicalY int + expectedX int + expectedY int + }{ + { + name: "Even coordinates", + physicalX: 100, + physicalY: 200, + expectedX: 50, + expectedY: 100, + }, + { + name: "Odd coordinates (rounds down)", + physicalX: 101, + physicalY: 201, + expectedX: 50, // 101/2 = 50.5 → 50 (int conversion truncates) + expectedY: 100, // 201/2 = 100.5 → 100 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotX, gotY := client.ScalePhysicalToLogical(tt.physicalX, tt.physicalY) + + if gotX != tt.expectedX || gotY != tt.expectedY { + t.Errorf("ScalePhysicalToLogical(%d, %d) = (%d, %d), want (%d, %d)", + tt.physicalX, tt.physicalY, gotX, gotY, tt.expectedX, tt.expectedY) + } + }) + } +} diff --git a/internal/display/macos/manager.go b/internal/display/macos/manager.go index 082cf95e..4c6475bf 100644 --- a/internal/display/macos/manager.go +++ b/internal/display/macos/manager.go @@ -4,12 +4,16 @@ package macos import ( "bufio" + "embed" "encoding/json" "fmt" "io" + "io/fs" "os" "os/exec" + "path/filepath" "runtime" + "strings" "sync" "syscall" "time" @@ -19,6 +23,9 @@ import ( logger "github.com/inference-gateway/cli/internal/logger" ) +//go:embed ComputerUse/build/ComputerUse.app +var computerUseApp embed.FS + // FloatingWindowManager manages the lifecycle of the floating progress window type FloatingWindowManager struct { cfg *config.Config @@ -28,7 +35,7 @@ type FloatingWindowManager struct { enabled bool eventSub chan domain.ChatEvent stopForward chan struct{} - swiftTmpFile string + appPath string monitorWg sync.WaitGroup // IPC fields (merged from ProcessManager) stdin io.Writer @@ -72,35 +79,34 @@ func NewFloatingWindowManager(cfg *config.Config, eventBridge *EventBridge, stat mgr.monitorWg.Add(1) go mgr.monitorProcess() + if cfg.ComputerUse.Screenshot.ShowOverlay { + time.Sleep(200 * time.Millisecond) + if err := mgr.ShowBorderOverlay(); err != nil { + logger.Warn("Failed to show border overlay", "error", err) + } + } + return mgr, nil } // launchWindow starts the Swift window process func (mgr *FloatingWindowManager) launchWindow() error { - swiftScript := mgr.generateSwiftScript() + appDir := filepath.Join(mgr.cfg.GetConfigDir(), "tmp", "ComputerUse.app") + mgr.appPath = appDir - tmpDir := mgr.cfg.GetConfigDir() + "/tmp" - if err := os.MkdirAll(tmpDir, 0755); err != nil { - return fmt.Errorf("failed to create temp directory: %w", err) - } - - tmpFile, err := os.CreateTemp(tmpDir, "floating_window_*.swift") - if err != nil { - return fmt.Errorf("failed to create temp file: %w", err) - } - - if _, err := tmpFile.Write([]byte(swiftScript)); err != nil { - _ = tmpFile.Close() - _ = os.Remove(tmpFile.Name()) - return fmt.Errorf("failed to write Swift script: %w", err) - } - if err := tmpFile.Close(); err != nil { - return fmt.Errorf("failed to close temp file: %w", err) + if _, err := os.Stat(appDir); os.IsNotExist(err) { + logger.Debug("Extracting ComputerUse.app from embedded binary", "path", appDir) + if err := mgr.extractApp(appDir); err != nil { + return fmt.Errorf("failed to extract embedded app: %w", err) + } + logger.Info("ComputerUse.app extracted successfully", "path", appDir) } - mgr.swiftTmpFile = tmpFile.Name() + position := mgr.cfg.ComputerUse.FloatingWindow.Position + alwaysOnTop := fmt.Sprintf("%t", mgr.cfg.ComputerUse.FloatingWindow.AlwaysOnTop) - cmd := exec.Command("swift", tmpFile.Name()) + executablePath := filepath.Join(appDir, "Contents", "MacOS", "ComputerUse") + cmd := exec.Command(executablePath, position, alwaysOnTop) stdin, err := cmd.StdinPipe() if err != nil { @@ -118,16 +124,13 @@ func (mgr *FloatingWindowManager) launchWindow() error { } if err := cmd.Start(); err != nil { - return fmt.Errorf("failed to start Swift process: %w", err) + return fmt.Errorf("failed to start ComputerUse.app: %w", err) } go func() { buf := make([]byte, 1024) for { - n, err := stderr.Read(buf) - if n > 0 { - logger.Debug("Swift stderr", "output", string(buf[:n])) - } + _, err := stderr.Read(buf) if err != nil { break } @@ -143,6 +146,53 @@ func (mgr *FloatingWindowManager) launchWindow() error { return nil } +// extractApp extracts the embedded .app bundle to the target directory +func (mgr *FloatingWindowManager) extractApp(targetDir string) error { + const appPrefix = "ComputerUse/build/ComputerUse.app" + + return fs.WalkDir(computerUseApp, ".", func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + + if path == "." { + return nil + } + + if !strings.HasPrefix(path, appPrefix) { + return nil + } + + if path == appPrefix { + return nil + } + + relPath := path[len(appPrefix)+1:] + + targetPath := filepath.Join(targetDir, relPath) + + if d.IsDir() { + return os.MkdirAll(targetPath, 0755) + } + + data, err := computerUseApp.ReadFile(path) + if err != nil { + return fmt.Errorf("failed to read embedded file %s: %w", path, err) + } + + perm := os.FileMode(0644) + if filepath.Base(filepath.Dir(targetPath)) == "MacOS" { + perm = 0755 + } + + if err := os.WriteFile(targetPath, data, perm); err != nil { + return fmt.Errorf("failed to write file %s: %w", targetPath, err) + } + + return nil + }) +} + // forwardEvents forwards chat events from the EventBridge to the Swift window func (mgr *FloatingWindowManager) forwardEvents() { for { @@ -157,7 +207,6 @@ func (mgr *FloatingWindowManager) forwardEvents() { } case <-mgr.stopForward: - logger.Debug("Event forwarding stopped") return } } @@ -213,14 +262,8 @@ func (mgr *FloatingWindowManager) Shutdown() error { mgr.monitorWg.Wait() - if mgr.swiftTmpFile != "" { - if err := os.Remove(mgr.swiftTmpFile); err != nil { - logger.Debug("Failed to remove temp Swift file", "error", err, "path", mgr.swiftTmpFile) - } else { - logger.Debug("Removed temp Swift file", "path", mgr.swiftTmpFile) - } - mgr.swiftTmpFile = "" - } + // Note: We don't delete the .app - it persists in .infer/ for future runs + // This avoids re-extracting the embedded .app on every launch logger.Info("Floating window manager shutdown complete") @@ -239,7 +282,6 @@ func (mgr *FloatingWindowManager) shutdownProcess() error { // sendTermSignal sends SIGTERM to the Swift process, falls back to SIGKILL if needed func (mgr *FloatingWindowManager) sendTermSignal() error { if err := mgr.cmd.Process.Signal(syscall.SIGTERM); err != nil { - logger.Debug("Failed to send SIGTERM, using SIGKILL", "error", err) if killErr := mgr.cmd.Process.Kill(); killErr != nil { logger.Warn("Failed to kill Swift process", "error", killErr) return fmt.Errorf("failed to kill process: %w", killErr) @@ -248,8 +290,6 @@ func (mgr *FloatingWindowManager) sendTermSignal() error { return nil } -// IPC Methods (merged from ProcessManager) - // writeEvent sends an event to the Swift process via stdin func (mgr *FloatingWindowManager) writeEvent(event domain.ChatEvent) error { mgr.stdinMutex.Lock() @@ -274,7 +314,6 @@ func (mgr *FloatingWindowManager) startApprovalListener() { for scanner.Scan() { select { case <-mgr.stopListener: - logger.Debug("Approval listener stopped") return default: } @@ -308,8 +347,6 @@ func (mgr *FloatingWindowManager) handleApprovalResponse(resp ApprovalResponse) mgr.approvalMutex.Lock() defer mgr.approvalMutex.Unlock() - logger.Debug("handleApprovalResponse called", "call_id", resp.CallID, "registered_channels", len(mgr.approvalChans)) - ch, exists := mgr.approvalChans[resp.CallID] if !exists { logger.Warn("Received approval for unknown call ID", "call_id", resp.CallID, "known_call_ids", mgr.getCallIDs()) @@ -321,11 +358,8 @@ func (mgr *FloatingWindowManager) handleApprovalResponse(resp ApprovalResponse) select { case ch <- resp.Action: delete(mgr.approvalChans, resp.CallID) - logger.Debug("Approval processed", "call_id", resp.CallID, "action", resp.Action) - if mgr.stateManager != nil { mgr.stateManager.ClearApprovalUIState() - logger.Debug("Cleared approval UI state from floating window") } default: logger.Warn("Approval channel blocked", "call_id", resp.CallID) @@ -338,7 +372,6 @@ func (mgr *FloatingWindowManager) registerApprovalChannel(callID string, ch chan defer mgr.approvalMutex.Unlock() mgr.approvalChans[callID] = ch - logger.Debug("Registered approval channel", "call_id", callID) } // getCallIDs returns a list of registered call IDs (for debugging) @@ -358,742 +391,58 @@ func (mgr *FloatingWindowManager) stopApprovalListener() { if !mgr.listenerStopped { close(mgr.stopListener) mgr.listenerStopped = true - logger.Debug("Approval listener stopped") } } -// generateSwiftScript generates the Swift script for the floating window -// -//nolint:funlen // Swift script embedding requires long function -func (mgr *FloatingWindowManager) generateSwiftScript() string { - position := mgr.cfg.ComputerUse.FloatingWindow.Position - alwaysOnTop := mgr.cfg.ComputerUse.FloatingWindow.AlwaysOnTop - - return fmt.Sprintf(` -import Cocoa -import Foundation -import WebKit - -// MARK: - Models - -struct ApprovalResponse: Codable { - let call_id: String - let action: Int // 0=Approve, 1=Reject, 2=AutoAccept -} +// ShowBorderOverlay sends an event to show the blue border around the screen +func (mgr *FloatingWindowManager) ShowBorderOverlay() error { + if !mgr.enabled { + return nil + } -// MARK: - Window Setup - -class AgentProgressWindow: NSPanel { - let webView = WKWebView() - var isTerminalReady = false - var isMinimized = false - var fullFrame: NSRect? - let minimizedWidth: CGFloat = 40 - let minimizedHeight: CGFloat = 150 - - init() { - let screenFrame = NSScreen.main!.visibleFrame - let windowWidth: CGFloat = 450 - let windowHeight: CGFloat = 600 - - // Position based on configuration - var xPos: CGFloat - let position = "%s" - switch position { - case "top-left": - xPos = screenFrame.minX + 20 - case "top-right": - xPos = screenFrame.maxX - windowWidth - 20 - default: - xPos = screenFrame.maxX - windowWidth - 20 - } - - let yPos = screenFrame.maxY - windowHeight - 20 - let frame = NSRect(x: xPos, y: yPos, width: windowWidth, height: windowHeight) - - let styleMask: NSWindow.StyleMask = [.titled, .resizable, .miniaturizable, .fullSizeContentView] - super.init(contentRect: frame, styleMask: styleMask, backing: .buffered, defer: false) - - self.title = "Computer Use" - self.isFloatingPanel = true - self.level = %t ? .floating : .normal - self.collectionBehavior = [.canJoinAllSpaces, .fullScreenAuxiliary] - self.hidesOnDeactivate = false - - self.isOpaque = false - self.alphaValue = 0.90 - self.backgroundColor = NSColor(calibratedWhite: 0.1, alpha: 0.85) - - self.contentView?.wantsLayer = true - if let layer = self.contentView?.layer { - layer.cornerRadius = 12 - layer.masksToBounds = false - } - - self.hasShadow = true - self.invalidateShadow() - - self.titlebarAppearsTransparent = true - self.titleVisibility = .visible - - self.isMovableByWindowBackground = true - - self.standardWindowButton(.closeButton)?.alphaValue = 0 - self.standardWindowButton(.zoomButton)?.alphaValue = 0 - - if let minimizeButton = self.standardWindowButton(.miniaturizeButton) { - minimizeButton.target = self - minimizeButton.action = #selector(customMinimize) - } - - setupUI() - - self.orderFront(nil) - } - - @objc func customMinimize() { - if isMinimized { - restoreWindow() - } else { - minimizeToSide() - } - } - - func minimizeToSide() { - guard let screen = NSScreen.main else { return } - isMinimized = true - fullFrame = self.frame - - let screenFrame = screen.visibleFrame - let xPos = screenFrame.maxX - minimizedWidth - let yPos = screenFrame.midY - (minimizedHeight / 2) - let minimizedFrame = NSRect(x: xPos, y: yPos, width: minimizedWidth, height: minimizedHeight) - - NSAnimationContext.runAnimationGroup({ context in - context.duration = 0.3 - context.timingFunction = CAMediaTimingFunction(name: .easeInEaseOut) - self.animator().setFrame(minimizedFrame, display: true) - self.animator().alphaValue = 1.0 - }, completionHandler: { - self.webView.isHidden = true - self.titleVisibility = .hidden - self.titlebarAppearsTransparent = true - self.standardWindowButton(.closeButton)?.alphaValue = 0 - self.standardWindowButton(.miniaturizeButton)?.alphaValue = 0 - self.standardWindowButton(.zoomButton)?.alphaValue = 0 - self.isOpaque = true - self.backgroundColor = NSColor(calibratedWhite: 0.1, alpha: 1.0) - self.updateMinimizedUI() - }) - } - - func restoreWindow() { - guard let savedFrame = fullFrame else { return } - isMinimized = false - - self.titleVisibility = .visible - self.titlebarAppearsTransparent = true - self.standardWindowButton(.closeButton)?.alphaValue = 0 - self.standardWindowButton(.miniaturizeButton)?.alphaValue = 1.0 - self.standardWindowButton(.zoomButton)?.alphaValue = 0 - self.isOpaque = false - self.backgroundColor = NSColor(calibratedWhite: 0.1, alpha: 0.85) - - if let contentView = self.contentView { - contentView.subviews.forEach { view in - if view.identifier?.rawValue == "minimizedLabel" { - view.removeFromSuperview() - } - } - } - - self.webView.isHidden = false - - NSAnimationContext.runAnimationGroup({ context in - context.duration = 0.3 - context.timingFunction = CAMediaTimingFunction(name: .easeInEaseOut) - self.animator().setFrame(savedFrame, display: true) - self.animator().alphaValue = 0.95 - }, completionHandler: nil) - } - - func updateMinimizedUI() { - guard let contentView = self.contentView else { return } - - contentView.subviews.forEach { view in - if view.identifier?.rawValue == "minimizedLabel" { - view.removeFromSuperview() - } - } - - // Create a simple dot indicator, vertically centered - let labelHeight: CGFloat = 30 - let labelY = (minimizedHeight - labelHeight) / 2 - let label = NSTextField(labelWithString: "●") - label.identifier = NSUserInterfaceItemIdentifier("minimizedLabel") - label.frame = NSRect(x: 0, y: labelY, width: minimizedWidth, height: labelHeight) - label.alignment = .center - label.font = NSFont.systemFont(ofSize: 20) - label.textColor = NSColor(calibratedRed: 0.48, green: 0.64, blue: 0.97, alpha: 1.0) // Blue accent color - label.backgroundColor = .clear - label.isBordered = false - label.isEditable = false - label.isSelectable = false - contentView.addSubview(label) - } - - override func mouseDown(with event: NSEvent) { - super.mouseDown(with: event) - if isMinimized { - customMinimize() - } - } - - func setupUI() { - guard let contentView = self.contentView else { return } - - // WebView leaves 30px at top for draggable title bar - let titleBarHeight: CGFloat = 30 - webView.frame = NSRect(x: 0, y: 0, width: contentView.bounds.width, height: contentView.bounds.height - titleBarHeight) - webView.autoresizingMask = [.width, .height] - webView.setValue(false, forKey: "drawsBackground") - - let html = """ - - - - - - - - - -
-
-
- - - -
-
- - - - """ - - let userController = webView.configuration.userContentController - userController.add(self, name: "terminalReady") - userController.add(self, name: "approval") - - let consoleScript = """ - console.log = function(msg) { - window.webkit.messageHandlers.consoleLog.postMessage(String(msg)); - }; - """ - let consoleUserScript = WKUserScript(source: consoleScript, injectionTime: .atDocumentStart, forMainFrameOnly: true) - userController.addUserScript(consoleUserScript) - userController.add(self, name: "consoleLog") - - webView.loadHTMLString(html, baseURL: nil) - contentView.addSubview(webView) - } - - func escapeForJS(_ text: String) -> String { - return text.replacingOccurrences(of: "\\", with: "\\\\") - .replacingOccurrences(of: "'", with: "\\'") - .replacingOccurrences(of: "\n", with: "\\n") - .replacingOccurrences(of: "\r", with: "") - } - - func writeToTerminal(_ text: String) { - guard isTerminalReady else { return } - let escaped = escapeForJS(text) - let js = "window.term.write('\(escaped)');" - webView.evaluateJavaScript(js, completionHandler: nil) - } - - func writeLineToTerminal(_ text: String) { - guard isTerminalReady else { return } - let escaped = escapeForJS(text) - let js = "window.term.writeln('\(escaped)');" - webView.evaluateJavaScript(js, completionHandler: nil) - } - - func formatToolArguments(_ jsonString: String) -> String { - guard let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] else { - return jsonString.count > 100 ? String(jsonString.prefix(100)) + "..." : jsonString - } - - var lines: [String] = [] - for (key, value) in json.sorted(by: { $0.key < $1.key }) { - let valueStr: String - if let str = value as? String { - valueStr = str.count > 60 ? String(str.prefix(60)) + "..." : str - } else if let num = value as? NSNumber { - valueStr = "\(num)" - } else { - valueStr = "\(value)" - } - lines.append("\(key): \(valueStr)") - } - return lines.joined(separator: "\n") - } - - func addEvent(type: String, description: String, callID: String? = nil, toolName: String? = nil, toolArgs: String? = nil) { - DispatchQueue.main.async { - let esc = "\u{001B}" - let cyan = "\(esc)[36m" - let yellow = "\(esc)[33m" - let magenta = "\(esc)[35m" - let gray = "\(esc)[90m" - let reset = "\(esc)[0m" - - switch type { - case "Chat Start": - self.writeLineToTerminal("") - self.writeLineToTerminal("\(cyan)●\(reset) Starting...") - self.writeLineToTerminal("") - case "Chat Chunk": - self.writeToTerminal(description) - case "Tool Approval": - if let cid = callID, let tool = toolName { - self.showApprovalButtons(callID: cid, toolName: tool) - } - case "Tool Execution": - if let tool = toolName { - let green = "\(esc)[32m" - let blue = "\(esc)[34m" - let bold = "\(esc)[1m" - let dim = "\(esc)[2m" - - self.writeLineToTerminal("") - self.writeLineToTerminal("\(blue)▶\(reset) \(bold)\(tool)\(reset)") - - // Format arguments nicely - if let args = toolArgs, !args.isEmpty && args != "{}" { - let formattedArgs = self.formatToolArguments(args) - for line in formattedArgs.split(separator: "\n") { - self.writeLineToTerminal(" \(dim)\(line)\(reset)") - } - } - } else { - self.writeLineToTerminal("\(gray) \(description)\(reset)") - } - case "Tool Failed", "Tool Rejected": - let red = "\(esc)[31m" - let bold = "\(esc)[1m" - self.writeLineToTerminal("") - self.writeLineToTerminal("\(red)✗ \(bold)\(description)\(reset)") - self.writeLineToTerminal("") - case "Approval Cleared": - self.hideApprovalBox() - case "Cancelled": - let red = "\(esc)[31m" - let bold = "\(esc)[1m" - self.writeLineToTerminal("") - self.writeLineToTerminal("\(red)✗ \(bold)\(description)\(reset)") - self.writeLineToTerminal("") - case "Event": - // Skip generic events to reduce noise - break - case "Optimization": - self.writeLineToTerminal("") - self.writeLineToTerminal("\(magenta)⚡ \(description)\(reset)") - default: - self.writeLineToTerminal("\(gray)[\(type)] \(description)\(reset)") - } - } - } - - func hideApprovalBox() { - guard isTerminalReady else { return } - let js = "document.getElementById('approvalBox').classList.remove('visible'); window.currentCallID = null;" - webView.evaluateJavaScript(js, completionHandler: nil) - } - - func showApprovalButtons(callID: String, toolName: String) { - guard isTerminalReady else { - return - } - let escapedCallID = escapeForJS(callID) - let escapedToolName = escapeForJS(toolName) - let js = "window.showApproval('\(escapedCallID)', '\(escapedToolName)');" - webView.evaluateJavaScript(js, completionHandler: nil) - } - - func sendApproval(callID: String, action: Int) { - let response = ApprovalResponse(call_id: callID, action: action) - if let jsonData = try? JSONEncoder().encode(response), - let jsonString = String(data: jsonData, encoding: .utf8) { - print(jsonString) // Send to stdout - fflush(stdout) - } - } -} + event := domain.BorderOverlayEvent{ + BaseChatEvent: domain.BaseChatEvent{ + RequestID: "border-show", + Timestamp: time.Now(), + }, + BorderAction: "show", + } -// MARK: - WebKit Message Handler - -extension AgentProgressWindow: WKScriptMessageHandler { - func userContentController(_ userContentController: WKUserContentController, didReceive message: WKScriptMessage) { - if message.name == "terminalReady" { - isTerminalReady = true - fputs("Terminal ready for output\n", stderr) - } else if message.name == "approval", - let data = message.body as? [String: Any], - let callID = data["call_id"] as? String, - let action = data["action"] as? Int { - fputs("Received approval from UI: callID=\(callID), action=\(action)\n", stderr) - sendApproval(callID: callID, action: action) - } else if message.name == "consoleLog" { - fputs("JS console: \(message.body)\n", stderr) - } - } + return mgr.writeEvent(event) } -// MARK: - Event Reading - -class EventReader { - let window: AgentProgressWindow - - init(window: AgentProgressWindow) { - self.window = window - } - - func startReading() { - DispatchQueue.global(qos: .userInitiated).async { - let handle = FileHandle.standardInput - - while true { - var lineData = Data() - - while true { - do { - guard let byte = try handle.read(upToCount: 1), !byte.isEmpty else { - return - } - - if byte[0] == 10 { - break - } - lineData.append(byte[0]) - } catch { - fputs("Read error: \(error)\n", stderr) - return - } - } - - if let line = String(data: lineData, encoding: .utf8), !line.isEmpty { - self.handleEvent(line) - } - } - } - } - - func handleEvent(_ jsonString: String) { - guard let data = jsonString.data(using: .utf8) else { - fputs("ERROR: Failed to convert string to data\n", stderr) - return - } - - do { - if let json = try JSONSerialization.jsonObject(with: data) as? [String: Any] { - - let eventType = self.extractEventType(from: json) - let description = self.extractDescription(from: json) - - var callID: String? = nil - var toolName: String? = nil - var toolArgs: String? = nil - - if eventType == "Parallel Tools Start" { - if let tools = json["Tools"] as? [[String: Any]] { - for tool in tools { - let tName = tool["Name"] as? String - let tArgs = tool["Arguments"] as? String - - DispatchQueue.main.async { - self.window.addEvent(type: "Tool Execution", description: "", callID: nil, toolName: tName, toolArgs: tArgs) - } - } - } - } else if eventType == "Chat Complete" { - if let toolCalls = json["ToolCalls"] as? [[String: Any]] { - for toolCall in toolCalls { - if let function = toolCall["function"] as? [String: Any] { - let tName = function["name"] as? String - let tArgs = function["arguments"] as? String - - DispatchQueue.main.async { - self.window.addEvent(type: "Tool Execution", description: "", callID: nil, toolName: tName, toolArgs: tArgs) - } - } - } - } - } else if eventType == "Tool Execution" { - toolName = json["ToolName"] as? String - toolArgs = json["Arguments"] as? String - } else if eventType == "Tool Execution Progress" { - if let status = json["Status"] as? String, status == "failed" { - if let tName = json["ToolName"] as? String { - let failureMsg = "Tool: \(tName) failed" - DispatchQueue.main.async { - self.window.addEvent(type: "Tool Failed", description: failureMsg, callID: nil, toolName: nil, toolArgs: nil) - } - } - } - return - } else if eventType == "Tool Approval" { - if let toolCall = json["ToolCall"] as? [String: Any] { - callID = toolCall["id"] as? String - if let function = toolCall["function"] as? [String: Any] { - toolName = function["name"] as? String - } - } - } - - if eventType != "Chat Chunk" { - fputs("Parsed event: type=\(eventType), desc=\(description), callID=\(callID ?? "nil")\n", stderr) - } - - // Only call addEvent for events we haven't already handled - if eventType != "Parallel Tools Start" && eventType != "Chat Complete" { - DispatchQueue.main.async { - self.window.addEvent(type: eventType, description: description, callID: callID, toolName: toolName, toolArgs: toolArgs) - } - } - } - } catch { - fputs("ERROR: JSON parse error: \(error)\n", stderr) - } - } - - func extractEventType(from json: [String: Any]) -> String { - - if json["Tools"] != nil { - return "Parallel Tools Start" - } - - if json["Content"] != nil { - return "Chat Chunk" - } - - if json["Message"] != nil && json["IsActive"] != nil { - return "Optimization" - } - - if json["ToolCalls"] != nil { - return "Chat Complete" - } - - if json["ToolCallID"] != nil && json["ToolName"] != nil && json["Status"] != nil { - return "Tool Execution Progress" - } - - if json["ToolName"] != nil && json["Arguments"] != nil { - return "Tool Execution" - } - - if json["ToolCall"] != nil { - return "Tool Approval" - } - - if json["Reason"] != nil { - return "Cancelled" - } - - if json["Model"] != nil { - return "Chat Start" - } - - if json["RequestID"] != nil && json["Timestamp"] != nil && - json["Content"] == nil && json["ToolCall"] == nil && json["Model"] == nil && json["Message"] == nil { - return "Approval Cleared" - } - - if json["RequestID"] != nil { - return "Event" - } - return "Unknown" - } - - func extractDescription(from json: [String: Any]) -> String { - // For Content field (ChatChunkEvent), preserve ALL whitespace including spaces and newlines - if let content = json["Content"] as? String { - return content // Don't trim - spaces and newlines are important! - } - - if let reason = json["Reason"] as? String { - return "Interrupted: \(reason)" - } - - if let message = json["Message"] as? String { - return message - } - - if let model = json["Model"] as? String { - return "Model: \(model)" - } - - if let toolCall = json["ToolCall"] as? [String: Any], - let function = toolCall["function"] as? [String: Any], - let toolName = function["name"] as? String { - return "Tool approval: \(toolName)" - } - - if let status = json["Status"] as? String { - return status - } - - if let error = json["Error"] as? String { - return "Error: \(error)" - } - - if let toolName = json["ToolName"] as? String { - if let status = json["Status"] as? String { - return "\(toolName): \(status)" - } - return "Tool: \(toolName)" - } - - if json["RequestID"] != nil { - return "Event received" - } - return "No description" - } -} +// HideBorderOverlay sends an event to hide the blue border around the screen +func (mgr *FloatingWindowManager) HideBorderOverlay() error { + if !mgr.enabled { + return nil + } -// MARK: - Main + event := domain.BorderOverlayEvent{ + BaseChatEvent: domain.BaseChatEvent{ + RequestID: "border-hide", + Timestamp: time.Now(), + }, + BorderAction: "hide", + } -signal(SIGTERM, SIG_IGN) -let sigTermSource = DispatchSource.makeSignalSource(signal: SIGTERM, queue: .main) -sigTermSource.setEventHandler { - exit(0) + return mgr.writeEvent(event) } -sigTermSource.resume() -NSApplication.shared.setActivationPolicy(.accessory) -NSApplication.shared.activate(ignoringOtherApps: true) - -let window = AgentProgressWindow() +// ShowClickIndicator sends an event to show a visual click indicator at the given coordinates +func (mgr *FloatingWindowManager) ShowClickIndicator(x, y int) error { + if !mgr.enabled { + return nil + } -let reader = EventReader(window: window) -reader.startReading() + event := domain.ClickIndicatorEvent{ + BaseChatEvent: domain.BaseChatEvent{ + RequestID: "click-indicator", + Timestamp: time.Now(), + }, + X: x, + Y: y, + ClickIndicator: true, + } -NSApplication.shared.run() -`, position, alwaysOnTop) + return mgr.writeEvent(event) } diff --git a/internal/display/macos/overlay_darwin.go b/internal/display/macos/overlay_darwin.go deleted file mode 100644 index 8820407d..00000000 --- a/internal/display/macos/overlay_darwin.go +++ /dev/null @@ -1,155 +0,0 @@ -//go:build darwin - -package macos - -import ( - "fmt" - "os/exec" - "syscall" - "time" - - logger "github.com/inference-gateway/cli/internal/logger" -) - -// OverlayWindow represents a macOS overlay indicator using osascript -type OverlayWindow struct { - cmd *exec.Cmd - visible bool -} - -// NewOverlayWindow creates a new macOS overlay window using persistent alert -func NewOverlayWindow() (*OverlayWindow, error) { - logger.Info("Creating macOS overlay window") - return &OverlayWindow{ - visible: false, - }, nil -} - -// Show displays a screen border overlay using Swift -func (w *OverlayWindow) Show() error { - logger.Info("Attempting to show macOS screen border overlay") - - swiftScript := ` -import Cocoa -import Foundation - -class BorderWindow: NSWindow { - init(frame: NSRect, color: NSColor) { - super.init(contentRect: frame, styleMask: .borderless, backing: .buffered, defer: false) - self.backgroundColor = color - self.isOpaque = false - self.level = .floating - self.ignoresMouseEvents = true - self.collectionBehavior = [.canJoinAllSpaces, .stationary] - self.orderFront(nil) - } -} - -signal(SIGTERM, SIG_IGN) -let sigTermSource = DispatchSource.makeSignalSource(signal: SIGTERM, queue: .main) -sigTermSource.setEventHandler { - exit(0) -} -sigTermSource.resume() - -let screen = NSScreen.main! -let frame = screen.visibleFrame -let borderWidth: CGFloat = 3 -let borderColor = NSColor(red: 0.3, green: 0.6, blue: 1.0, alpha: 0.95) - -_ = BorderWindow(frame: NSRect(x: frame.minX, y: frame.maxY - borderWidth, width: frame.width, height: borderWidth), color: borderColor) -_ = BorderWindow(frame: NSRect(x: frame.minX, y: frame.minY, width: frame.width, height: borderWidth), color: borderColor) -_ = BorderWindow(frame: NSRect(x: frame.minX, y: frame.minY, width: borderWidth, height: frame.height), color: borderColor) -_ = BorderWindow(frame: NSRect(x: frame.maxX - borderWidth, y: frame.minY, width: borderWidth, height: frame.height), color: borderColor) - -RunLoop.main.run() -` - - logger.Debug("Compiling and running Swift screen border overlay") - - cmd := exec.Command("swift", "-") - stdin, err := cmd.StdinPipe() - if err != nil { - logger.Error("Failed to create stdin pipe", "error", err) - return fmt.Errorf("failed to create stdin pipe: %w", err) - } - - if err := cmd.Start(); err != nil { - logger.Error("Failed to start Swift process", "error", err) - return fmt.Errorf("failed to start Swift process: %w", err) - } - - go func() { - defer func() { _ = stdin.Close() }() - if _, err := stdin.Write([]byte(swiftScript)); err != nil { - logger.Error("Failed to write Swift script to stdin", "error", err) - } - }() - - w.cmd = cmd - w.visible = true - logger.Info("Screen border overlay shown successfully") - return nil -} - -// Hide hides the screen border overlay by terminating the process gracefully -func (w *OverlayWindow) Hide() error { - logger.Info("Hiding macOS screen border overlay") - - if w.cmd == nil { - w.visible = false - return nil - } - - if w.cmd.Process == nil { - w.cmd = nil - w.visible = false - return nil - } - - cmd := w.cmd - w.cmd = nil - w.visible = false - - if err := cmd.Process.Signal(syscall.SIGTERM); err != nil { - logger.Debug("Failed to send SIGTERM, using SIGKILL", "error", err) - if err := cmd.Process.Kill(); err != nil { - logger.Warn("Failed to kill overlay process", "error", err) - return fmt.Errorf("failed to kill overlay process: %w", err) - } - } - - done := make(chan error, 1) - go func() { - done <- cmd.Wait() - }() - - select { - case err := <-done: - if err != nil { - logger.Debug("Overlay process exited with error", "error", err) - } else { - logger.Debug("Overlay process exited cleanly") - } - case <-time.After(5 * time.Second): - logger.Warn("Overlay process did not exit within timeout, force killing") - if err := cmd.Process.Kill(); err != nil { - logger.Error("Failed to force kill overlay process", "error", err) - return fmt.Errorf("failed to force kill overlay process: %w", err) - } - <-done - } - - return nil -} - -// IsVisible returns whether the overlay is currently visible -func (w *OverlayWindow) IsVisible() bool { - return w.visible -} - -// Destroy cleans up the screen border overlay by killing the process -func (w *OverlayWindow) Destroy() error { - logger.Info("Destroying macOS screen border overlay") - return w.Hide() -} diff --git a/internal/display/macos/overlay_darwin_test.go b/internal/display/macos/overlay_darwin_test.go deleted file mode 100644 index 5de72ac2..00000000 --- a/internal/display/macos/overlay_darwin_test.go +++ /dev/null @@ -1,149 +0,0 @@ -//go:build darwin - -package macos - -import ( - "runtime" - "testing" - "time" - - assert "github.com/stretchr/testify/assert" - require "github.com/stretchr/testify/require" -) - -func TestOverlayWindow_Creation(t *testing.T) { - if runtime.GOOS != "darwin" { - t.Skip("macOS only test") - } - - overlay, err := NewOverlayWindow() - require.NoError(t, err) - require.NotNil(t, overlay) - require.False(t, overlay.visible) - defer func() { - if overlay != nil { - _ = overlay.Destroy() - } - }() -} - -func TestOverlayWindow_ShowHide(t *testing.T) { - if runtime.GOOS != "darwin" { - t.Skip("macOS only test") - } - - overlay, err := NewOverlayWindow() - require.NoError(t, err) - require.NotNil(t, overlay) - defer func() { _ = overlay.Destroy() }() - - err = overlay.Show() - require.NoError(t, err) - assert.True(t, overlay.visible) - - time.Sleep(100 * time.Millisecond) - - assert.True(t, overlay.IsVisible()) - - err = overlay.Hide() - require.NoError(t, err) - assert.False(t, overlay.visible) - - time.Sleep(100 * time.Millisecond) - - assert.False(t, overlay.IsVisible()) -} - -func TestOverlayWindow_Lifecycle(t *testing.T) { - if runtime.GOOS != "darwin" { - t.Skip("macOS only test") - } - - overlay, err := NewOverlayWindow() - require.NoError(t, err) - require.NotNil(t, overlay) - - err = overlay.Show() - require.NoError(t, err) - assert.True(t, overlay.visible) - - time.Sleep(100 * time.Millisecond) - - err = overlay.Hide() - require.NoError(t, err) - assert.False(t, overlay.visible) - - err = overlay.Destroy() - require.NoError(t, err) - assert.False(t, overlay.visible) -} - -func TestOverlayWindow_DestroyWithoutShow(t *testing.T) { - if runtime.GOOS != "darwin" { - t.Skip("macOS only test") - } - - overlay, err := NewOverlayWindow() - require.NoError(t, err) - require.NotNil(t, overlay) - - err = overlay.Destroy() - require.NoError(t, err) -} - -func TestOverlayWindow_OperationsOnEmptyWindow(t *testing.T) { - if runtime.GOOS != "darwin" { - t.Skip("macOS only test") - } - - overlay := &OverlayWindow{cmd: nil, visible: false} - - err := overlay.Hide() - assert.NoError(t, err) - - err = overlay.Destroy() - assert.NoError(t, err) - - assert.False(t, overlay.IsVisible()) -} - -func TestOverlayWindow_MultipleShowCalls(t *testing.T) { - if runtime.GOOS != "darwin" { - t.Skip("macOS only test") - } - - overlay, err := NewOverlayWindow() - require.NoError(t, err) - require.NotNil(t, overlay) - defer func() { _ = overlay.Destroy() }() - - err = overlay.Show() - require.NoError(t, err) - - err = overlay.Show() - require.NoError(t, err) - - assert.True(t, overlay.visible) -} - -func TestOverlayWindow_MultipleHideCalls(t *testing.T) { - if runtime.GOOS != "darwin" { - t.Skip("macOS only test") - } - - overlay, err := NewOverlayWindow() - require.NoError(t, err) - require.NotNil(t, overlay) - defer func() { _ = overlay.Destroy() }() - - err = overlay.Show() - require.NoError(t, err) - - err = overlay.Hide() - require.NoError(t, err) - - err = overlay.Hide() - require.NoError(t, err) - - assert.False(t, overlay.visible) -} diff --git a/internal/display/macos/overlay_stub.go b/internal/display/macos/overlay_stub.go deleted file mode 100644 index 85e281bc..00000000 --- a/internal/display/macos/overlay_stub.go +++ /dev/null @@ -1,33 +0,0 @@ -//go:build !darwin - -package macos - -import "fmt" - -// OverlayWindow is a stub for non-macOS platforms -type OverlayWindow struct{} - -// NewOverlayWindow returns an error on non-macOS platforms -func NewOverlayWindow() (*OverlayWindow, error) { - return nil, fmt.Errorf("overlay window only supported on macOS") -} - -// Show is a no-op on non-macOS platforms -func (w *OverlayWindow) Show() error { - return fmt.Errorf("overlay window only supported on macOS") -} - -// Hide is a no-op on non-macOS platforms -func (w *OverlayWindow) Hide() error { - return fmt.Errorf("overlay window only supported on macOS") -} - -// IsVisible always returns false on non-macOS platforms -func (w *OverlayWindow) IsVisible() bool { - return false -} - -// Destroy is a no-op on non-macOS platforms -func (w *OverlayWindow) Destroy() error { - return fmt.Errorf("overlay window only supported on macOS") -} diff --git a/internal/display/x11/client.go b/internal/display/x11/client.go index 88f5449d..da69795e 100644 --- a/internal/display/x11/client.go +++ b/internal/display/x11/client.go @@ -73,8 +73,6 @@ func NewX11Client(display string) (*X11Client, error) { keybind.Initialize(xu) - logger.Debug("Successfully connected to X11 display", "display", display) - return &X11Client{ xu: xu, conn: xu.Conn(), diff --git a/internal/domain/chat_events.go b/internal/domain/chat_events.go index 9cfe8e43..31bca1ce 100644 --- a/internal/domain/chat_events.go +++ b/internal/domain/chat_events.go @@ -56,3 +56,27 @@ type TodoUpdateChatEvent struct { BaseChatEvent Todos []TodoItem } + +// BorderOverlayEvent indicates the screen border overlay should be shown or hidden +type BorderOverlayEvent struct { + BaseChatEvent + BorderAction string +} + +// ClickIndicatorEvent indicates a visual click indicator should be shown at coordinates +type ClickIndicatorEvent struct { + BaseChatEvent + X int `json:"X"` + Y int `json:"Y"` + ClickIndicator bool `json:"ClickIndicator"` +} + +// MoveIndicatorEvent indicates a visual move indicator should be shown at coordinates +type MoveIndicatorEvent struct { + BaseChatEvent + FromX int `json:"FromX"` + FromY int `json:"FromY"` + ToX int `json:"ToX"` + ToY int `json:"ToY"` + MoveIndicator bool `json:"MoveIndicator"` +} diff --git a/internal/domain/context.go b/internal/domain/context.go index 2a1035f6..a192f9eb 100644 --- a/internal/domain/context.go +++ b/internal/domain/context.go @@ -28,3 +28,9 @@ const ChatHandlerKey ContextKey = "chat_handler" // SessionIDKey is the context key for the current conversation session ID // This allows shortcuts to access the session ID when they need it (e.g., /export) const SessionIDKey ContextKey = "session_id" + +// DirectExecutionKey is the context key for direct tool execution +// When this key is set to true in the context, it indicates that the tool +// was invoked directly by the user (e.g., via !! command) rather than by the LLM +// This allows tools to adjust behavior (e.g., skip coordinate scaling for mouse operations) +const DirectExecutionKey ContextKey = "direct_execution" diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 4a50f643..3ab8b6cc 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -40,13 +40,15 @@ type ScreenRegion struct { // Screenshot represents a captured screenshot with metadata type Screenshot struct { - ID string `json:"id"` - Timestamp time.Time `json:"timestamp"` - Data string `json:"data"` // base64 encoded image - Width int `json:"width"` - Height int `json:"height"` - Format string `json:"format"` // "png" or "jpeg" - Method string `json:"method"` // "x11" or "wayland" + ID string `json:"id"` + Timestamp time.Time `json:"timestamp"` + Data string `json:"data"` // base64 encoded image + Width int `json:"width"` // Final image width (after scaling) + Height int `json:"height"` // Final image height (after scaling) + Format string `json:"format"` // "png" or "jpeg" + Method string `json:"method"` // "x11" or "wayland" + OriginalWidth int `json:"original_width"` // Screen width before scaling + OriginalHeight int `json:"original_height"` // Screen height before scaling } // ScreenshotProvider defines the interface for getting screenshots from a buffer @@ -330,6 +332,14 @@ type StateManager interface { GetMessageEditState() *MessageEditState ClearMessageEditState() IsEditingMessage() bool + + // Focus management (macOS computer-use tools) + SetLastFocusedApp(appID string) + GetLastFocusedApp() string + ClearLastFocusedApp() + SetLastClickCoordinates(x, y int) + GetLastClickCoordinates() (x, y int) + ClearLastClickCoordinates() } // FileService handles file operations diff --git a/internal/domain/state.go b/internal/domain/state.go index e75b192e..0def6c7a 100644 --- a/internal/domain/state.go +++ b/internal/domain/state.go @@ -43,6 +43,13 @@ type ApplicationState struct { // Message Edit State messageEditState *MessageEditState + // Focus Management (macOS computer-use tools) + // Stores the bundle ID of the app that was clicked (for restoring focus before keyboard operations) + lastFocusedAppID string + // Stores the coordinates of the last click (for re-clicking before keyboard operations) + lastClickX int + lastClickY int + // Debugging debugMode bool } @@ -854,6 +861,41 @@ func (s *ApplicationState) IsEditingMessage() bool { return s.messageEditState != nil } +// Focus Management Methods (macOS computer-use tools) + +// SetLastFocusedApp stores the bundle ID of the last focused application +// This is used to restore focus before keyboard operations +func (s *ApplicationState) SetLastFocusedApp(appID string) { + s.lastFocusedAppID = appID +} + +// GetLastFocusedApp returns the bundle ID of the last focused application +func (s *ApplicationState) GetLastFocusedApp() string { + return s.lastFocusedAppID +} + +// ClearLastFocusedApp clears the stored focused app +func (s *ApplicationState) ClearLastFocusedApp() { + s.lastFocusedAppID = "" +} + +// SetLastClickCoordinates stores the coordinates of the last click +func (s *ApplicationState) SetLastClickCoordinates(x, y int) { + s.lastClickX = x + s.lastClickY = y +} + +// GetLastClickCoordinates returns the coordinates of the last click +func (s *ApplicationState) GetLastClickCoordinates() (x, y int) { + return s.lastClickX, s.lastClickY +} + +// ClearLastClickCoordinates clears the stored click coordinates +func (s *ApplicationState) ClearLastClickCoordinates() { + s.lastClickX = 0 + s.lastClickY = 0 +} + // StateSnapshot represents a point-in-time snapshot of application state type StateSnapshot struct { CurrentView string `json:"current_view"` diff --git a/internal/handlers/chat_command_handler.go b/internal/handlers/chat_command_handler.go index df26b669..023f7f57 100644 --- a/internal/handlers/chat_command_handler.go +++ b/internal/handlers/chat_command_handler.go @@ -197,6 +197,7 @@ func (c *ChatCommandHandler) executeBashCommandAsync(command string, toolCallID ctx := context.WithValue(context.Background(), domain.ToolApprovedKey, true) ctx = context.WithValue(ctx, domain.BashOutputCallbackKey, domain.BashOutputCallback(bashCallback)) ctx = context.WithValue(ctx, domain.BashDetachChannelKey, (<-chan struct{})(detachChan)) + ctx = context.WithValue(ctx, domain.DirectExecutionKey, true) result, err := c.handler.toolService.ExecuteToolDirect(ctx, toolCallFunc) if err != nil { @@ -537,6 +538,7 @@ func (c *ChatCommandHandler) executeToolCommandAsync(toolName, argsJSON, toolCal } ctx := context.WithValue(context.Background(), domain.ToolApprovedKey, true) + ctx = context.WithValue(ctx, domain.DirectExecutionKey, true) result, err := c.handler.toolService.ExecuteToolDirect(ctx, toolCallFunc) if err != nil { eventChan <- domain.ShowErrorEvent{ diff --git a/internal/services/circular_screenshot_buffer.go b/internal/services/circular_screenshot_buffer.go index f2bd5632..0a144fe2 100644 --- a/internal/services/circular_screenshot_buffer.go +++ b/internal/services/circular_screenshot_buffer.go @@ -177,7 +177,11 @@ func (b *CircularScreenshotBuffer) writeToDisk(screenshot *domain.Screenshot) er return fmt.Errorf("failed to decode base64 data: %w", err) } - filename := filepath.Join(b.tempDir, fmt.Sprintf("screenshot-%s.png", screenshot.ID)) + extension := screenshot.Format + if extension == "" { + extension = "png" + } + filename := filepath.Join(b.tempDir, fmt.Sprintf("screenshot-%s.%s", screenshot.ID, extension)) if err := os.WriteFile(filename, imageData, 0644); err != nil { return fmt.Errorf("failed to write file: %w", err) } @@ -187,8 +191,10 @@ func (b *CircularScreenshotBuffer) writeToDisk(screenshot *domain.Screenshot) er // deleteFromDisk removes a screenshot file from disk func (b *CircularScreenshotBuffer) deleteFromDisk(id string) { - filename := filepath.Join(b.tempDir, fmt.Sprintf("screenshot-%s.png", id)) - if err := os.Remove(filename); err != nil && !os.IsNotExist(err) { - logger.Warn("Failed to delete screenshot file", "error", err, "filename", filename) + for _, ext := range []string{"png", "jpeg", "jpg"} { + filename := filepath.Join(b.tempDir, fmt.Sprintf("screenshot-%s.%s", id, ext)) + if err := os.Remove(filename); err == nil { + return + } } } diff --git a/internal/services/mcp_manager.go b/internal/services/mcp_manager.go index 606506c1..81078a9c 100644 --- a/internal/services/mcp_manager.go +++ b/internal/services/mcp_manager.go @@ -440,7 +440,7 @@ func (m *MCPManager) handleDiscoveryFailure(client *mcpClient, maxRetries int, e } client.mu.Unlock() - logger.Error("MCP server health check failed (tool discovery failed)", + logger.Debug("MCP server health check failed (tool discovery failed)", "server", client.serverName, "retryAttempt", retryCount, "maxRetries", maxRetries, diff --git a/internal/services/screenshot_server.go b/internal/services/screenshot_server.go index b84b0e15..fb17b4d1 100644 --- a/internal/services/screenshot_server.go +++ b/internal/services/screenshot_server.go @@ -1,20 +1,23 @@ package services import ( + "bytes" "context" "encoding/json" "fmt" + "image" + "image/color" + "image/jpeg" + _ "image/png" "net" "net/http" "path/filepath" - "runtime" "strconv" "sync" "time" config "github.com/inference-gateway/cli/config" display "github.com/inference-gateway/cli/internal/display" - "github.com/inference-gateway/cli/internal/display/macos" domain "github.com/inference-gateway/cli/internal/domain" logger "github.com/inference-gateway/cli/internal/logger" @@ -24,17 +27,16 @@ import ( // ScreenshotServer provides an HTTP API for screenshot streaming type ScreenshotServer struct { - cfg *config.Config - port int - server *http.Server - buffer *CircularScreenshotBuffer - captureCtx context.Context - captureStop context.CancelFunc - mu sync.RWMutex - sessionID string - imageSvc domain.ImageService - running bool - overlayWindow *macos.OverlayWindow + cfg *config.Config + port int + server *http.Server + buffer *CircularScreenshotBuffer + captureCtx context.Context + captureStop context.CancelFunc + mu sync.RWMutex + sessionID string + imageSvc domain.ImageService + running bool } // NewScreenshotServer creates a new screenshot server @@ -104,7 +106,7 @@ func (s *ScreenshotServer) Start() error { s.running = true - s.showOverlayIfEnabled() + // Note: Border overlay is now managed by FloatingWindow.app via BorderOverlayEvent interval := s.cfg.ComputerUse.Screenshot.CaptureInterval if interval <= 0 { @@ -148,17 +150,6 @@ func (s *ScreenshotServer) Stop() error { } } - if s.overlayWindow != nil { - if err := s.overlayWindow.Hide(); err != nil { - logger.Warn("Failed to hide overlay window", "error", err) - } - if err := s.overlayWindow.Destroy(); err != nil { - logger.Warn("Failed to destroy overlay window", "error", err) - } - s.overlayWindow = nil - logger.Info("Screenshot overlay window destroyed") - } - s.running = false return nil @@ -171,31 +162,6 @@ func (s *ScreenshotServer) Port() int { return s.port } -// showOverlayIfEnabled shows the overlay window if configured -func (s *ScreenshotServer) showOverlayIfEnabled() { - if !s.cfg.ComputerUse.Screenshot.ShowOverlay { - return - } - - if runtime.GOOS != "darwin" { - return - } - - overlay, err := macos.NewOverlayWindow() - if err != nil { - logger.Warn("Failed to create overlay window", "error", err) - return - } - - s.overlayWindow = overlay - if err := s.overlayWindow.Show(); err != nil { - logger.Warn("Failed to show overlay window", "error", err) - return - } - - logger.Info("Screenshot overlay window shown") -} - // startCaptureLoop runs the background screenshot capture loop func (s *ScreenshotServer) startCaptureLoop() { interval := s.cfg.ComputerUse.Screenshot.CaptureInterval @@ -213,8 +179,6 @@ func (s *ScreenshotServer) startCaptureLoop() { case <-ticker.C: if err := s.captureScreenshot(); err != nil { logger.Warn("Screenshot capture failed", "error", err) - } else if s.cfg.ComputerUse.Screenshot.LogCaptures { - logger.Debug("Screenshot captured") } } } @@ -247,18 +211,74 @@ func (s *ScreenshotServer) captureScreenshot() error { return fmt.Errorf("failed to capture screenshot: %w", err) } - imageAttachment, err := s.imageSvc.ReadImageFromBinary(imageBytes, "screenshot.png") + imgConfig, _, err := image.DecodeConfig(bytes.NewReader(imageBytes)) + if err != nil { + logger.Warn("Failed to decode image config, using controller dimensions", "error", err) + } else { + actualWidth := imgConfig.Width + actualHeight := imgConfig.Height + + if actualWidth != width || actualHeight != height { + width = actualWidth + height = actualHeight + } + } + + img, _, err := image.Decode(bytes.NewReader(imageBytes)) + if err != nil { + return fmt.Errorf("failed to decode screenshot: %w", err) + } + + logicalWidth, logicalHeight, err := controller.GetScreenDimensions(s.captureCtx) + if err != nil { + logger.Warn("Failed to get logical dimensions", "error", err) + } else if width != logicalWidth || height != logicalHeight { + img = resizeImage(img, logicalWidth, logicalHeight) + width = logicalWidth + height = logicalHeight + } + + originalWidth := width + originalHeight := height + + targetW := s.cfg.ComputerUse.Screenshot.TargetWidth + targetH := s.cfg.ComputerUse.Screenshot.TargetHeight + + if targetW > 0 && targetH > 0 { + img = resizeImage(img, targetW, targetH) + width = targetW + height = targetH + + logger.Info("Screenshot force-resized to target dimensions", + "from", fmt.Sprintf("%dx%d", originalWidth, originalHeight), + "to", fmt.Sprintf("%dx%d", width, height)) + } + + quality := s.cfg.ComputerUse.Screenshot.Quality + if quality <= 0 || quality > 100 { + quality = 60 + } + + var buf bytes.Buffer + if err := jpeg.Encode(&buf, img, &jpeg.Options{Quality: quality}); err != nil { + return fmt.Errorf("failed to encode JPEG: %w", err) + } + imageBytes = buf.Bytes() + + imageAttachment, err := s.imageSvc.ReadImageFromBinary(imageBytes, "screenshot.jpeg") if err != nil { return fmt.Errorf("failed to process image: %w", err) } screenshot := &domain.Screenshot{ - Timestamp: time.Now(), - Data: imageAttachment.Data, - Width: width, - Height: height, - Format: "png", - Method: displayProvider.GetDisplayInfo().Name, + Timestamp: time.Now(), + Data: imageAttachment.Data, + Width: width, + Height: height, + Format: s.cfg.ComputerUse.Screenshot.Format, + Method: displayProvider.GetDisplayInfo().Name, + OriginalWidth: originalWidth, + OriginalHeight: originalHeight, } return s.buffer.Add(screenshot) @@ -337,7 +357,6 @@ func (s *ScreenshotServer) handleGetStatus(w http.ResponseWriter, r *http.Reques } // GetLatestScreenshot retrieves the latest screenshot from the buffer -// Implements the ScreenshotProvider interface for use by GetLatestScreenshotTool func (s *ScreenshotServer) GetLatestScreenshot() (*domain.Screenshot, error) { s.mu.RLock() defer s.mu.RUnlock() @@ -348,3 +367,65 @@ func (s *ScreenshotServer) GetLatestScreenshot() (*domain.Screenshot, error) { return s.buffer.GetLatest() } + +// resizeImage resizes an image to target dimensions using bilinear interpolation +// This provides better quality than nearest neighbor for LLM visual understanding +func resizeImage(src image.Image, targetWidth, targetHeight int) image.Image { + srcBounds := src.Bounds() + srcWidth := srcBounds.Dx() + srcHeight := srcBounds.Dy() + + dst := image.NewRGBA(image.Rect(0, 0, targetWidth, targetHeight)) + + xRatio := float64(srcWidth-1) / float64(targetWidth-1) + yRatio := float64(srcHeight-1) / float64(targetHeight-1) + + for dstY := range targetHeight { + for dstX := range targetWidth { + srcXFloat := float64(dstX) * xRatio + srcYFloat := float64(dstY) * yRatio + + srcX := int(srcXFloat) + srcY := int(srcYFloat) + + fracX := srcXFloat - float64(srcX) + fracY := srcYFloat - float64(srcY) + + srcX1 := srcX + srcY1 := srcY + srcX2 := srcX + 1 + srcY2 := srcY + 1 + + if srcX2 >= srcWidth { + srcX2 = srcWidth - 1 + } + if srcY2 >= srcHeight { + srcY2 = srcHeight - 1 + } + + c11 := src.At(srcBounds.Min.X+srcX1, srcBounds.Min.Y+srcY1) + c21 := src.At(srcBounds.Min.X+srcX2, srcBounds.Min.Y+srcY1) + c12 := src.At(srcBounds.Min.X+srcX1, srcBounds.Min.Y+srcY2) + c22 := src.At(srcBounds.Min.X+srcX2, srcBounds.Min.Y+srcY2) + + r11, g11, b11, a11 := c11.RGBA() + r21, g21, b21, a21 := c21.RGBA() + r12, g12, b12, a12 := c12.RGBA() + r22, g22, b22, a22 := c22.RGBA() + + w1 := (1 - fracX) * (1 - fracY) + w2 := fracX * (1 - fracY) + w3 := (1 - fracX) * fracY + w4 := fracX * fracY + + r := uint8((float64(r11)*w1 + float64(r21)*w2 + float64(r12)*w3 + float64(r22)*w4) / 257) + g := uint8((float64(g11)*w1 + float64(g21)*w2 + float64(g12)*w3 + float64(g22)*w4) / 257) + b := uint8((float64(b11)*w1 + float64(b21)*w2 + float64(b12)*w3 + float64(b22)*w4) / 257) + a := uint8((float64(a11)*w1 + float64(a21)*w2 + float64(a12)*w3 + float64(a22)*w4) / 257) + + dst.SetRGBA(dstX, dstY, color.RGBA{R: r, G: g, B: b, A: a}) + } + } + + return dst +} diff --git a/internal/services/state_manager.go b/internal/services/state_manager.go index 8c2dd73c..d7de4eed 100644 --- a/internal/services/state_manager.go +++ b/internal/services/state_manager.go @@ -780,3 +780,47 @@ type HealthStatus struct { LastStateChange time.Time `json:"last_state_change"` MemoryUsageKB int `json:"memory_usage_kb"` } + +// Focus management methods (macOS computer-use tools) + +// SetLastFocusedApp stores the bundle ID of the last focused application +func (sm *StateManager) SetLastFocusedApp(appID string) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.state.SetLastFocusedApp(appID) +} + +// GetLastFocusedApp returns the bundle ID of the last focused application +func (sm *StateManager) GetLastFocusedApp() string { + sm.mutex.RLock() + defer sm.mutex.RUnlock() + return sm.state.GetLastFocusedApp() +} + +// ClearLastFocusedApp clears the stored focused app +func (sm *StateManager) ClearLastFocusedApp() { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.state.ClearLastFocusedApp() +} + +// SetLastClickCoordinates stores the coordinates of the last click +func (sm *StateManager) SetLastClickCoordinates(x, y int) { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.state.SetLastClickCoordinates(x, y) +} + +// GetLastClickCoordinates returns the coordinates of the last click +func (sm *StateManager) GetLastClickCoordinates() (x, y int) { + sm.mutex.RLock() + defer sm.mutex.RUnlock() + return sm.state.GetLastClickCoordinates() +} + +// ClearLastClickCoordinates clears the stored click coordinates +func (sm *StateManager) ClearLastClickCoordinates() { + sm.mutex.Lock() + defer sm.mutex.Unlock() + sm.state.ClearLastClickCoordinates() +} diff --git a/internal/services/tools/coordinate_scaler.go b/internal/services/tools/coordinate_scaler.go new file mode 100644 index 00000000..7b26a268 --- /dev/null +++ b/internal/services/tools/coordinate_scaler.go @@ -0,0 +1,54 @@ +package tools + +import "math" + +// ScaleAPIToScreen converts coordinates from API space (Claude's screenshot) +// to screen space (actual display) using Anthropic's proportional scaling approach. +// +// This implementation follows the official Anthropic computer-use-demo strategy: +// - Simple proportional scaling with separate X/Y factors +// - No letterboxing - screenshots are force-resized to exact dimensions +// - Handles aspect ratio mismatches automatically through independent scaling +// +// Parameters: +// - apiX, apiY: Coordinates from Claude's response (in screenshot space) +// - apiWidth, apiHeight: Target screenshot dimensions (e.g., 1024x768) +// - screenWidth, screenHeight: Actual logical screen dimensions +// +// Returns: +// - screenX, screenY: Coordinates in actual screen space for mouse operations +// +// Example: +// +// API screenshot: 1024x768 +// Actual screen: 2048x1536 +// API coordinate (512, 384) → Screen coordinate (1024, 768) +func ScaleAPIToScreen(apiX, apiY, apiWidth, apiHeight, screenWidth, screenHeight int) (int, int) { + xScale := float64(apiWidth) / float64(screenWidth) + yScale := float64(apiHeight) / float64(screenHeight) + + screenX := int(math.Round(float64(apiX) / xScale)) + screenY := int(math.Round(float64(apiY) / yScale)) + + return screenX, screenY +} + +// ScaleScreenToAPI converts coordinates from screen space to API space. +// This is the inverse of ScaleAPIToScreen and is used when capturing screenshots. +// +// Parameters: +// - screenX, screenY: Coordinates in actual screen space +// - screenWidth, screenHeight: Actual logical screen dimensions +// - apiWidth, apiHeight: Target screenshot dimensions +// +// Returns: +// - apiX, apiY: Coordinates in API/screenshot space +func ScaleScreenToAPI(screenX, screenY, screenWidth, screenHeight, apiWidth, apiHeight int) (int, int) { + xScale := float64(apiWidth) / float64(screenWidth) + yScale := float64(apiHeight) / float64(screenHeight) + + apiX := int(math.Round(float64(screenX) * xScale)) + apiY := int(math.Round(float64(screenY) * yScale)) + + return apiX, apiY +} diff --git a/internal/services/tools/keyboard_type.go b/internal/services/tools/keyboard_type.go index 0b269520..89dc4f0b 100644 --- a/internal/services/tools/keyboard_type.go +++ b/internal/services/tools/keyboard_type.go @@ -19,16 +19,18 @@ type KeyboardTypeTool struct { formatter domain.BaseFormatter rateLimiter domain.RateLimiter displayProvider display.Provider + stateManager domain.StateManager } // NewKeyboardTypeTool creates a new keyboard type tool -func NewKeyboardTypeTool(cfg *config.Config, rateLimiter domain.RateLimiter, displayProvider display.Provider) *KeyboardTypeTool { +func NewKeyboardTypeTool(cfg *config.Config, rateLimiter domain.RateLimiter, displayProvider display.Provider, stateManager domain.StateManager) *KeyboardTypeTool { return &KeyboardTypeTool{ config: cfg, enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.KeyboardType.Enabled, formatter: domain.NewBaseFormatter("KeyboardType"), rateLimiter: rateLimiter, displayProvider: displayProvider, + stateManager: stateManager, } } @@ -120,6 +122,10 @@ func (t *KeyboardTypeTool) Execute(ctx context.Context, args map[string]any) (*d } }() + if t.stateManager != nil { + t.restoreInputFocus(ctx, controller) + } + var execErr error if hasText { execErr = controller.TypeText(ctx, text, t.config.ComputerUse.KeyboardType.TypingDelayMs) @@ -243,3 +249,46 @@ func (t *KeyboardTypeTool) ShouldCollapseArg(key string) bool { func (t *KeyboardTypeTool) ShouldAlwaysExpand() bool { return false } + +func (t *KeyboardTypeTool) restoreInputFocus(ctx context.Context, controller display.DisplayController) { + clickX, clickY := t.stateManager.GetLastClickCoordinates() + if clickX <= 0 && clickY <= 0 { + return + } + + lastFocusedApp := t.stateManager.GetLastFocusedApp() + if lastFocusedApp != "" { + t.activateLastFocusedApp(ctx, controller, lastFocusedApp) + } + + t.reClickInputField(ctx, controller, clickX, clickY) +} + +func (t *KeyboardTypeTool) activateLastFocusedApp(ctx context.Context, controller display.DisplayController, appID string) { + focusManager, ok := controller.(display.FocusManager) + if !ok { + return + } + + if err := focusManager.ActivateApp(ctx, appID); err != nil { + logger.Warn("Failed to restore app focus", "app_id", appID, "error", err) + return + } + + time.Sleep(100 * time.Millisecond) +} + +func (t *KeyboardTypeTool) reClickInputField(ctx context.Context, controller display.DisplayController, x, y int) { + if err := controller.MoveMouse(ctx, x, y); err != nil { + logger.Warn("Failed to move mouse to stored coordinates", "x", x, "y", y, "error", err) + return + } + + mouseButton := display.ParseMouseButton("left") + if err := controller.ClickMouse(ctx, mouseButton, 1); err != nil { + logger.Warn("Failed to re-click input field", "error", err) + return + } + + time.Sleep(100 * time.Millisecond) +} diff --git a/internal/services/tools/keyboard_type_test.go b/internal/services/tools/keyboard_type_test.go index 0b7f0df1..d36934c4 100644 --- a/internal/services/tools/keyboard_type_test.go +++ b/internal/services/tools/keyboard_type_test.go @@ -58,7 +58,7 @@ func TestKeyboardTypeTool_TypingDelay(t *testing.T) { }, } - tool := NewKeyboardTypeTool(cfg, utils.NewRateLimiter(cfg.ComputerUse.RateLimit), nil) + tool := NewKeyboardTypeTool(cfg, utils.NewRateLimiter(cfg.ComputerUse.RateLimit), nil, nil) if tool.config.ComputerUse.KeyboardType.TypingDelayMs != tt.delayMs { t.Errorf("Expected delay %d ms, got %d ms", tt.delayMs, tool.config.ComputerUse.KeyboardType.TypingDelayMs) @@ -70,7 +70,7 @@ func TestKeyboardTypeTool_TypingDelay(t *testing.T) { func TestKeyboardTypeTool_ConfigDefault(t *testing.T) { cfg := config.DefaultConfig() - expectedDelay := 200 + expectedDelay := 100 actualDelay := cfg.ComputerUse.KeyboardType.TypingDelayMs if actualDelay != expectedDelay { @@ -95,7 +95,7 @@ func TestKeyboardTypeTool_Validation(t *testing.T) { }, } - tool := NewKeyboardTypeTool(cfg, utils.NewRateLimiter(cfg.ComputerUse.RateLimit), nil) + tool := NewKeyboardTypeTool(cfg, utils.NewRateLimiter(cfg.ComputerUse.RateLimit), nil, nil) tests := []struct { name string @@ -206,7 +206,7 @@ func TestKeyboardTypeTool_FormatResult(t *testing.T) { }, } - tool := NewKeyboardTypeTool(cfg, utils.NewRateLimiter(cfg.ComputerUse.RateLimit), nil) + tool := NewKeyboardTypeTool(cfg, utils.NewRateLimiter(cfg.ComputerUse.RateLimit), nil, nil) result := &domain.ToolExecutionResult{ ToolName: "KeyboardType", diff --git a/internal/services/tools/mouse_click.go b/internal/services/tools/mouse_click.go index ab5be236..7b76e608 100644 --- a/internal/services/tools/mouse_click.go +++ b/internal/services/tools/mouse_click.go @@ -19,16 +19,18 @@ type MouseClickTool struct { formatter domain.BaseFormatter rateLimiter domain.RateLimiter displayProvider display.Provider + stateManager domain.StateManager } // NewMouseClickTool creates a new mouse click tool -func NewMouseClickTool(cfg *config.Config, rateLimiter domain.RateLimiter, displayProvider display.Provider) *MouseClickTool { +func NewMouseClickTool(cfg *config.Config, rateLimiter domain.RateLimiter, displayProvider display.Provider, stateManager domain.StateManager) *MouseClickTool { return &MouseClickTool{ config: cfg, enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.MouseClick.Enabled, formatter: domain.NewBaseFormatter("MouseClick"), rateLimiter: rateLimiter, displayProvider: displayProvider, + stateManager: stateManager, } } @@ -75,55 +77,20 @@ func (t *MouseClickTool) Execute(ctx context.Context, args map[string]any) (*dom start := time.Now() if err := t.rateLimiter.CheckAndRecord("MouseClick"); err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseClick", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: err.Error(), - }, nil + return t.errorResult(args, start, err.Error()), nil } - button, ok := args["button"].(string) - if !ok { - button = "left" - } - - clicks := 1 - if clicksArg, ok := args["clicks"].(float64); ok { - clicks = int(clicksArg) - } - - var finalX, finalY int - shouldMove := false - - if xArg, xOk := args["x"].(float64); xOk { - if yArg, yOk := args["y"].(float64); yOk { - finalX = int(xArg) - finalY = int(yArg) - shouldMove = true - } - } + button := t.getButton(args) + clicks := t.getClicks(args) + finalX, finalY, shouldMove := t.getCoordinates(args) if t.displayProvider == nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseClick", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: "no compatible display platform detected", - }, nil + return t.errorResult(args, start, "no compatible display platform detected"), nil } controller, err := t.displayProvider.GetController() if err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseClick", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: fmt.Sprintf("failed to get platform controller: %v", err), - }, nil + return t.errorResult(args, start, fmt.Sprintf("failed to get platform controller: %v", err)), nil } defer func() { if closeErr := controller.Close(); closeErr != nil { @@ -131,32 +98,18 @@ func (t *MouseClickTool) Execute(ctx context.Context, args map[string]any) (*dom } }() - if shouldMove { - if err := controller.MoveMouse(ctx, finalX, finalY); err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseClick", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: fmt.Sprintf("failed to move mouse: %v", err), - }, nil - } - } else { - x, y, _ := controller.GetCursorPosition(ctx) - finalX, finalY = x, y + finalX, finalY, err = t.handleMovement(ctx, controller, shouldMove, finalX, finalY) + if err != nil { + return t.errorResult(args, start, err.Error()), nil } mouseButton := display.ParseMouseButton(button) if err := controller.ClickMouse(ctx, mouseButton, clicks); err != nil { - return &domain.ToolExecutionResult{ - ToolName: "MouseClick", - Arguments: args, - Success: false, - Duration: time.Since(start), - Error: fmt.Sprintf("failed to click mouse: %v", err), - }, nil + return t.errorResult(args, start, fmt.Sprintf("failed to click mouse: %v", err)), nil } + t.updateStateAfterClick(ctx, controller, finalX, finalY) + result := domain.MouseClickToolResult{ Button: button, Clicks: clicks, @@ -174,6 +127,137 @@ func (t *MouseClickTool) Execute(ctx context.Context, args map[string]any) (*dom }, nil } +func (t *MouseClickTool) getButton(args map[string]any) string { + button, ok := args["button"].(string) + if !ok { + return "left" + } + return button +} + +func (t *MouseClickTool) getClicks(args map[string]any) int { + if clicksArg, ok := args["clicks"].(float64); ok { + return int(clicksArg) + } + return 1 +} + +func (t *MouseClickTool) getCoordinates(args map[string]any) (int, int, bool) { + if xArg, xOk := args["x"].(float64); xOk { + if yArg, yOk := args["y"].(float64); yOk { + return int(xArg), int(yArg), true + } + } + return 0, 0, false +} + +func (t *MouseClickTool) errorResult(args map[string]any, start time.Time, errorMsg string) *domain.ToolExecutionResult { + return &domain.ToolExecutionResult{ + ToolName: "MouseClick", + Arguments: args, + Success: false, + Duration: time.Since(start), + Error: errorMsg, + } +} + +func (t *MouseClickTool) handleMovement(ctx context.Context, controller display.DisplayController, shouldMove bool, x, y int) (int, int, error) { + if !shouldMove { + cursorX, cursorY, _ := controller.GetCursorPosition(ctx) + return cursorX, cursorY, nil + } + + targetX, targetY := t.scaleCoordinates(ctx, controller, x, y) + + if err := controller.MoveMouse(ctx, targetX, targetY); err != nil { + return 0, 0, fmt.Errorf("failed to move mouse: %w", err) + } + + return targetX, targetY, nil +} + +// scaleCoordinates converts API coordinates to screen coordinates using Anthropic's proportional scaling. +// This follows the official computer-use-demo implementation strategy. +func (t *MouseClickTool) scaleCoordinates(ctx context.Context, controller display.DisplayController, x, y int) (int, int) { + if isDirectExec := ctx.Value(domain.DirectExecutionKey); isDirectExec != nil && isDirectExec.(bool) { + return x, y + } + + screenWidth, screenHeight, err := controller.GetScreenDimensions(ctx) + if err != nil { + logger.Warn("Failed to get screen dimensions, no scaling", "error", err) + return x, y + } + + apiWidth := t.config.ComputerUse.Screenshot.TargetWidth + apiHeight := t.config.ComputerUse.Screenshot.TargetHeight + + if apiWidth == 0 || apiHeight == 0 { + return x, y + } + + screenX, screenY := ScaleAPIToScreen(x, y, apiWidth, apiHeight, screenWidth, screenHeight) + + return screenX, screenY +} + +func (t *MouseClickTool) updateStateAfterClick(ctx context.Context, controller display.DisplayController, x, y int) { + if t.stateManager == nil { + return + } + + t.storeFocusedApp(ctx, controller) + t.stateManager.SetLastClickCoordinates(x, y) + t.broadcastClickEvent(x, y) +} + +func (t *MouseClickTool) storeFocusedApp(ctx context.Context, controller display.DisplayController) { + focusManager, ok := controller.(display.FocusManager) + if !ok { + return + } + + appID, err := focusManager.GetFrontmostApp(ctx) + if err != nil { + logger.Warn("Failed to get frontmost app after click", "error", err) + return + } + + t.stateManager.SetLastFocusedApp(appID) +} + +func (t *MouseClickTool) broadcastClickEvent(x, y int) { + controller, err := t.displayProvider.GetController() + if err != nil { + logger.Warn("Failed to get controller for click indicator", "error", err) + return + } + defer func() { + if closeErr := controller.Close(); closeErr != nil { + logger.Warn("Failed to close controller", "error", closeErr) + } + }() + + _, screenHeight, err := controller.GetScreenDimensions(context.Background()) + if err != nil { + logger.Warn("Failed to get screen dimensions for click indicator", "error", err) + screenHeight = 1117 + } + + macosY := screenHeight - y + + clickEvent := domain.ClickIndicatorEvent{ + BaseChatEvent: domain.BaseChatEvent{ + RequestID: "click-indicator", + Timestamp: time.Now(), + }, + X: x, + Y: macosY, + ClickIndicator: true, + } + t.stateManager.BroadcastEvent(clickEvent) +} + // Validate checks if the tool arguments are valid func (t *MouseClickTool) Validate(args map[string]any) error { button, ok := args["button"].(string) diff --git a/internal/services/tools/mouse_move.go b/internal/services/tools/mouse_move.go index 48f890fd..1accb3ac 100644 --- a/internal/services/tools/mouse_move.go +++ b/internal/services/tools/mouse_move.go @@ -19,16 +19,18 @@ type MouseMoveTool struct { formatter domain.BaseFormatter rateLimiter domain.RateLimiter displayProvider display.Provider + stateManager domain.StateManager } // NewMouseMoveTool creates a new mouse move tool -func NewMouseMoveTool(cfg *config.Config, rateLimiter domain.RateLimiter, displayProvider display.Provider) *MouseMoveTool { +func NewMouseMoveTool(cfg *config.Config, rateLimiter domain.RateLimiter, displayProvider display.Provider, stateManager domain.StateManager) *MouseMoveTool { return &MouseMoveTool{ config: cfg, enabled: cfg.ComputerUse.Enabled && cfg.ComputerUse.MouseMove.Enabled, formatter: domain.NewBaseFormatter("MouseMove"), rateLimiter: rateLimiter, displayProvider: displayProvider, + stateManager: stateManager, } } @@ -111,9 +113,11 @@ func (t *MouseMoveTool) Execute(ctx context.Context, args map[string]any) (*doma } }() + targetX, targetY := t.scaleCoordinates(ctx, controller, int(x), int(y)) + fromX, fromY, _ := controller.GetCursorPosition(ctx) - if err := controller.MoveMouse(ctx, int(x), int(y)); err != nil { + if err := controller.MoveMouse(ctx, targetX, targetY); err != nil { return &domain.ToolExecutionResult{ ToolName: "MouseMove", Arguments: args, @@ -123,11 +127,13 @@ func (t *MouseMoveTool) Execute(ctx context.Context, args map[string]any) (*doma }, nil } + t.broadcastMoveEvent(fromX, fromY, targetX, targetY) + result := domain.MouseMoveToolResult{ FromX: fromX, FromY: fromY, - ToX: int(x), - ToY: int(y), + ToX: targetX, + ToY: targetY, Method: t.displayProvider.GetDisplayInfo().Name, } @@ -212,3 +218,69 @@ func (t *MouseMoveTool) ShouldCollapseArg(key string) bool { func (t *MouseMoveTool) ShouldAlwaysExpand() bool { return false } + +// scaleCoordinates converts API coordinates to screen coordinates using Anthropic's proportional scaling. +// This follows the official computer-use-demo implementation strategy. +func (t *MouseMoveTool) scaleCoordinates(ctx context.Context, controller display.DisplayController, x, y int) (int, int) { + if isDirectExec := ctx.Value(domain.DirectExecutionKey); isDirectExec != nil && isDirectExec.(bool) { + return x, y + } + + screenWidth, screenHeight, err := controller.GetScreenDimensions(ctx) + if err != nil { + logger.Warn("Failed to get screen dimensions", "error", err) + return x, y + } + + apiWidth := t.config.ComputerUse.Screenshot.TargetWidth + apiHeight := t.config.ComputerUse.Screenshot.TargetHeight + + if apiWidth == 0 || apiHeight == 0 { + return x, y + } + + screenX, screenY := ScaleAPIToScreen(x, y, apiWidth, apiHeight, screenWidth, screenHeight) + + return screenX, screenY +} + +// broadcastMoveEvent broadcasts a visual move indicator event for user feedback +func (t *MouseMoveTool) broadcastMoveEvent(fromX, fromY, toX, toY int) { + if t.stateManager == nil { + return + } + + controller, err := t.displayProvider.GetController() + if err != nil { + logger.Warn("Failed to get controller for move indicator", "error", err) + return + } + defer func() { + if closeErr := controller.Close(); closeErr != nil { + logger.Warn("Failed to close controller", "error", closeErr) + } + }() + + _, screenHeight, err := controller.GetScreenDimensions(context.Background()) + if err != nil { + logger.Warn("Failed to get screen dimensions for move indicator", "error", err) + screenHeight = 1117 + } + + macosFromY := screenHeight - fromY + macosToY := screenHeight - toY + + moveEvent := domain.MoveIndicatorEvent{ + BaseChatEvent: domain.BaseChatEvent{ + RequestID: "move-indicator", + Timestamp: time.Now(), + }, + FromX: fromX, + FromY: macosFromY, + ToX: toX, + ToY: macosToY, + MoveIndicator: true, + } + + t.stateManager.BroadcastEvent(moveEvent) +} diff --git a/internal/services/tools/mouse_move_test.go b/internal/services/tools/mouse_move_test.go new file mode 100644 index 00000000..1ebfc4b9 --- /dev/null +++ b/internal/services/tools/mouse_move_test.go @@ -0,0 +1,215 @@ +package tools + +import ( + "context" + "testing" + + config "github.com/inference-gateway/cli/config" + display "github.com/inference-gateway/cli/internal/display" + domain "github.com/inference-gateway/cli/internal/domain" + displayMocks "github.com/inference-gateway/cli/tests/mocks/display" + domainMocks "github.com/inference-gateway/cli/tests/mocks/domain" +) + +func TestMouseMoveTool_CoordinateScaling(t *testing.T) { + tests := []struct { + name string + apiWidth int + apiHeight int + logicalWidth int + logicalHeight int + inputX float64 + inputY float64 + directExec bool + expectedX int + expectedY int + description string + }{ + { + name: "Direct execution - no scaling", + apiWidth: 1024, + apiHeight: 768, + logicalWidth: 1728, + logicalHeight: 1117, + inputX: 500, + inputY: 400, + directExec: true, + expectedX: 500, + expectedY: 400, + description: "Direct execution should use coordinates as-is without scaling", + }, + { + name: "LLM execution - scale from XGA to 2x Retina", + apiWidth: 1024, + apiHeight: 768, + logicalWidth: 2048, + logicalHeight: 1536, + inputX: 512, + inputY: 384, + directExec: false, + expectedX: 1024, + expectedY: 768, + description: "LLM execution should scale from XGA (1024x768) to 2x Retina (2048x1536)", + }, + { + name: "LLM execution - top-left corner", + apiWidth: 1024, + apiHeight: 768, + logicalWidth: 1728, + logicalHeight: 1117, + inputX: 0, + inputY: 0, + directExec: false, + expectedX: 0, + expectedY: 0, + description: "Top-left corner should map to (0,0) in both spaces", + }, + { + name: "LLM execution - mismatched aspect ratios (proportional scaling)", + apiWidth: 1024, + apiHeight: 768, + logicalWidth: 1728, + logicalHeight: 1117, + inputX: 1024, + inputY: 768, + directExec: false, + expectedX: 1728, + expectedY: 1117, + description: "Bottom-right corner of API space maps to bottom-right corner of screen space", + }, + { + name: "LLM execution - 3x scaling", + apiWidth: 1024, + apiHeight: 768, + logicalWidth: 3072, + logicalHeight: 2304, + inputX: 256, + inputY: 192, + directExec: false, + expectedX: 768, + expectedY: 576, + description: "3x scaling should work proportionally with matching aspect ratios", + }, + { + name: "No scaling when dimensions match", + apiWidth: 1728, + apiHeight: 1117, + logicalWidth: 1728, + logicalHeight: 1117, + inputX: 500, + inputY: 400, + directExec: false, + expectedX: 500, + expectedY: 400, + description: "When API and logical dimensions match, no scaling should occur", + }, + { + name: "LLM execution - config has zero dimensions", + apiWidth: 0, + apiHeight: 0, + logicalWidth: 1728, + logicalHeight: 1117, + inputX: 500, + inputY: 400, + directExec: false, + expectedX: 500, + expectedY: 400, + description: "When config dimensions are 0, no scaling should occur", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tc := coordinateScalingTestCase{ + apiWidth: tt.apiWidth, + apiHeight: tt.apiHeight, + logicalWidth: tt.logicalWidth, + logicalHeight: tt.logicalHeight, + inputX: tt.inputX, + inputY: tt.inputY, + directExec: tt.directExec, + expectedX: tt.expectedX, + expectedY: tt.expectedY, + description: tt.description, + } + runCoordinateScalingTest(t, tc) + }) + } +} + +type coordinateScalingTestCase struct { + apiWidth int + apiHeight int + logicalWidth int + logicalHeight int + inputX float64 + inputY float64 + directExec bool + expectedX int + expectedY int + description string +} + +func runCoordinateScalingTest(t *testing.T, tt coordinateScalingTestCase) { + t.Helper() + + mockCtrl := &displayMocks.FakeDisplayController{} + mockCtrl.GetScreenDimensionsReturns(tt.logicalWidth, tt.logicalHeight, nil) + mockCtrl.GetCursorPositionReturns(100, 100, nil) + mockCtrl.MoveMouseReturns(nil) + mockCtrl.CloseReturns(nil) + + mockProv := &displayMocks.FakeProvider{} + mockProv.GetControllerReturns(mockCtrl, nil) + mockProv.GetDisplayInfoReturns(display.DisplayInfo{Name: "mock"}) + + cfg := &config.Config{ + ComputerUse: config.ComputerUseConfig{ + Screenshot: config.ScreenshotToolConfig{ + TargetWidth: tt.apiWidth, + TargetHeight: tt.apiHeight, + }, + }, + } + + rateLimiter := &domainMocks.FakeRateLimiter{} + rateLimiter.CheckAndRecordReturns(nil) + + stateManager := &domainMocks.FakeStateManager{} + + tool := NewMouseMoveTool(cfg, rateLimiter, mockProv, stateManager) + + ctx := context.Background() + if tt.directExec { + ctx = context.WithValue(ctx, domain.DirectExecutionKey, true) + } + + args := map[string]any{ + "x": tt.inputX, + "y": tt.inputY, + } + + result, err := tool.Execute(ctx, args) + + if err != nil { + t.Fatalf("Execute failed: %v", err) + } + + if !result.Success { + t.Fatalf("Execute was not successful: %s", result.Error) + } + + if mockCtrl.MoveMouseCallCount() != 1 { + t.Fatalf("Expected MoveMouse to be called once, got %d calls", mockCtrl.MoveMouseCallCount()) + } + + _, actualX, actualY := mockCtrl.MoveMouseArgsForCall(0) + + if actualX != tt.expectedX { + t.Errorf("%s\nExpected X coordinate: %d, got: %d", tt.description, tt.expectedX, actualX) + } + + if actualY != tt.expectedY { + t.Errorf("%s\nExpected Y coordinate: %d, got: %d", tt.description, tt.expectedY, actualY) + } +} diff --git a/internal/services/tools/registry.go b/internal/services/tools/registry.go index e909103c..b89c44af 100644 --- a/internal/services/tools/registry.go +++ b/internal/services/tools/registry.go @@ -19,31 +19,49 @@ import ( // Registry manages all available tools type Registry struct { - config domain.ConfigService - tools map[string]domain.Tool - readToolUsed bool - taskTracker domain.TaskTracker - imageService domain.ImageService - mcpManager domain.MCPManager - shellService domain.BackgroundShellService + config domain.ConfigService + tools map[string]domain.Tool + readToolUsed bool + taskTracker domain.TaskTracker + imageService domain.ImageService + mcpManager domain.MCPManager + shellService domain.BackgroundShellService + stateManager domain.StateManager + screenshotProvider domain.ScreenshotProvider } // NewRegistry creates a new tool registry with self-contained tools -func NewRegistry(cfg domain.ConfigService, imageService domain.ImageService, mcpManager domain.MCPManager, shellService domain.BackgroundShellService) *Registry { +func NewRegistry(cfg domain.ConfigService, imageService domain.ImageService, mcpManager domain.MCPManager, shellService domain.BackgroundShellService, stateManager domain.StateManager, screenshotProvider domain.ScreenshotProvider) *Registry { registry := &Registry{ - config: cfg, - tools: make(map[string]domain.Tool), - shellService: shellService, - readToolUsed: false, - taskTracker: utils.NewTaskTracker(), - imageService: imageService, - mcpManager: mcpManager, + config: cfg, + tools: make(map[string]domain.Tool), + shellService: shellService, + readToolUsed: false, + taskTracker: utils.NewTaskTracker(), + imageService: imageService, + mcpManager: mcpManager, + stateManager: stateManager, + screenshotProvider: screenshotProvider, } registry.registerTools() return registry } +// SetScreenshotProvider updates the screenshot provider for tools that need it +func (r *Registry) SetScreenshotProvider(provider domain.ScreenshotProvider) { + r.screenshotProvider = provider + + cfg := r.config.GetConfig() + if cfg.ComputerUse.Enabled { + displayProvider, err := display.DetectDisplay() + if err == nil { + rateLimiter := utils.NewRateLimiter(cfg.ComputerUse.RateLimit) + r.tools["MouseClick"] = NewMouseClickTool(cfg, rateLimiter, displayProvider, r.stateManager) + } + } +} + // registerTools initializes and registers all available tools func (r *Registry) registerTools() { cfg := r.config.GetConfig() @@ -90,10 +108,10 @@ func (r *Registry) registerTools() { logger.Warn("No compatible display platform detected, computer use tools will be disabled", "error", err) } else { rateLimiter := utils.NewRateLimiter(cfg.ComputerUse.RateLimit) - r.tools["MouseMove"] = NewMouseMoveTool(cfg, rateLimiter, displayProvider) - r.tools["MouseClick"] = NewMouseClickTool(cfg, rateLimiter, displayProvider) + r.tools["MouseMove"] = NewMouseMoveTool(cfg, rateLimiter, displayProvider, r.stateManager) + r.tools["MouseClick"] = NewMouseClickTool(cfg, rateLimiter, displayProvider, r.stateManager) r.tools["MouseScroll"] = NewMouseScrollTool(cfg, rateLimiter, displayProvider) - r.tools["KeyboardType"] = NewKeyboardTypeTool(cfg, rateLimiter, displayProvider) + r.tools["KeyboardType"] = NewKeyboardTypeTool(cfg, rateLimiter, displayProvider, r.stateManager) r.tools["GetFocusedApp"] = NewGetFocusedAppTool(r.config) r.tools["ActivateApp"] = NewActivateAppTool(r.config) } @@ -264,6 +282,8 @@ func (r *Registry) SetScreenshotServer(provider domain.ScreenshotProvider) { return } + r.SetScreenshotProvider(provider) + getLatestTool := NewGetLatestScreenshotTool(cfg, provider) r.tools["GetLatestScreenshot"] = getLatestTool diff --git a/internal/services/tools/registry_test.go b/internal/services/tools/registry_test.go index 195838fd..448711c9 100644 --- a/internal/services/tools/registry_test.go +++ b/internal/services/tools/registry_test.go @@ -90,7 +90,7 @@ func createTestRegistry() *Registry { }, } - return NewRegistry(newTestConfigService(cfg), nil, nil, nil) + return NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil) } func TestRegistry_GetTool_Unknown(t *testing.T) { @@ -122,7 +122,7 @@ func TestRegistry_DisabledTools(t *testing.T) { }, } - registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil) + registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil) tools := registry.ListAvailableTools() @@ -175,7 +175,7 @@ func TestRegistry_NewRegistry(t *testing.T) { } configService := newTestConfigService(cfg) - registry := NewRegistry(configService, nil, nil, nil) + registry := NewRegistry(configService, nil, nil, nil, nil, nil) if registry == nil { t.Fatal("Expected non-nil registry") @@ -210,7 +210,7 @@ func TestRegistry_GetTool(t *testing.T) { }, } - registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil) + registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil) tests := []struct { name string @@ -361,7 +361,7 @@ func TestRegistry_ListAvailableTools(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - registry := NewRegistry(newTestConfigService(tt.config), nil, nil, nil) + registry := NewRegistry(newTestConfigService(tt.config), nil, nil, nil, nil, nil) tools := registry.ListAvailableTools() if len(tools) < tt.expectedMin || len(tools) > tt.expectedMax { @@ -414,7 +414,7 @@ func TestRegistry_GetToolDefinitions(t *testing.T) { }, } - registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil) + registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil) definitions := registry.GetToolDefinitions() if len(definitions) < 5 || len(definitions) > 15 { @@ -464,7 +464,7 @@ func TestRegistry_IsToolEnabled(t *testing.T) { }, } - registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil) + registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil) tests := []struct { name string @@ -517,7 +517,7 @@ func TestRegistry_WithMockedTool(t *testing.T) { }, } - registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil) + registry := NewRegistry(newTestConfigService(cfg), nil, nil, nil, nil, nil) fakeTool := &mocks.FakeTool{} fakeTool.IsEnabledReturns(true) diff --git a/internal/shortcuts/custom.go b/internal/shortcuts/custom.go index b0a20c12..6f3566fe 100644 --- a/internal/shortcuts/custom.go +++ b/internal/shortcuts/custom.go @@ -388,6 +388,7 @@ func (c *CustomShortcut) executeWithTool(ctx context.Context, _ []string) (Short Arguments: string(argsJSON), } + ctx = context.WithValue(ctx, domain.DirectExecutionKey, true) result, err := c.toolService.ExecuteToolDirect(ctx, toolCall) if err != nil { return ShortcutResult{ diff --git a/tests/mocks/display/fake_display_controller.go b/tests/mocks/display/fake_display_controller.go new file mode 100644 index 00000000..c2b640fb --- /dev/null +++ b/tests/mocks/display/fake_display_controller.go @@ -0,0 +1,802 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package display + +import ( + "context" + "image" + "sync" + + "github.com/inference-gateway/cli/internal/display" +) + +type FakeDisplayController struct { + CaptureScreenStub func(context.Context, *display.Region) (image.Image, error) + captureScreenMutex sync.RWMutex + captureScreenArgsForCall []struct { + arg1 context.Context + arg2 *display.Region + } + captureScreenReturns struct { + result1 image.Image + result2 error + } + captureScreenReturnsOnCall map[int]struct { + result1 image.Image + result2 error + } + CaptureScreenBytesStub func(context.Context, *display.Region) ([]byte, error) + captureScreenBytesMutex sync.RWMutex + captureScreenBytesArgsForCall []struct { + arg1 context.Context + arg2 *display.Region + } + captureScreenBytesReturns struct { + result1 []byte + result2 error + } + captureScreenBytesReturnsOnCall map[int]struct { + result1 []byte + result2 error + } + ClickMouseStub func(context.Context, display.MouseButton, int) error + clickMouseMutex sync.RWMutex + clickMouseArgsForCall []struct { + arg1 context.Context + arg2 display.MouseButton + arg3 int + } + clickMouseReturns struct { + result1 error + } + clickMouseReturnsOnCall map[int]struct { + result1 error + } + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct { + } + closeReturns struct { + result1 error + } + closeReturnsOnCall map[int]struct { + result1 error + } + GetCursorPositionStub func(context.Context) (int, int, error) + getCursorPositionMutex sync.RWMutex + getCursorPositionArgsForCall []struct { + arg1 context.Context + } + getCursorPositionReturns struct { + result1 int + result2 int + result3 error + } + getCursorPositionReturnsOnCall map[int]struct { + result1 int + result2 int + result3 error + } + GetScreenDimensionsStub func(context.Context) (int, int, error) + getScreenDimensionsMutex sync.RWMutex + getScreenDimensionsArgsForCall []struct { + arg1 context.Context + } + getScreenDimensionsReturns struct { + result1 int + result2 int + result3 error + } + getScreenDimensionsReturnsOnCall map[int]struct { + result1 int + result2 int + result3 error + } + MoveMouseStub func(context.Context, int, int) error + moveMouseMutex sync.RWMutex + moveMouseArgsForCall []struct { + arg1 context.Context + arg2 int + arg3 int + } + moveMouseReturns struct { + result1 error + } + moveMouseReturnsOnCall map[int]struct { + result1 error + } + ScrollMouseStub func(context.Context, int, string) error + scrollMouseMutex sync.RWMutex + scrollMouseArgsForCall []struct { + arg1 context.Context + arg2 int + arg3 string + } + scrollMouseReturns struct { + result1 error + } + scrollMouseReturnsOnCall map[int]struct { + result1 error + } + SendKeyComboStub func(context.Context, string) error + sendKeyComboMutex sync.RWMutex + sendKeyComboArgsForCall []struct { + arg1 context.Context + arg2 string + } + sendKeyComboReturns struct { + result1 error + } + sendKeyComboReturnsOnCall map[int]struct { + result1 error + } + TypeTextStub func(context.Context, string, int) error + typeTextMutex sync.RWMutex + typeTextArgsForCall []struct { + arg1 context.Context + arg2 string + arg3 int + } + typeTextReturns struct { + result1 error + } + typeTextReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeDisplayController) CaptureScreen(arg1 context.Context, arg2 *display.Region) (image.Image, error) { + fake.captureScreenMutex.Lock() + ret, specificReturn := fake.captureScreenReturnsOnCall[len(fake.captureScreenArgsForCall)] + fake.captureScreenArgsForCall = append(fake.captureScreenArgsForCall, struct { + arg1 context.Context + arg2 *display.Region + }{arg1, arg2}) + stub := fake.CaptureScreenStub + fakeReturns := fake.captureScreenReturns + fake.recordInvocation("CaptureScreen", []interface{}{arg1, arg2}) + fake.captureScreenMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeDisplayController) CaptureScreenCallCount() int { + fake.captureScreenMutex.RLock() + defer fake.captureScreenMutex.RUnlock() + return len(fake.captureScreenArgsForCall) +} + +func (fake *FakeDisplayController) CaptureScreenCalls(stub func(context.Context, *display.Region) (image.Image, error)) { + fake.captureScreenMutex.Lock() + defer fake.captureScreenMutex.Unlock() + fake.CaptureScreenStub = stub +} + +func (fake *FakeDisplayController) CaptureScreenArgsForCall(i int) (context.Context, *display.Region) { + fake.captureScreenMutex.RLock() + defer fake.captureScreenMutex.RUnlock() + argsForCall := fake.captureScreenArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeDisplayController) CaptureScreenReturns(result1 image.Image, result2 error) { + fake.captureScreenMutex.Lock() + defer fake.captureScreenMutex.Unlock() + fake.CaptureScreenStub = nil + fake.captureScreenReturns = struct { + result1 image.Image + result2 error + }{result1, result2} +} + +func (fake *FakeDisplayController) CaptureScreenReturnsOnCall(i int, result1 image.Image, result2 error) { + fake.captureScreenMutex.Lock() + defer fake.captureScreenMutex.Unlock() + fake.CaptureScreenStub = nil + if fake.captureScreenReturnsOnCall == nil { + fake.captureScreenReturnsOnCall = make(map[int]struct { + result1 image.Image + result2 error + }) + } + fake.captureScreenReturnsOnCall[i] = struct { + result1 image.Image + result2 error + }{result1, result2} +} + +func (fake *FakeDisplayController) CaptureScreenBytes(arg1 context.Context, arg2 *display.Region) ([]byte, error) { + fake.captureScreenBytesMutex.Lock() + ret, specificReturn := fake.captureScreenBytesReturnsOnCall[len(fake.captureScreenBytesArgsForCall)] + fake.captureScreenBytesArgsForCall = append(fake.captureScreenBytesArgsForCall, struct { + arg1 context.Context + arg2 *display.Region + }{arg1, arg2}) + stub := fake.CaptureScreenBytesStub + fakeReturns := fake.captureScreenBytesReturns + fake.recordInvocation("CaptureScreenBytes", []interface{}{arg1, arg2}) + fake.captureScreenBytesMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeDisplayController) CaptureScreenBytesCallCount() int { + fake.captureScreenBytesMutex.RLock() + defer fake.captureScreenBytesMutex.RUnlock() + return len(fake.captureScreenBytesArgsForCall) +} + +func (fake *FakeDisplayController) CaptureScreenBytesCalls(stub func(context.Context, *display.Region) ([]byte, error)) { + fake.captureScreenBytesMutex.Lock() + defer fake.captureScreenBytesMutex.Unlock() + fake.CaptureScreenBytesStub = stub +} + +func (fake *FakeDisplayController) CaptureScreenBytesArgsForCall(i int) (context.Context, *display.Region) { + fake.captureScreenBytesMutex.RLock() + defer fake.captureScreenBytesMutex.RUnlock() + argsForCall := fake.captureScreenBytesArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeDisplayController) CaptureScreenBytesReturns(result1 []byte, result2 error) { + fake.captureScreenBytesMutex.Lock() + defer fake.captureScreenBytesMutex.Unlock() + fake.CaptureScreenBytesStub = nil + fake.captureScreenBytesReturns = struct { + result1 []byte + result2 error + }{result1, result2} +} + +func (fake *FakeDisplayController) CaptureScreenBytesReturnsOnCall(i int, result1 []byte, result2 error) { + fake.captureScreenBytesMutex.Lock() + defer fake.captureScreenBytesMutex.Unlock() + fake.CaptureScreenBytesStub = nil + if fake.captureScreenBytesReturnsOnCall == nil { + fake.captureScreenBytesReturnsOnCall = make(map[int]struct { + result1 []byte + result2 error + }) + } + fake.captureScreenBytesReturnsOnCall[i] = struct { + result1 []byte + result2 error + }{result1, result2} +} + +func (fake *FakeDisplayController) ClickMouse(arg1 context.Context, arg2 display.MouseButton, arg3 int) error { + fake.clickMouseMutex.Lock() + ret, specificReturn := fake.clickMouseReturnsOnCall[len(fake.clickMouseArgsForCall)] + fake.clickMouseArgsForCall = append(fake.clickMouseArgsForCall, struct { + arg1 context.Context + arg2 display.MouseButton + arg3 int + }{arg1, arg2, arg3}) + stub := fake.ClickMouseStub + fakeReturns := fake.clickMouseReturns + fake.recordInvocation("ClickMouse", []interface{}{arg1, arg2, arg3}) + fake.clickMouseMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDisplayController) ClickMouseCallCount() int { + fake.clickMouseMutex.RLock() + defer fake.clickMouseMutex.RUnlock() + return len(fake.clickMouseArgsForCall) +} + +func (fake *FakeDisplayController) ClickMouseCalls(stub func(context.Context, display.MouseButton, int) error) { + fake.clickMouseMutex.Lock() + defer fake.clickMouseMutex.Unlock() + fake.ClickMouseStub = stub +} + +func (fake *FakeDisplayController) ClickMouseArgsForCall(i int) (context.Context, display.MouseButton, int) { + fake.clickMouseMutex.RLock() + defer fake.clickMouseMutex.RUnlock() + argsForCall := fake.clickMouseArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeDisplayController) ClickMouseReturns(result1 error) { + fake.clickMouseMutex.Lock() + defer fake.clickMouseMutex.Unlock() + fake.ClickMouseStub = nil + fake.clickMouseReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) ClickMouseReturnsOnCall(i int, result1 error) { + fake.clickMouseMutex.Lock() + defer fake.clickMouseMutex.Unlock() + fake.ClickMouseStub = nil + if fake.clickMouseReturnsOnCall == nil { + fake.clickMouseReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.clickMouseReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) Close() error { + fake.closeMutex.Lock() + ret, specificReturn := fake.closeReturnsOnCall[len(fake.closeArgsForCall)] + fake.closeArgsForCall = append(fake.closeArgsForCall, struct { + }{}) + stub := fake.CloseStub + fakeReturns := fake.closeReturns + fake.recordInvocation("Close", []interface{}{}) + fake.closeMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDisplayController) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeDisplayController) CloseCalls(stub func() error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = stub +} + +func (fake *FakeDisplayController) CloseReturns(result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) CloseReturnsOnCall(i int, result1 error) { + fake.closeMutex.Lock() + defer fake.closeMutex.Unlock() + fake.CloseStub = nil + if fake.closeReturnsOnCall == nil { + fake.closeReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.closeReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) GetCursorPosition(arg1 context.Context) (int, int, error) { + fake.getCursorPositionMutex.Lock() + ret, specificReturn := fake.getCursorPositionReturnsOnCall[len(fake.getCursorPositionArgsForCall)] + fake.getCursorPositionArgsForCall = append(fake.getCursorPositionArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.GetCursorPositionStub + fakeReturns := fake.getCursorPositionReturns + fake.recordInvocation("GetCursorPosition", []interface{}{arg1}) + fake.getCursorPositionMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 +} + +func (fake *FakeDisplayController) GetCursorPositionCallCount() int { + fake.getCursorPositionMutex.RLock() + defer fake.getCursorPositionMutex.RUnlock() + return len(fake.getCursorPositionArgsForCall) +} + +func (fake *FakeDisplayController) GetCursorPositionCalls(stub func(context.Context) (int, int, error)) { + fake.getCursorPositionMutex.Lock() + defer fake.getCursorPositionMutex.Unlock() + fake.GetCursorPositionStub = stub +} + +func (fake *FakeDisplayController) GetCursorPositionArgsForCall(i int) context.Context { + fake.getCursorPositionMutex.RLock() + defer fake.getCursorPositionMutex.RUnlock() + argsForCall := fake.getCursorPositionArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeDisplayController) GetCursorPositionReturns(result1 int, result2 int, result3 error) { + fake.getCursorPositionMutex.Lock() + defer fake.getCursorPositionMutex.Unlock() + fake.GetCursorPositionStub = nil + fake.getCursorPositionReturns = struct { + result1 int + result2 int + result3 error + }{result1, result2, result3} +} + +func (fake *FakeDisplayController) GetCursorPositionReturnsOnCall(i int, result1 int, result2 int, result3 error) { + fake.getCursorPositionMutex.Lock() + defer fake.getCursorPositionMutex.Unlock() + fake.GetCursorPositionStub = nil + if fake.getCursorPositionReturnsOnCall == nil { + fake.getCursorPositionReturnsOnCall = make(map[int]struct { + result1 int + result2 int + result3 error + }) + } + fake.getCursorPositionReturnsOnCall[i] = struct { + result1 int + result2 int + result3 error + }{result1, result2, result3} +} + +func (fake *FakeDisplayController) GetScreenDimensions(arg1 context.Context) (int, int, error) { + fake.getScreenDimensionsMutex.Lock() + ret, specificReturn := fake.getScreenDimensionsReturnsOnCall[len(fake.getScreenDimensionsArgsForCall)] + fake.getScreenDimensionsArgsForCall = append(fake.getScreenDimensionsArgsForCall, struct { + arg1 context.Context + }{arg1}) + stub := fake.GetScreenDimensionsStub + fakeReturns := fake.getScreenDimensionsReturns + fake.recordInvocation("GetScreenDimensions", []interface{}{arg1}) + fake.getScreenDimensionsMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1, ret.result2, ret.result3 + } + return fakeReturns.result1, fakeReturns.result2, fakeReturns.result3 +} + +func (fake *FakeDisplayController) GetScreenDimensionsCallCount() int { + fake.getScreenDimensionsMutex.RLock() + defer fake.getScreenDimensionsMutex.RUnlock() + return len(fake.getScreenDimensionsArgsForCall) +} + +func (fake *FakeDisplayController) GetScreenDimensionsCalls(stub func(context.Context) (int, int, error)) { + fake.getScreenDimensionsMutex.Lock() + defer fake.getScreenDimensionsMutex.Unlock() + fake.GetScreenDimensionsStub = stub +} + +func (fake *FakeDisplayController) GetScreenDimensionsArgsForCall(i int) context.Context { + fake.getScreenDimensionsMutex.RLock() + defer fake.getScreenDimensionsMutex.RUnlock() + argsForCall := fake.getScreenDimensionsArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeDisplayController) GetScreenDimensionsReturns(result1 int, result2 int, result3 error) { + fake.getScreenDimensionsMutex.Lock() + defer fake.getScreenDimensionsMutex.Unlock() + fake.GetScreenDimensionsStub = nil + fake.getScreenDimensionsReturns = struct { + result1 int + result2 int + result3 error + }{result1, result2, result3} +} + +func (fake *FakeDisplayController) GetScreenDimensionsReturnsOnCall(i int, result1 int, result2 int, result3 error) { + fake.getScreenDimensionsMutex.Lock() + defer fake.getScreenDimensionsMutex.Unlock() + fake.GetScreenDimensionsStub = nil + if fake.getScreenDimensionsReturnsOnCall == nil { + fake.getScreenDimensionsReturnsOnCall = make(map[int]struct { + result1 int + result2 int + result3 error + }) + } + fake.getScreenDimensionsReturnsOnCall[i] = struct { + result1 int + result2 int + result3 error + }{result1, result2, result3} +} + +func (fake *FakeDisplayController) MoveMouse(arg1 context.Context, arg2 int, arg3 int) error { + fake.moveMouseMutex.Lock() + ret, specificReturn := fake.moveMouseReturnsOnCall[len(fake.moveMouseArgsForCall)] + fake.moveMouseArgsForCall = append(fake.moveMouseArgsForCall, struct { + arg1 context.Context + arg2 int + arg3 int + }{arg1, arg2, arg3}) + stub := fake.MoveMouseStub + fakeReturns := fake.moveMouseReturns + fake.recordInvocation("MoveMouse", []interface{}{arg1, arg2, arg3}) + fake.moveMouseMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDisplayController) MoveMouseCallCount() int { + fake.moveMouseMutex.RLock() + defer fake.moveMouseMutex.RUnlock() + return len(fake.moveMouseArgsForCall) +} + +func (fake *FakeDisplayController) MoveMouseCalls(stub func(context.Context, int, int) error) { + fake.moveMouseMutex.Lock() + defer fake.moveMouseMutex.Unlock() + fake.MoveMouseStub = stub +} + +func (fake *FakeDisplayController) MoveMouseArgsForCall(i int) (context.Context, int, int) { + fake.moveMouseMutex.RLock() + defer fake.moveMouseMutex.RUnlock() + argsForCall := fake.moveMouseArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeDisplayController) MoveMouseReturns(result1 error) { + fake.moveMouseMutex.Lock() + defer fake.moveMouseMutex.Unlock() + fake.MoveMouseStub = nil + fake.moveMouseReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) MoveMouseReturnsOnCall(i int, result1 error) { + fake.moveMouseMutex.Lock() + defer fake.moveMouseMutex.Unlock() + fake.MoveMouseStub = nil + if fake.moveMouseReturnsOnCall == nil { + fake.moveMouseReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.moveMouseReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) ScrollMouse(arg1 context.Context, arg2 int, arg3 string) error { + fake.scrollMouseMutex.Lock() + ret, specificReturn := fake.scrollMouseReturnsOnCall[len(fake.scrollMouseArgsForCall)] + fake.scrollMouseArgsForCall = append(fake.scrollMouseArgsForCall, struct { + arg1 context.Context + arg2 int + arg3 string + }{arg1, arg2, arg3}) + stub := fake.ScrollMouseStub + fakeReturns := fake.scrollMouseReturns + fake.recordInvocation("ScrollMouse", []interface{}{arg1, arg2, arg3}) + fake.scrollMouseMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDisplayController) ScrollMouseCallCount() int { + fake.scrollMouseMutex.RLock() + defer fake.scrollMouseMutex.RUnlock() + return len(fake.scrollMouseArgsForCall) +} + +func (fake *FakeDisplayController) ScrollMouseCalls(stub func(context.Context, int, string) error) { + fake.scrollMouseMutex.Lock() + defer fake.scrollMouseMutex.Unlock() + fake.ScrollMouseStub = stub +} + +func (fake *FakeDisplayController) ScrollMouseArgsForCall(i int) (context.Context, int, string) { + fake.scrollMouseMutex.RLock() + defer fake.scrollMouseMutex.RUnlock() + argsForCall := fake.scrollMouseArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeDisplayController) ScrollMouseReturns(result1 error) { + fake.scrollMouseMutex.Lock() + defer fake.scrollMouseMutex.Unlock() + fake.ScrollMouseStub = nil + fake.scrollMouseReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) ScrollMouseReturnsOnCall(i int, result1 error) { + fake.scrollMouseMutex.Lock() + defer fake.scrollMouseMutex.Unlock() + fake.ScrollMouseStub = nil + if fake.scrollMouseReturnsOnCall == nil { + fake.scrollMouseReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.scrollMouseReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) SendKeyCombo(arg1 context.Context, arg2 string) error { + fake.sendKeyComboMutex.Lock() + ret, specificReturn := fake.sendKeyComboReturnsOnCall[len(fake.sendKeyComboArgsForCall)] + fake.sendKeyComboArgsForCall = append(fake.sendKeyComboArgsForCall, struct { + arg1 context.Context + arg2 string + }{arg1, arg2}) + stub := fake.SendKeyComboStub + fakeReturns := fake.sendKeyComboReturns + fake.recordInvocation("SendKeyCombo", []interface{}{arg1, arg2}) + fake.sendKeyComboMutex.Unlock() + if stub != nil { + return stub(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDisplayController) SendKeyComboCallCount() int { + fake.sendKeyComboMutex.RLock() + defer fake.sendKeyComboMutex.RUnlock() + return len(fake.sendKeyComboArgsForCall) +} + +func (fake *FakeDisplayController) SendKeyComboCalls(stub func(context.Context, string) error) { + fake.sendKeyComboMutex.Lock() + defer fake.sendKeyComboMutex.Unlock() + fake.SendKeyComboStub = stub +} + +func (fake *FakeDisplayController) SendKeyComboArgsForCall(i int) (context.Context, string) { + fake.sendKeyComboMutex.RLock() + defer fake.sendKeyComboMutex.RUnlock() + argsForCall := fake.sendKeyComboArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeDisplayController) SendKeyComboReturns(result1 error) { + fake.sendKeyComboMutex.Lock() + defer fake.sendKeyComboMutex.Unlock() + fake.SendKeyComboStub = nil + fake.sendKeyComboReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) SendKeyComboReturnsOnCall(i int, result1 error) { + fake.sendKeyComboMutex.Lock() + defer fake.sendKeyComboMutex.Unlock() + fake.SendKeyComboStub = nil + if fake.sendKeyComboReturnsOnCall == nil { + fake.sendKeyComboReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.sendKeyComboReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) TypeText(arg1 context.Context, arg2 string, arg3 int) error { + fake.typeTextMutex.Lock() + ret, specificReturn := fake.typeTextReturnsOnCall[len(fake.typeTextArgsForCall)] + fake.typeTextArgsForCall = append(fake.typeTextArgsForCall, struct { + arg1 context.Context + arg2 string + arg3 int + }{arg1, arg2, arg3}) + stub := fake.TypeTextStub + fakeReturns := fake.typeTextReturns + fake.recordInvocation("TypeText", []interface{}{arg1, arg2, arg3}) + fake.typeTextMutex.Unlock() + if stub != nil { + return stub(arg1, arg2, arg3) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeDisplayController) TypeTextCallCount() int { + fake.typeTextMutex.RLock() + defer fake.typeTextMutex.RUnlock() + return len(fake.typeTextArgsForCall) +} + +func (fake *FakeDisplayController) TypeTextCalls(stub func(context.Context, string, int) error) { + fake.typeTextMutex.Lock() + defer fake.typeTextMutex.Unlock() + fake.TypeTextStub = stub +} + +func (fake *FakeDisplayController) TypeTextArgsForCall(i int) (context.Context, string, int) { + fake.typeTextMutex.RLock() + defer fake.typeTextMutex.RUnlock() + argsForCall := fake.typeTextArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3 +} + +func (fake *FakeDisplayController) TypeTextReturns(result1 error) { + fake.typeTextMutex.Lock() + defer fake.typeTextMutex.Unlock() + fake.TypeTextStub = nil + fake.typeTextReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) TypeTextReturnsOnCall(i int, result1 error) { + fake.typeTextMutex.Lock() + defer fake.typeTextMutex.Unlock() + fake.TypeTextStub = nil + if fake.typeTextReturnsOnCall == nil { + fake.typeTextReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.typeTextReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeDisplayController) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeDisplayController) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ display.DisplayController = new(FakeDisplayController) diff --git a/tests/mocks/display/fake_provider.go b/tests/mocks/display/fake_provider.go new file mode 100644 index 00000000..e885c1ed --- /dev/null +++ b/tests/mocks/display/fake_provider.go @@ -0,0 +1,231 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package display + +import ( + "sync" + + "github.com/inference-gateway/cli/internal/display" +) + +type FakeProvider struct { + GetControllerStub func() (display.DisplayController, error) + getControllerMutex sync.RWMutex + getControllerArgsForCall []struct { + } + getControllerReturns struct { + result1 display.DisplayController + result2 error + } + getControllerReturnsOnCall map[int]struct { + result1 display.DisplayController + result2 error + } + GetDisplayInfoStub func() display.DisplayInfo + getDisplayInfoMutex sync.RWMutex + getDisplayInfoArgsForCall []struct { + } + getDisplayInfoReturns struct { + result1 display.DisplayInfo + } + getDisplayInfoReturnsOnCall map[int]struct { + result1 display.DisplayInfo + } + IsAvailableStub func() bool + isAvailableMutex sync.RWMutex + isAvailableArgsForCall []struct { + } + isAvailableReturns struct { + result1 bool + } + isAvailableReturnsOnCall map[int]struct { + result1 bool + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeProvider) GetController() (display.DisplayController, error) { + fake.getControllerMutex.Lock() + ret, specificReturn := fake.getControllerReturnsOnCall[len(fake.getControllerArgsForCall)] + fake.getControllerArgsForCall = append(fake.getControllerArgsForCall, struct { + }{}) + stub := fake.GetControllerStub + fakeReturns := fake.getControllerReturns + fake.recordInvocation("GetController", []interface{}{}) + fake.getControllerMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeProvider) GetControllerCallCount() int { + fake.getControllerMutex.RLock() + defer fake.getControllerMutex.RUnlock() + return len(fake.getControllerArgsForCall) +} + +func (fake *FakeProvider) GetControllerCalls(stub func() (display.DisplayController, error)) { + fake.getControllerMutex.Lock() + defer fake.getControllerMutex.Unlock() + fake.GetControllerStub = stub +} + +func (fake *FakeProvider) GetControllerReturns(result1 display.DisplayController, result2 error) { + fake.getControllerMutex.Lock() + defer fake.getControllerMutex.Unlock() + fake.GetControllerStub = nil + fake.getControllerReturns = struct { + result1 display.DisplayController + result2 error + }{result1, result2} +} + +func (fake *FakeProvider) GetControllerReturnsOnCall(i int, result1 display.DisplayController, result2 error) { + fake.getControllerMutex.Lock() + defer fake.getControllerMutex.Unlock() + fake.GetControllerStub = nil + if fake.getControllerReturnsOnCall == nil { + fake.getControllerReturnsOnCall = make(map[int]struct { + result1 display.DisplayController + result2 error + }) + } + fake.getControllerReturnsOnCall[i] = struct { + result1 display.DisplayController + result2 error + }{result1, result2} +} + +func (fake *FakeProvider) GetDisplayInfo() display.DisplayInfo { + fake.getDisplayInfoMutex.Lock() + ret, specificReturn := fake.getDisplayInfoReturnsOnCall[len(fake.getDisplayInfoArgsForCall)] + fake.getDisplayInfoArgsForCall = append(fake.getDisplayInfoArgsForCall, struct { + }{}) + stub := fake.GetDisplayInfoStub + fakeReturns := fake.getDisplayInfoReturns + fake.recordInvocation("GetDisplayInfo", []interface{}{}) + fake.getDisplayInfoMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeProvider) GetDisplayInfoCallCount() int { + fake.getDisplayInfoMutex.RLock() + defer fake.getDisplayInfoMutex.RUnlock() + return len(fake.getDisplayInfoArgsForCall) +} + +func (fake *FakeProvider) GetDisplayInfoCalls(stub func() display.DisplayInfo) { + fake.getDisplayInfoMutex.Lock() + defer fake.getDisplayInfoMutex.Unlock() + fake.GetDisplayInfoStub = stub +} + +func (fake *FakeProvider) GetDisplayInfoReturns(result1 display.DisplayInfo) { + fake.getDisplayInfoMutex.Lock() + defer fake.getDisplayInfoMutex.Unlock() + fake.GetDisplayInfoStub = nil + fake.getDisplayInfoReturns = struct { + result1 display.DisplayInfo + }{result1} +} + +func (fake *FakeProvider) GetDisplayInfoReturnsOnCall(i int, result1 display.DisplayInfo) { + fake.getDisplayInfoMutex.Lock() + defer fake.getDisplayInfoMutex.Unlock() + fake.GetDisplayInfoStub = nil + if fake.getDisplayInfoReturnsOnCall == nil { + fake.getDisplayInfoReturnsOnCall = make(map[int]struct { + result1 display.DisplayInfo + }) + } + fake.getDisplayInfoReturnsOnCall[i] = struct { + result1 display.DisplayInfo + }{result1} +} + +func (fake *FakeProvider) IsAvailable() bool { + fake.isAvailableMutex.Lock() + ret, specificReturn := fake.isAvailableReturnsOnCall[len(fake.isAvailableArgsForCall)] + fake.isAvailableArgsForCall = append(fake.isAvailableArgsForCall, struct { + }{}) + stub := fake.IsAvailableStub + fakeReturns := fake.isAvailableReturns + fake.recordInvocation("IsAvailable", []interface{}{}) + fake.isAvailableMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeProvider) IsAvailableCallCount() int { + fake.isAvailableMutex.RLock() + defer fake.isAvailableMutex.RUnlock() + return len(fake.isAvailableArgsForCall) +} + +func (fake *FakeProvider) IsAvailableCalls(stub func() bool) { + fake.isAvailableMutex.Lock() + defer fake.isAvailableMutex.Unlock() + fake.IsAvailableStub = stub +} + +func (fake *FakeProvider) IsAvailableReturns(result1 bool) { + fake.isAvailableMutex.Lock() + defer fake.isAvailableMutex.Unlock() + fake.IsAvailableStub = nil + fake.isAvailableReturns = struct { + result1 bool + }{result1} +} + +func (fake *FakeProvider) IsAvailableReturnsOnCall(i int, result1 bool) { + fake.isAvailableMutex.Lock() + defer fake.isAvailableMutex.Unlock() + fake.IsAvailableStub = nil + if fake.isAvailableReturnsOnCall == nil { + fake.isAvailableReturnsOnCall = make(map[int]struct { + result1 bool + }) + } + fake.isAvailableReturnsOnCall[i] = struct { + result1 bool + }{result1} +} + +func (fake *FakeProvider) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeProvider) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ display.Provider = new(FakeProvider) diff --git a/tests/mocks/domain/fake_rate_limiter.go b/tests/mocks/domain/fake_rate_limiter.go new file mode 100644 index 00000000..a6061618 --- /dev/null +++ b/tests/mocks/domain/fake_rate_limiter.go @@ -0,0 +1,200 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package domain + +import ( + "sync" + + "github.com/inference-gateway/cli/internal/domain" +) + +type FakeRateLimiter struct { + CheckAndRecordStub func(string) error + checkAndRecordMutex sync.RWMutex + checkAndRecordArgsForCall []struct { + arg1 string + } + checkAndRecordReturns struct { + result1 error + } + checkAndRecordReturnsOnCall map[int]struct { + result1 error + } + GetCurrentCountStub func() int + getCurrentCountMutex sync.RWMutex + getCurrentCountArgsForCall []struct { + } + getCurrentCountReturns struct { + result1 int + } + getCurrentCountReturnsOnCall map[int]struct { + result1 int + } + ResetStub func() + resetMutex sync.RWMutex + resetArgsForCall []struct { + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRateLimiter) CheckAndRecord(arg1 string) error { + fake.checkAndRecordMutex.Lock() + ret, specificReturn := fake.checkAndRecordReturnsOnCall[len(fake.checkAndRecordArgsForCall)] + fake.checkAndRecordArgsForCall = append(fake.checkAndRecordArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.CheckAndRecordStub + fakeReturns := fake.checkAndRecordReturns + fake.recordInvocation("CheckAndRecord", []interface{}{arg1}) + fake.checkAndRecordMutex.Unlock() + if stub != nil { + return stub(arg1) + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRateLimiter) CheckAndRecordCallCount() int { + fake.checkAndRecordMutex.RLock() + defer fake.checkAndRecordMutex.RUnlock() + return len(fake.checkAndRecordArgsForCall) +} + +func (fake *FakeRateLimiter) CheckAndRecordCalls(stub func(string) error) { + fake.checkAndRecordMutex.Lock() + defer fake.checkAndRecordMutex.Unlock() + fake.CheckAndRecordStub = stub +} + +func (fake *FakeRateLimiter) CheckAndRecordArgsForCall(i int) string { + fake.checkAndRecordMutex.RLock() + defer fake.checkAndRecordMutex.RUnlock() + argsForCall := fake.checkAndRecordArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRateLimiter) CheckAndRecordReturns(result1 error) { + fake.checkAndRecordMutex.Lock() + defer fake.checkAndRecordMutex.Unlock() + fake.CheckAndRecordStub = nil + fake.checkAndRecordReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRateLimiter) CheckAndRecordReturnsOnCall(i int, result1 error) { + fake.checkAndRecordMutex.Lock() + defer fake.checkAndRecordMutex.Unlock() + fake.CheckAndRecordStub = nil + if fake.checkAndRecordReturnsOnCall == nil { + fake.checkAndRecordReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.checkAndRecordReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRateLimiter) GetCurrentCount() int { + fake.getCurrentCountMutex.Lock() + ret, specificReturn := fake.getCurrentCountReturnsOnCall[len(fake.getCurrentCountArgsForCall)] + fake.getCurrentCountArgsForCall = append(fake.getCurrentCountArgsForCall, struct { + }{}) + stub := fake.GetCurrentCountStub + fakeReturns := fake.getCurrentCountReturns + fake.recordInvocation("GetCurrentCount", []interface{}{}) + fake.getCurrentCountMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeRateLimiter) GetCurrentCountCallCount() int { + fake.getCurrentCountMutex.RLock() + defer fake.getCurrentCountMutex.RUnlock() + return len(fake.getCurrentCountArgsForCall) +} + +func (fake *FakeRateLimiter) GetCurrentCountCalls(stub func() int) { + fake.getCurrentCountMutex.Lock() + defer fake.getCurrentCountMutex.Unlock() + fake.GetCurrentCountStub = stub +} + +func (fake *FakeRateLimiter) GetCurrentCountReturns(result1 int) { + fake.getCurrentCountMutex.Lock() + defer fake.getCurrentCountMutex.Unlock() + fake.GetCurrentCountStub = nil + fake.getCurrentCountReturns = struct { + result1 int + }{result1} +} + +func (fake *FakeRateLimiter) GetCurrentCountReturnsOnCall(i int, result1 int) { + fake.getCurrentCountMutex.Lock() + defer fake.getCurrentCountMutex.Unlock() + fake.GetCurrentCountStub = nil + if fake.getCurrentCountReturnsOnCall == nil { + fake.getCurrentCountReturnsOnCall = make(map[int]struct { + result1 int + }) + } + fake.getCurrentCountReturnsOnCall[i] = struct { + result1 int + }{result1} +} + +func (fake *FakeRateLimiter) Reset() { + fake.resetMutex.Lock() + fake.resetArgsForCall = append(fake.resetArgsForCall, struct { + }{}) + stub := fake.ResetStub + fake.recordInvocation("Reset", []interface{}{}) + fake.resetMutex.Unlock() + if stub != nil { + fake.ResetStub() + } +} + +func (fake *FakeRateLimiter) ResetCallCount() int { + fake.resetMutex.RLock() + defer fake.resetMutex.RUnlock() + return len(fake.resetArgsForCall) +} + +func (fake *FakeRateLimiter) ResetCalls(stub func()) { + fake.resetMutex.Lock() + defer fake.resetMutex.Unlock() + fake.ResetStub = stub +} + +func (fake *FakeRateLimiter) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRateLimiter) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ domain.RateLimiter = new(FakeRateLimiter) diff --git a/tests/mocks/domain/fake_state_manager.go b/tests/mocks/domain/fake_state_manager.go index e2748439..dc00e631 100644 --- a/tests/mocks/domain/fake_state_manager.go +++ b/tests/mocks/domain/fake_state_manager.go @@ -36,6 +36,14 @@ type FakeStateManager struct { clearFileSelectionStateMutex sync.RWMutex clearFileSelectionStateArgsForCall []struct { } + ClearLastClickCoordinatesStub func() + clearLastClickCoordinatesMutex sync.RWMutex + clearLastClickCoordinatesArgsForCall []struct { + } + ClearLastFocusedAppStub func() + clearLastFocusedAppMutex sync.RWMutex + clearLastFocusedAppArgsForCall []struct { + } ClearMessageEditStateStub func() clearMessageEditStateMutex sync.RWMutex clearMessageEditStateArgsForCall []struct { @@ -156,6 +164,28 @@ type FakeStateManager struct { getFileSelectionStateReturnsOnCall map[int]struct { result1 *domain.FileSelectionState } + GetLastClickCoordinatesStub func() (int, int) + getLastClickCoordinatesMutex sync.RWMutex + getLastClickCoordinatesArgsForCall []struct { + } + getLastClickCoordinatesReturns struct { + result1 int + result2 int + } + getLastClickCoordinatesReturnsOnCall map[int]struct { + result1 int + result2 int + } + GetLastFocusedAppStub func() string + getLastFocusedAppMutex sync.RWMutex + getLastFocusedAppArgsForCall []struct { + } + getLastFocusedAppReturns struct { + result1 string + } + getLastFocusedAppReturnsOnCall map[int]struct { + result1 string + } GetMessageEditStateStub func() *domain.MessageEditState getMessageEditStateMutex sync.RWMutex getMessageEditStateArgsForCall []struct { @@ -272,6 +302,17 @@ type FakeStateManager struct { setFileSelectedIndexArgsForCall []struct { arg1 int } + SetLastClickCoordinatesStub func(int, int) + setLastClickCoordinatesMutex sync.RWMutex + setLastClickCoordinatesArgsForCall []struct { + arg1 int + arg2 int + } + SetLastFocusedAppStub func(string) + setLastFocusedAppMutex sync.RWMutex + setLastFocusedAppArgsForCall []struct { + arg1 string + } SetMessageEditStateStub func(*domain.MessageEditState) setMessageEditStateMutex sync.RWMutex setMessageEditStateArgsForCall []struct { @@ -525,6 +566,54 @@ func (fake *FakeStateManager) ClearFileSelectionStateCalls(stub func()) { fake.ClearFileSelectionStateStub = stub } +func (fake *FakeStateManager) ClearLastClickCoordinates() { + fake.clearLastClickCoordinatesMutex.Lock() + fake.clearLastClickCoordinatesArgsForCall = append(fake.clearLastClickCoordinatesArgsForCall, struct { + }{}) + stub := fake.ClearLastClickCoordinatesStub + fake.recordInvocation("ClearLastClickCoordinates", []interface{}{}) + fake.clearLastClickCoordinatesMutex.Unlock() + if stub != nil { + fake.ClearLastClickCoordinatesStub() + } +} + +func (fake *FakeStateManager) ClearLastClickCoordinatesCallCount() int { + fake.clearLastClickCoordinatesMutex.RLock() + defer fake.clearLastClickCoordinatesMutex.RUnlock() + return len(fake.clearLastClickCoordinatesArgsForCall) +} + +func (fake *FakeStateManager) ClearLastClickCoordinatesCalls(stub func()) { + fake.clearLastClickCoordinatesMutex.Lock() + defer fake.clearLastClickCoordinatesMutex.Unlock() + fake.ClearLastClickCoordinatesStub = stub +} + +func (fake *FakeStateManager) ClearLastFocusedApp() { + fake.clearLastFocusedAppMutex.Lock() + fake.clearLastFocusedAppArgsForCall = append(fake.clearLastFocusedAppArgsForCall, struct { + }{}) + stub := fake.ClearLastFocusedAppStub + fake.recordInvocation("ClearLastFocusedApp", []interface{}{}) + fake.clearLastFocusedAppMutex.Unlock() + if stub != nil { + fake.ClearLastFocusedAppStub() + } +} + +func (fake *FakeStateManager) ClearLastFocusedAppCallCount() int { + fake.clearLastFocusedAppMutex.RLock() + defer fake.clearLastFocusedAppMutex.RUnlock() + return len(fake.clearLastFocusedAppArgsForCall) +} + +func (fake *FakeStateManager) ClearLastFocusedAppCalls(stub func()) { + fake.clearLastFocusedAppMutex.Lock() + defer fake.clearLastFocusedAppMutex.Unlock() + fake.ClearLastFocusedAppStub = stub +} + func (fake *FakeStateManager) ClearMessageEditState() { fake.clearMessageEditStateMutex.Lock() fake.clearMessageEditStateArgsForCall = append(fake.clearMessageEditStateArgsForCall, struct { @@ -1170,6 +1259,115 @@ func (fake *FakeStateManager) GetFileSelectionStateReturnsOnCall(i int, result1 }{result1} } +func (fake *FakeStateManager) GetLastClickCoordinates() (int, int) { + fake.getLastClickCoordinatesMutex.Lock() + ret, specificReturn := fake.getLastClickCoordinatesReturnsOnCall[len(fake.getLastClickCoordinatesArgsForCall)] + fake.getLastClickCoordinatesArgsForCall = append(fake.getLastClickCoordinatesArgsForCall, struct { + }{}) + stub := fake.GetLastClickCoordinatesStub + fakeReturns := fake.getLastClickCoordinatesReturns + fake.recordInvocation("GetLastClickCoordinates", []interface{}{}) + fake.getLastClickCoordinatesMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1, ret.result2 + } + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeStateManager) GetLastClickCoordinatesCallCount() int { + fake.getLastClickCoordinatesMutex.RLock() + defer fake.getLastClickCoordinatesMutex.RUnlock() + return len(fake.getLastClickCoordinatesArgsForCall) +} + +func (fake *FakeStateManager) GetLastClickCoordinatesCalls(stub func() (int, int)) { + fake.getLastClickCoordinatesMutex.Lock() + defer fake.getLastClickCoordinatesMutex.Unlock() + fake.GetLastClickCoordinatesStub = stub +} + +func (fake *FakeStateManager) GetLastClickCoordinatesReturns(result1 int, result2 int) { + fake.getLastClickCoordinatesMutex.Lock() + defer fake.getLastClickCoordinatesMutex.Unlock() + fake.GetLastClickCoordinatesStub = nil + fake.getLastClickCoordinatesReturns = struct { + result1 int + result2 int + }{result1, result2} +} + +func (fake *FakeStateManager) GetLastClickCoordinatesReturnsOnCall(i int, result1 int, result2 int) { + fake.getLastClickCoordinatesMutex.Lock() + defer fake.getLastClickCoordinatesMutex.Unlock() + fake.GetLastClickCoordinatesStub = nil + if fake.getLastClickCoordinatesReturnsOnCall == nil { + fake.getLastClickCoordinatesReturnsOnCall = make(map[int]struct { + result1 int + result2 int + }) + } + fake.getLastClickCoordinatesReturnsOnCall[i] = struct { + result1 int + result2 int + }{result1, result2} +} + +func (fake *FakeStateManager) GetLastFocusedApp() string { + fake.getLastFocusedAppMutex.Lock() + ret, specificReturn := fake.getLastFocusedAppReturnsOnCall[len(fake.getLastFocusedAppArgsForCall)] + fake.getLastFocusedAppArgsForCall = append(fake.getLastFocusedAppArgsForCall, struct { + }{}) + stub := fake.GetLastFocusedAppStub + fakeReturns := fake.getLastFocusedAppReturns + fake.recordInvocation("GetLastFocusedApp", []interface{}{}) + fake.getLastFocusedAppMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeStateManager) GetLastFocusedAppCallCount() int { + fake.getLastFocusedAppMutex.RLock() + defer fake.getLastFocusedAppMutex.RUnlock() + return len(fake.getLastFocusedAppArgsForCall) +} + +func (fake *FakeStateManager) GetLastFocusedAppCalls(stub func() string) { + fake.getLastFocusedAppMutex.Lock() + defer fake.getLastFocusedAppMutex.Unlock() + fake.GetLastFocusedAppStub = stub +} + +func (fake *FakeStateManager) GetLastFocusedAppReturns(result1 string) { + fake.getLastFocusedAppMutex.Lock() + defer fake.getLastFocusedAppMutex.Unlock() + fake.GetLastFocusedAppStub = nil + fake.getLastFocusedAppReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeStateManager) GetLastFocusedAppReturnsOnCall(i int, result1 string) { + fake.getLastFocusedAppMutex.Lock() + defer fake.getLastFocusedAppMutex.Unlock() + fake.GetLastFocusedAppStub = nil + if fake.getLastFocusedAppReturnsOnCall == nil { + fake.getLastFocusedAppReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.getLastFocusedAppReturnsOnCall[i] = struct { + result1 string + }{result1} +} + func (fake *FakeStateManager) GetMessageEditState() *domain.MessageEditState { fake.getMessageEditStateMutex.Lock() ret, specificReturn := fake.getMessageEditStateReturnsOnCall[len(fake.getMessageEditStateArgsForCall)] @@ -1823,6 +2021,71 @@ func (fake *FakeStateManager) SetFileSelectedIndexArgsForCall(i int) int { return argsForCall.arg1 } +func (fake *FakeStateManager) SetLastClickCoordinates(arg1 int, arg2 int) { + fake.setLastClickCoordinatesMutex.Lock() + fake.setLastClickCoordinatesArgsForCall = append(fake.setLastClickCoordinatesArgsForCall, struct { + arg1 int + arg2 int + }{arg1, arg2}) + stub := fake.SetLastClickCoordinatesStub + fake.recordInvocation("SetLastClickCoordinates", []interface{}{arg1, arg2}) + fake.setLastClickCoordinatesMutex.Unlock() + if stub != nil { + fake.SetLastClickCoordinatesStub(arg1, arg2) + } +} + +func (fake *FakeStateManager) SetLastClickCoordinatesCallCount() int { + fake.setLastClickCoordinatesMutex.RLock() + defer fake.setLastClickCoordinatesMutex.RUnlock() + return len(fake.setLastClickCoordinatesArgsForCall) +} + +func (fake *FakeStateManager) SetLastClickCoordinatesCalls(stub func(int, int)) { + fake.setLastClickCoordinatesMutex.Lock() + defer fake.setLastClickCoordinatesMutex.Unlock() + fake.SetLastClickCoordinatesStub = stub +} + +func (fake *FakeStateManager) SetLastClickCoordinatesArgsForCall(i int) (int, int) { + fake.setLastClickCoordinatesMutex.RLock() + defer fake.setLastClickCoordinatesMutex.RUnlock() + argsForCall := fake.setLastClickCoordinatesArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeStateManager) SetLastFocusedApp(arg1 string) { + fake.setLastFocusedAppMutex.Lock() + fake.setLastFocusedAppArgsForCall = append(fake.setLastFocusedAppArgsForCall, struct { + arg1 string + }{arg1}) + stub := fake.SetLastFocusedAppStub + fake.recordInvocation("SetLastFocusedApp", []interface{}{arg1}) + fake.setLastFocusedAppMutex.Unlock() + if stub != nil { + fake.SetLastFocusedAppStub(arg1) + } +} + +func (fake *FakeStateManager) SetLastFocusedAppCallCount() int { + fake.setLastFocusedAppMutex.RLock() + defer fake.setLastFocusedAppMutex.RUnlock() + return len(fake.setLastFocusedAppArgsForCall) +} + +func (fake *FakeStateManager) SetLastFocusedAppCalls(stub func(string)) { + fake.setLastFocusedAppMutex.Lock() + defer fake.setLastFocusedAppMutex.Unlock() + fake.SetLastFocusedAppStub = stub +} + +func (fake *FakeStateManager) SetLastFocusedAppArgsForCall(i int) string { + fake.setLastFocusedAppMutex.RLock() + defer fake.setLastFocusedAppMutex.RUnlock() + argsForCall := fake.setLastFocusedAppArgsForCall[i] + return argsForCall.arg1 +} + func (fake *FakeStateManager) SetMessageEditState(arg1 *domain.MessageEditState) { fake.setMessageEditStateMutex.Lock() fake.setMessageEditStateArgsForCall = append(fake.setMessageEditStateArgsForCall, struct { From ed38ee707cb058dcd20f7f627cc154c9d97d3f9a Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Thu, 8 Jan 2026 11:16:27 +0200 Subject: [PATCH 11/14] refactor: Improve UI layout and mouse click logic --- internal/display/macos/ComputerUse/main.swift | 4 ++-- internal/display/macos/client_darwin.go | 12 ++++++++---- internal/services/screenshot_server.go | 4 ---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/internal/display/macos/ComputerUse/main.swift b/internal/display/macos/ComputerUse/main.swift index 21537b9f..0a285d47 100644 --- a/internal/display/macos/ComputerUse/main.swift +++ b/internal/display/macos/ComputerUse/main.swift @@ -277,7 +277,7 @@ class FloatingWindow: NSPanel { init(position: String, alwaysOnTop: Bool) { let screenFrame = NSScreen.main!.visibleFrame let windowWidth: CGFloat = 450 - let windowHeight: CGFloat = 600 + let windowHeight: CGFloat = 350 var xPos: CGFloat switch position { @@ -706,7 +706,7 @@ class EventReader { if let toolName = tool["Name"] as? String { let args = tool["Arguments"] as? String ?? "" if !args.isEmpty { - window.appendText("\n▶ \(toolName): \(args)\n", color: blue) + window.appendText("\n▶ \(toolName)\n \(args)\n", color: blue) } else { window.appendText("\n▶ \(toolName)\n", color: blue) } diff --git a/internal/display/macos/client_darwin.go b/internal/display/macos/client_darwin.go index 38e738c3..db1604cc 100644 --- a/internal/display/macos/client_darwin.go +++ b/internal/display/macos/client_darwin.go @@ -248,10 +248,14 @@ func (c *MacOSClient) ClickMouse(button string, clicks int) error { return fmt.Errorf("invalid click count: %d (must be 1-3)", clicks) } - for i := range clicks { - if i > 0 { - time.Sleep(100 * time.Millisecond) - } + switch clicks { + case 1: + robotgo.Click(robotButton, false) + case 2: + robotgo.Click(robotButton, true) + case 3: + robotgo.Click(robotButton, true) + time.Sleep(100 * time.Millisecond) robotgo.Click(robotButton, false) } diff --git a/internal/services/screenshot_server.go b/internal/services/screenshot_server.go index fb17b4d1..de160ef4 100644 --- a/internal/services/screenshot_server.go +++ b/internal/services/screenshot_server.go @@ -248,10 +248,6 @@ func (s *ScreenshotServer) captureScreenshot() error { img = resizeImage(img, targetW, targetH) width = targetW height = targetH - - logger.Info("Screenshot force-resized to target dimensions", - "from", fmt.Sprintf("%dx%d", originalWidth, originalHeight), - "to", fmt.Sprintf("%dx%d", width, height)) } quality := s.cfg.ComputerUse.Screenshot.Quality From 157bc38de0770d5904af16c2666cba566b19a0eb Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Thu, 8 Jan 2026 16:41:56 +0200 Subject: [PATCH 12/14] refactor: Remove horizontal scrolling from ComputerUse dialog Signed-off-by: Eden Reich --- internal/display/macos/ComputerUse/main.swift | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/internal/display/macos/ComputerUse/main.swift b/internal/display/macos/ComputerUse/main.swift index 0a285d47..27df2d83 100644 --- a/internal/display/macos/ComputerUse/main.swift +++ b/internal/display/macos/ComputerUse/main.swift @@ -329,6 +329,10 @@ class FloatingWindow: NSPanel { self.orderFront(nil) } + deinit { + NotificationCenter.default.removeObserver(self, name: NSWindow.didResizeNotification, object: nil) + } + @objc func customMinimize() { if isMinimized { restoreWindow() @@ -412,7 +416,32 @@ class FloatingWindow: NSPanel { context.timingFunction = CAMediaTimingFunction(name: .easeInEaseOut) self.animator().setFrame(savedFrame, display: true) self.animator().alphaValue = 0.95 - }, completionHandler: nil) + }, completionHandler: { + self.updateTextContainerWidth() + }) + } + + func updateTextContainerWidth() { + guard !isMinimized else { return } + + DispatchQueue.main.async { + let visibleWidth = self.scrollView.contentView.bounds.width + + var newFrame = self.textView.frame + newFrame.size.width = visibleWidth + self.textView.frame = newFrame + + let textInset: CGFloat = 16 + let availableWidth = visibleWidth - (textInset * 2) + + self.textView.textContainer?.containerSize = NSSize( + width: availableWidth, + height: CGFloat.greatestFiniteMagnitude + ) + + self.textView.layoutManager?.ensureLayout(for: self.textView.textContainer!) + self.textView.setNeedsDisplay(self.textView.bounds) + } } func updateMinimizedUI() { @@ -470,6 +499,7 @@ class FloatingWindow: NSPanel { scrollView.documentView = textView scrollView.hasVerticalScroller = true + scrollView.hasHorizontalScroller = false scrollView.autohidesScrollers = true scrollView.frame = contentView.bounds scrollView.autoresizingMask = [.width, .height] @@ -517,6 +547,16 @@ class FloatingWindow: NSPanel { contentView.addSubview(approvalBox) + NotificationCenter.default.addObserver( + forName: NSWindow.didResizeNotification, + object: self, + queue: .main + ) { [weak self] _ in + self?.updateTextContainerWidth() + } + + updateTextContainerWidth() + fputs("UI ready for output\n", stderr) fflush(stderr) } From 0ed136146ac055d2f03cefb6ccec887e30ed1f83 Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Thu, 8 Jan 2026 16:56:22 +0200 Subject: [PATCH 13/14] fix: Window title overlapping the conversation history Signed-off-by: Eden Reich --- internal/display/macos/ComputerUse/main.swift | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/internal/display/macos/ComputerUse/main.swift b/internal/display/macos/ComputerUse/main.swift index 27df2d83..8e0b4965 100644 --- a/internal/display/macos/ComputerUse/main.swift +++ b/internal/display/macos/ComputerUse/main.swift @@ -483,6 +483,9 @@ class FloatingWindow: NSPanel { func setupUI() { guard let contentView = self.contentView else { return } + let titleBarHeight = self.frame.height - contentView.frame.height + let topPadding = titleBarHeight > 0 ? titleBarHeight : 28 + textView.frame = contentView.bounds textView.autoresizingMask = [.width] textView.isEditable = false @@ -497,11 +500,15 @@ class FloatingWindow: NSPanel { textView.textContainer?.containerSize = NSSize(width: contentView.bounds.width, height: CGFloat.greatestFiniteMagnitude) textView.textContainer?.lineBreakMode = .byWordWrapping + var scrollFrame = contentView.bounds + scrollFrame.origin.y = 0 + scrollFrame.size.height = contentView.bounds.height - topPadding + scrollView.documentView = textView scrollView.hasVerticalScroller = true scrollView.hasHorizontalScroller = false scrollView.autohidesScrollers = true - scrollView.frame = contentView.bounds + scrollView.frame = scrollFrame scrollView.autoresizingMask = [.width, .height] contentView.addSubview(scrollView) From c69d5b42da925c1e2d8e42ecab8538568955a3c1 Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Thu, 8 Jan 2026 17:35:38 +0200 Subject: [PATCH 14/14] refactor: Add some space when approval actions appear Signed-off-by: Eden Reich --- internal/display/macos/ComputerUse/main.swift | 46 ++++++++++++++----- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/internal/display/macos/ComputerUse/main.swift b/internal/display/macos/ComputerUse/main.swift index 8e0b4965..f12f2661 100644 --- a/internal/display/macos/ComputerUse/main.swift +++ b/internal/display/macos/ComputerUse/main.swift @@ -239,6 +239,11 @@ class ClickableView: NSView { var newOrigin = window.frame.origin newOrigin.y += deltaY + + let minY = screenFrame.minY + let maxY = screenFrame.maxY - window.frame.height + newOrigin.y = max(minY, min(maxY, newOrigin.y)) + newOrigin.x = screenFrame.maxX - window.frame.width window.setFrameOrigin(newOrigin) @@ -292,7 +297,7 @@ class FloatingWindow: NSPanel { let yPos = screenFrame.maxY - windowHeight - 20 let frame = NSRect(x: xPos, y: yPos, width: windowWidth, height: windowHeight) - super.init(contentRect: frame, styleMask: [.titled, .resizable, .miniaturizable, .fullSizeContentView], backing: .buffered, defer: false) + super.init(contentRect: frame, styleMask: [.titled, .resizable, .miniaturizable], backing: .buffered, defer: false) self.title = "Computer Use" self.isFloatingPanel = true @@ -313,7 +318,7 @@ class FloatingWindow: NSPanel { self.hasShadow = true self.invalidateShadow() - self.titlebarAppearsTransparent = true + self.titlebarAppearsTransparent = false self.titleVisibility = .visible self.isMovableByWindowBackground = true @@ -390,7 +395,7 @@ class FloatingWindow: NSPanel { isMinimized = false self.titleVisibility = .visible - self.titlebarAppearsTransparent = true + self.titlebarAppearsTransparent = false self.standardWindowButton(.closeButton)?.alphaValue = 0 self.standardWindowButton(.miniaturizeButton)?.alphaValue = 1.0 self.standardWindowButton(.zoomButton)?.alphaValue = 0 @@ -418,6 +423,7 @@ class FloatingWindow: NSPanel { self.animator().alphaValue = 0.95 }, completionHandler: { self.updateTextContainerWidth() + self.updateScrollViewInsets() }) } @@ -483,9 +489,6 @@ class FloatingWindow: NSPanel { func setupUI() { guard let contentView = self.contentView else { return } - let titleBarHeight = self.frame.height - contentView.frame.height - let topPadding = titleBarHeight > 0 ? titleBarHeight : 28 - textView.frame = contentView.bounds textView.autoresizingMask = [.width] textView.isEditable = false @@ -500,15 +503,11 @@ class FloatingWindow: NSPanel { textView.textContainer?.containerSize = NSSize(width: contentView.bounds.width, height: CGFloat.greatestFiniteMagnitude) textView.textContainer?.lineBreakMode = .byWordWrapping - var scrollFrame = contentView.bounds - scrollFrame.origin.y = 0 - scrollFrame.size.height = contentView.bounds.height - topPadding - scrollView.documentView = textView scrollView.hasVerticalScroller = true scrollView.hasHorizontalScroller = false scrollView.autohidesScrollers = true - scrollView.frame = scrollFrame + scrollView.frame = contentView.bounds scrollView.autoresizingMask = [.width, .height] contentView.addSubview(scrollView) @@ -590,6 +589,7 @@ class FloatingWindow: NSPanel { } approvalBox.isHidden = true currentCallID = nil + updateScrollViewInsets() } func appendText(_ text: String, color: NSColor? = nil) { @@ -608,6 +608,29 @@ class FloatingWindow: NSPanel { DispatchQueue.main.async { self.currentCallID = callID self.approvalBox.isHidden = false + self.updateScrollViewInsets() + } + } + + func updateScrollViewInsets() { + guard let contentView = self.contentView else { return } + + let buttonAreaHeight: CGFloat = 70 + + if approvalBox.isHidden { + scrollView.frame = contentView.bounds + } else { + let scrollFrame = NSRect( + x: 0, + y: buttonAreaHeight, + width: contentView.bounds.width, + height: contentView.bounds.height - buttonAreaHeight + ) + scrollView.frame = scrollFrame + + DispatchQueue.main.async { + self.textView.scrollToEndOfDocument(nil) + } } } @@ -780,6 +803,7 @@ class EventReader { case "Approval Cleared": DispatchQueue.main.async { self.window.approvalBox.isHidden = true + self.window.updateScrollViewInsets() } case "Border Show":