From 586b357cb63a5f42e939344950a20bfb54ce8b2f Mon Sep 17 00:00:00 2001 From: Eden Reich Date: Mon, 22 Dec 2025 17:39:16 +0200 Subject: [PATCH] feat: Implement a UI using the AG-UI standard --- .flox/env/manifest.lock | 128 + .flox/env/manifest.toml | 2 + .infer/agents.yaml | 2 +- .infer/config.yaml | 26 + Taskfile.yml | 2 +- cmd/agents.go | 2 +- cmd/chat.go | 47 +- cmd/root.go | 10 + cmd/serve.go | 331 + config/api.go | 28 + config/config.go | 22 +- go.mod | 1 + go.sum | 2 + internal/container/container.go | 36 +- internal/domain/chat_events.go | 8 +- internal/domain/interfaces.go | 1 + internal/handlers/api_handlers.go | 464 + internal/handlers/chat_command_handler.go | 16 +- internal/handlers/headless_handler.go | 824 ++ .../headless_handler_progress_test.go | 404 + internal/handlers/session_manager.go | 314 + internal/handlers/websocket_handler.go | 284 + internal/infra/storage/jsonl.go | 8 +- internal/services/agent.go | 68 +- internal/services/agent_manager.go | 6 + internal/services/conversation.go | 5 + internal/services/ui_manager.go | 186 + internal/utils/browser.go | 23 + .../domain/fake_conversation_repository.go | 63 + ui/.gitignore | 2 + ui/README.md | 289 + ui/app/globals.css | 124 + ui/app/layout.tsx | 22 + ui/app/page.tsx | 968 +++ ui/app/providers.tsx | 27 + ui/components.json | 22 + ui/components/status-bar.tsx | 81 + ui/components/theme-toggle.tsx | 24 + ui/components/tool-call-display.tsx | 198 + ui/components/ui/button.tsx | 57 + ui/components/ui/command.tsx | 153 + ui/components/ui/dialog.tsx | 122 + ui/components/ui/popover.tsx | 33 + ui/components/ui/select.tsx | 159 + ui/eslint.config.mjs | 14 + ui/lib/agui-types.ts | 73 + ui/lib/api/client.ts | 227 + ui/lib/chat/websocket-client.ts | 263 + ui/lib/storage/factory.ts | 20 + ui/lib/storage/hooks.ts | 206 + ui/lib/storage/http/http-storage.ts | 78 + ui/lib/storage/interfaces.ts | 192 + ui/lib/theme-provider.tsx | 51 + ui/lib/utils.ts | 6 + ui/next-env.d.ts | 6 + ui/next.config.ts | 12 + ui/package-lock.json | 7578 +++++++++++++++++ ui/package.json | 41 + ui/postcss.config.mjs | 8 + ui/tailwind.config.ts | 9 + ui/tsconfig.json | 41 + 61 files changed, 14347 insertions(+), 72 deletions(-) create mode 100644 cmd/serve.go create mode 100644 config/api.go create mode 100644 internal/handlers/api_handlers.go create mode 100644 internal/handlers/headless_handler.go create mode 100644 internal/handlers/headless_handler_progress_test.go create mode 100644 internal/handlers/session_manager.go create mode 100644 internal/handlers/websocket_handler.go create mode 100644 internal/services/ui_manager.go create mode 100644 internal/utils/browser.go create mode 100644 ui/.gitignore create mode 100644 ui/README.md create mode 100644 ui/app/globals.css create mode 100644 ui/app/layout.tsx create mode 100644 ui/app/page.tsx create mode 100644 ui/app/providers.tsx create mode 100644 ui/components.json create mode 100644 ui/components/status-bar.tsx create mode 100644 ui/components/theme-toggle.tsx create mode 100644 ui/components/tool-call-display.tsx create mode 100644 ui/components/ui/button.tsx create mode 100644 ui/components/ui/command.tsx create mode 100644 ui/components/ui/dialog.tsx create mode 100644 ui/components/ui/popover.tsx create mode 100644 ui/components/ui/select.tsx create mode 100644 ui/eslint.config.mjs create mode 100644 ui/lib/agui-types.ts create mode 100644 ui/lib/api/client.ts create mode 100644 ui/lib/chat/websocket-client.ts create mode 100644 ui/lib/storage/factory.ts create mode 100644 ui/lib/storage/hooks.ts create mode 100644 ui/lib/storage/http/http-storage.ts create mode 100644 ui/lib/storage/interfaces.ts create mode 100644 ui/lib/theme-provider.tsx create mode 100644 ui/lib/utils.ts create mode 100644 ui/next-env.d.ts create mode 100644 ui/next.config.ts create mode 100644 ui/package-lock.json create mode 100644 ui/package.json create mode 100644 ui/postcss.config.mjs create mode 100644 ui/tailwind.config.ts create mode 100644 ui/tsconfig.json diff --git a/.flox/env/manifest.lock b/.flox/env/manifest.lock index 2a61fe55..9c0b15a5 100644 --- a/.flox/env/manifest.lock +++ b/.flox/env/manifest.lock @@ -43,6 +43,10 @@ "pkg-path": "nodejs_24", "version": "^24.11.1" }, + "nodejs_24": { + "pkg-path": "nodejs_24", + "version": "24.11.1" + }, "pre-commit": { "pkg-path": "pre-commit", "version": "^4.5.0" @@ -1241,6 +1245,130 @@ "group": "toplevel", "priority": 5 }, + { + "attr_path": "nodejs_24", + "broken": false, + "derivation": "/nix/store/vf9p8gxd4dydpbghi45a369ba9gwxvg4-nodejs-24.11.1.drv", + "description": "Event-driven I/O framework for the V8 JavaScript engine", + "install_id": "nodejs_24", + "license": "MIT", + "locked_url": "https://github.com/flox/nixpkgs?rev=1306659b587dc277866c7b69eb97e5f07864d8c4", + "name": "nodejs-24.11.1", + "pname": "nodejs_24", + "rev": "1306659b587dc277866c7b69eb97e5f07864d8c4", + "rev_count": 912002, + "rev_date": "2025-12-15T06:20:37Z", + "scrape_date": "2025-12-16T03:01:53.867688Z", + "stabilities": [ + "unstable" + ], + "unfree": false, + "version": "24.11.1", + "outputs_to_install": [ + "out" + ], + "outputs": { + "dev": "/nix/store/akmrc34ry4wil7wq8ghxbh2dpvikschh-nodejs-24.11.1-dev", + "libv8": "/nix/store/h4ad12n85c04b2c5c1a4r8r5qr4i1afn-nodejs-24.11.1-libv8", + "out": "/nix/store/h57ln4r40qr349z1r8ri90n9i9xid3r5-nodejs-24.11.1" + }, + "system": "aarch64-darwin", + "group": "toplevel", + "priority": 5 + }, + { + "attr_path": "nodejs_24", + "broken": false, + "derivation": "/nix/store/8njzr7lnj4w452wh84bilva49bwgz9lw-nodejs-24.11.1.drv", + "description": "Event-driven I/O framework for the V8 JavaScript engine", + "install_id": "nodejs_24", + "license": "MIT", + "locked_url": "https://github.com/flox/nixpkgs?rev=1306659b587dc277866c7b69eb97e5f07864d8c4", + "name": "nodejs-24.11.1", + "pname": "nodejs_24", + "rev": "1306659b587dc277866c7b69eb97e5f07864d8c4", + "rev_count": 912002, + "rev_date": "2025-12-15T06:20:37Z", + "scrape_date": "2025-12-16T03:18:11.985214Z", + "stabilities": [ + "unstable" + ], + "unfree": false, + "version": "24.11.1", + "outputs_to_install": [ + "out" + ], + "outputs": { + "dev": "/nix/store/82gbs5sgm1gpggsalpdlnxjkvr7bnkcc-nodejs-24.11.1-dev", + "libv8": "/nix/store/7ihzshxrnl1xh935ff8va57aw1avwk5j-nodejs-24.11.1-libv8", + "out": "/nix/store/r46dnkahacvrnavv5945aizq8b61kbwc-nodejs-24.11.1" + }, + "system": "aarch64-linux", + "group": "toplevel", + "priority": 5 + }, + { + "attr_path": "nodejs_24", + "broken": false, + "derivation": "/nix/store/krwzsq16irazaldn5ygdkx7zfpbybxba-nodejs-24.11.1.drv", + "description": "Event-driven I/O framework for the V8 JavaScript engine", + "install_id": "nodejs_24", + "license": "MIT", + "locked_url": "https://github.com/flox/nixpkgs?rev=1306659b587dc277866c7b69eb97e5f07864d8c4", + "name": "nodejs-24.11.1", + "pname": "nodejs_24", + "rev": "1306659b587dc277866c7b69eb97e5f07864d8c4", + "rev_count": 912002, + "rev_date": "2025-12-15T06:20:37Z", + "scrape_date": "2025-12-16T03:34:11.125879Z", + "stabilities": [ + "unstable" + ], + "unfree": false, + "version": "24.11.1", + "outputs_to_install": [ + "out" + ], + "outputs": { + "dev": "/nix/store/xj8y0v2v5ssdxbpz1bz9yivn109g3i4x-nodejs-24.11.1-dev", + "libv8": "/nix/store/136306lidz9kx8dlka3l2sf1q9i6ilc3-nodejs-24.11.1-libv8", + "out": "/nix/store/68d27bbipwcv99inwbs6fb7vxbgd4q6j-nodejs-24.11.1" + }, + "system": "x86_64-darwin", + "group": "toplevel", + "priority": 5 + }, + { + "attr_path": "nodejs_24", + "broken": false, + "derivation": "/nix/store/n15i8fv8a00hy7ng11km2lr8wvgz552w-nodejs-24.11.1.drv", + "description": "Event-driven I/O framework for the V8 JavaScript engine", + "install_id": "nodejs_24", + "license": "MIT", + "locked_url": "https://github.com/flox/nixpkgs?rev=1306659b587dc277866c7b69eb97e5f07864d8c4", + "name": "nodejs-24.11.1", + "pname": "nodejs_24", + "rev": "1306659b587dc277866c7b69eb97e5f07864d8c4", + "rev_count": 912002, + "rev_date": "2025-12-15T06:20:37Z", + "scrape_date": "2025-12-16T03:50:36.781636Z", + "stabilities": [ + "unstable" + ], + "unfree": false, + "version": "24.11.1", + "outputs_to_install": [ + "out" + ], + "outputs": { + "dev": "/nix/store/qayyvn6vq1yx49lr30kkrpdl1jy99752-nodejs-24.11.1-dev", + "libv8": "/nix/store/i9s2w9xib1fbqrc2qmmyz17frhpvkmsk-nodejs-24.11.1-libv8", + "out": "/nix/store/lgggxsrdzisnbligi7irlh4qmqczs0xk-nodejs-24.11.1" + }, + "system": "x86_64-linux", + "group": "toplevel", + "priority": 5 + }, { "attr_path": "pre-commit", "broken": false, diff --git a/.flox/env/manifest.toml b/.flox/env/manifest.toml index 26a7303c..f634057b 100644 --- a/.flox/env/manifest.toml +++ b/.flox/env/manifest.toml @@ -34,6 +34,8 @@ docker.pkg-path = "docker" docker.version = "^29.1.2" docker-compose.pkg-path = "docker-compose" docker-compose.version = "^5.0.0" +nodejs_24.pkg-path = "nodejs_24" +nodejs_24.version = "24.11.1" [hook] on-activate = """ diff --git a/.infer/agents.yaml b/.infer/agents.yaml index a23ff1be..e2694bff 100644 --- a/.infer/agents.yaml +++ b/.infer/agents.yaml @@ -1,6 +1,6 @@ agents: - name: mock-agent - url: http://localhost:8081 + url: http://localhost:8082 oci: ghcr.io/inference-gateway/mock-agent:latest run: true model: deepseek/deepseek-chat diff --git a/.infer/config.yaml b/.infer/config.yaml index 787ff16b..dfaf3c32 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -596,6 +596,32 @@ mcp: liveness_probe_interval: 10 max_retries: 10 servers: [] +api: + host: 127.0.0.1 + port: 8081 + read_timeout: 30 + write_timeout: 30 + idle_timeout: 120 + cors: + enabled: true + allowed_origins: + - http://localhost:3000 + - http://localhost:3001 + allowed_methods: + - GET + - POST + - PUT + - PATCH + - DELETE + - OPTIONS + allowed_headers: + - Content-Type + - Authorization + ui: + port: 3000 + auto_open: true + mode: npm + working_dir: ./ui pricing: enabled: true currency: USD diff --git a/Taskfile.yml b/Taskfile.yml index 5f4efa76..9121fe8b 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -63,7 +63,7 @@ tasks: desc: Run linter (requires golangci-lint) cmds: - golangci-lint run - - markdownlint . --ignore CHANGELOG.md --fix + - markdownlint . --ignore CHANGELOG.md --ignore ui/node_modules --fix fmt: desc: Format Go code diff --git a/cmd/agents.go b/cmd/agents.go index ecf2e0d3..61b2ef9e 100644 --- a/cmd/agents.go +++ b/cmd/agents.go @@ -425,7 +425,7 @@ func listAgents(cmd *cobra.Command, args []string) error { format, _ := cmd.Flags().GetString("format") if format == "json" { - combinedOutput := map[string]interface{}{ + combinedOutput := map[string]any{ "local": localAgents, "external": externalAgents, "total": totalAgents, diff --git a/cmd/chat.go b/cmd/chat.go index 93cc7745..f0a7acc4 100644 --- a/cmd/chat.go +++ b/cmd/chat.go @@ -15,6 +15,8 @@ import ( clipboard "github.com/inference-gateway/cli/internal/clipboard" container "github.com/inference-gateway/cli/internal/container" domain "github.com/inference-gateway/cli/internal/domain" + handlers "github.com/inference-gateway/cli/internal/handlers" + logger "github.com/inference-gateway/cli/internal/logger" sdk "github.com/inference-gateway/sdk" cobra "github.com/spf13/cobra" viper "github.com/spf13/viper" @@ -25,12 +27,20 @@ var chatCmd = &cobra.Command{ Short: "Start an interactive chat session with model selection", Long: `Start an interactive chat session where you can select a model from a dropdown and have a conversational interface with the inference gateway.`, - RunE: func(_ *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, args []string) error { cfg, err := getConfigFromViper() if err != nil { return fmt.Errorf("failed to load config: %w", err) } + headless, _ := cmd.Flags().GetBool("headless") + sessionID, _ := cmd.Flags().GetString("session-id") + conversationID, _ := cmd.Flags().GetString("conversation-id") + + if headless { + return runHeadlessChat(cfg, V, sessionID, conversationID) + } + if !isInteractiveTerminal() { return runNonInteractiveChat(cfg, V) } @@ -56,6 +66,23 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { fmt.Printf(" Make sure the inference gateway is running at: %s\n\n", cfg.Gateway.URL) } + agentManager := services.GetAgentManager() + if agentManager != nil { + agentCtx := context.Background() + if err := agentManager.StartAgents(agentCtx); err != nil { + logger.Error("Failed to start agents", "error", err) + } + } + + mcpManager := services.GetMCPManager() + if mcpManager != nil { + mcpCtx, mcpCancel := context.WithTimeout(context.Background(), 120*time.Second) + defer mcpCancel() + if err := mcpManager.StartServers(mcpCtx); err != nil { + logger.Error("Some MCP servers failed to start", "error", err) + } + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(cfg.Gateway.Timeout)*time.Second) defer cancel() @@ -86,10 +113,8 @@ func StartChatSession(cfg *config.Config, v *viper.Viper) error { messageQueue := services.GetMessageQueue() themeService := services.GetThemeService() toolRegistry := services.GetToolRegistry() - mcpManager := services.GetMCPManager() taskRetentionService := services.GetTaskRetentionService() backgroundTaskService := services.GetBackgroundTaskService() - agentManager := services.GetAgentManager() conversationOptimizer := services.GetConversationOptimizer() versionInfo := GetVersionInfo() @@ -300,6 +325,22 @@ func processStreamingOutput(events <-chan domain.ChatEvent) error { return nil } +// runHeadlessChat runs the chat in headless mode (JSON I/O via stdin/stdout) +func runHeadlessChat(cfg *config.Config, v *viper.Viper, sessionID string, conversationID string) error { + services := container.NewServiceContainer(cfg, v) + + handler := handlers.NewHeadlessHandler(sessionID, conversationID, services, cfg) + defer func() { + _ = handler.Shutdown() + }() + + return handler.Start() +} + func init() { rootCmd.AddCommand(chatCmd) + + chatCmd.Flags().Bool("headless", false, "Run in headless mode (JSON I/O via stdin/stdout)") + chatCmd.Flags().String("session-id", "", "Session identifier for headless mode") + chatCmd.Flags().String("conversation-id", "", "Conversation ID to continue (for headless mode)") } diff --git a/cmd/root.go b/cmd/root.go index 050b3130..4d16e685 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -120,6 +120,16 @@ func initConfig() { // nolint:funlen v.SetDefault("tools.web_fetch.safety.timeout", defaults.Tools.WebFetch.Safety.Timeout) v.SetDefault("tools.web_fetch.safety.allow_redirect", defaults.Tools.WebFetch.Safety.AllowRedirect) v.SetDefault("tools.web_fetch.require_approval", defaults.Tools.WebFetch.RequireApproval) + v.SetDefault("api", defaults.API) + v.SetDefault("api.host", defaults.API.Host) + v.SetDefault("api.port", defaults.API.Port) + v.SetDefault("api.read_timeout", defaults.API.ReadTimeout) + v.SetDefault("api.write_timeout", defaults.API.WriteTimeout) + v.SetDefault("api.idle_timeout", defaults.API.IdleTimeout) + v.SetDefault("api.cors.enabled", defaults.API.CORS.Enabled) + v.SetDefault("api.cors.allowed_origins", defaults.API.CORS.AllowedOrigins) + v.SetDefault("api.cors.allowed_methods", defaults.API.CORS.AllowedMethods) + v.SetDefault("api.cors.allowed_headers", defaults.API.CORS.AllowedHeaders) v.SetDefault("pricing", defaults.Pricing) v.SetDefault("pricing.enabled", defaults.Pricing.Enabled) v.SetDefault("pricing.currency", defaults.Pricing.Currency) diff --git a/cmd/serve.go b/cmd/serve.go new file mode 100644 index 00000000..34c98b74 --- /dev/null +++ b/cmd/serve.go @@ -0,0 +1,331 @@ +package cmd + +import ( + "context" + "fmt" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + config "github.com/inference-gateway/cli/config" + container "github.com/inference-gateway/cli/internal/container" + handlers "github.com/inference-gateway/cli/internal/handlers" + storage "github.com/inference-gateway/cli/internal/infra/storage" + logger "github.com/inference-gateway/cli/internal/logger" + svc "github.com/inference-gateway/cli/internal/services" + history "github.com/inference-gateway/cli/internal/ui/history" + utils "github.com/inference-gateway/cli/internal/utils" + cobra "github.com/spf13/cobra" +) + +var serveCmd = &cobra.Command{ + Use: "serve", + Short: "Start the API server for conversation storage queries", + Long: `Start an HTTP API server that exposes REST endpoints for accessing conversation storage. +This allows the UI and other clients to query conversations, statistics, and manage stored data +without requiring direct database access. + +The API server provides: + - Conversation listing and querying + - Conversation statistics and analytics + - Metadata management + - Health checks + +All headless chat sessions continue to save their conversations directly to the configured +storage backend (JSONL, SQLite, PostgreSQL, or Redis).`, + RunE: func(cmd *cobra.Command, args []string) error { + cfg, err := getConfigFromViper() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + port, _ := cmd.Flags().GetInt("port") + host, _ := cmd.Flags().GetString("host") + enableUI, _ := cmd.Flags().GetBool("ui") + + if port != 0 { + cfg.API.Port = port + } + if host != "" { + cfg.API.Host = host + } + + return startAPIServer(cfg, V, enableUI) + }, +} + +func init() { + rootCmd.AddCommand(serveCmd) + + serveCmd.Flags().Int("port", 0, "API server port (default: 8080)") + serveCmd.Flags().String("host", "", "API server host (default: 127.0.0.1)") + serveCmd.Flags().Bool("ui", false, "Start the web UI and automatically open it in browser") +} + +func startAPIServer(cfg *config.Config, v any, enableUI bool) error { + services, uiManager := initializeServices(cfg, enableUI) + defer cleanupServices(services, uiManager) + + if err := checkStorageHealth(services.GetStorage()); err != nil { + logger.Warn("Storage health check failed", "error", err) + fmt.Printf("Warning: Storage backend may not be available: %v\n", err) + } + + server, serverReady, serverErrors := setupAndStartServer(cfg, services) + handleUIStartup(cfg, uiManager, serverReady, enableUI) + + return waitForShutdown(server, serverErrors) +} + +func initializeServices(cfg *config.Config, enableUI bool) (*container.ServiceContainer, *svc.UIManager) { + services := container.NewServiceContainer(cfg, V) + + var uiManager *svc.UIManager + if enableUI { + uiManager = svc.NewUIManager(cfg) + } + + if agentManager := services.GetAgentManager(); agentManager != nil { + ctx := context.Background() + if err := agentManager.StartAgents(ctx); err != nil { + logger.Error("Failed to start agents", "error", err) + } + } + + if mcpManager := services.GetMCPManager(); mcpManager != nil { + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + if err := mcpManager.StartServers(ctx); err != nil { + logger.Error("Some MCP servers failed to start", "error", err) + } + } + + return services, uiManager +} + +func cleanupServices(services *container.ServiceContainer, uiManager *svc.UIManager) { + if uiManager != nil && uiManager.IsRunning() { + _ = uiManager.Stop() + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = services.Shutdown(ctx) +} + +func checkStorageHealth(strg storage.ConversationStorage) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return strg.Health(ctx) +} + +func setupAndStartServer(cfg *config.Config, services *container.ServiceContainer) (*http.Server, chan struct{}, chan error) { + historyManager, err := history.NewHistoryManager(1000) + if err != nil { + logger.Warn("Failed to initialize history manager", "error", err) + historyManager = nil + } + + apiHandler := handlers.NewAPIHandler( + services.GetStorage(), + services.GetModelService(), + services.GetStateManager(), + services.GetMCPManager(), + historyManager, + ) + + sessionManager := handlers.NewSessionManager() + wsHandler := handlers.NewWebSocketHandler(sessionManager) + + handler := setupHTTPHandlers(cfg, apiHandler, wsHandler) + addr := fmt.Sprintf("%s:%d", cfg.API.Host, cfg.API.Port) + server := &http.Server{ + Addr: addr, + Handler: handler, + ReadTimeout: time.Duration(cfg.API.ReadTimeout) * time.Second, + WriteTimeout: time.Duration(cfg.API.WriteTimeout) * time.Second, + IdleTimeout: time.Duration(cfg.API.IdleTimeout) * time.Second, + } + + serverErrors := make(chan error, 1) + serverReady := make(chan struct{}) + + go startServerAsync(server, addr, cfg, serverErrors, serverReady) + + return server, serverReady, serverErrors +} + +func setupHTTPHandlers(cfg *config.Config, apiHandler *handlers.APIHandler, wsHandler *handlers.WebSocketHandler) http.Handler { + mux := http.NewServeMux() + + mux.HandleFunc("/health", apiHandler.HandleHealth) + mux.HandleFunc("/ws", wsHandler.HandleWebSocket) + mux.HandleFunc("/api/v1/models", apiHandler.HandleListModels) + mux.HandleFunc("/api/v1/conversations", apiHandler.HandleListConversations) + mux.HandleFunc("/api/v1/conversations/", apiHandler.HandleConversationByID) + mux.HandleFunc("/api/v1/conversations/needs-titles", apiHandler.HandleConversationsNeedingTitles) + mux.HandleFunc("/api/v1/agents/status", apiHandler.HandleAgentsStatus) + mux.HandleFunc("/api/v1/mcp/status", apiHandler.HandleMCPStatus) + mux.HandleFunc("/api/history", apiHandler.HandleHistory) + + handler := http.Handler(mux) + if cfg.API.CORS.Enabled { + handler = enableCORS(handler, cfg.API.CORS) + } + + return handler +} + +func startServerAsync(server *http.Server, addr string, cfg *config.Config, serverErrors chan error, serverReady chan struct{}) { + logger.Info("Starting API server", "address", addr) + + go func() { + serverErrors <- server.ListenAndServe() + }() + + if waitForServerReady(addr) { + printServerInfo(addr, cfg) + close(serverReady) + } +} + +func waitForServerReady(addr string) bool { + apiURL := fmt.Sprintf("http://%s/health", addr) + for i := 0; i < 20; i++ { + time.Sleep(100 * time.Millisecond) + resp, err := http.Get(apiURL) + if err == nil { + if closeErr := resp.Body.Close(); closeErr != nil { + logger.Warn("Failed to close response body", "error", closeErr) + } + if resp.StatusCode == http.StatusOK { + return true + } + } + } + return false +} + +func printServerInfo(addr string, cfg *config.Config) { + fmt.Printf("API server listening on http://%s\n", addr) + fmt.Printf(" Storage: %s\n", cfg.Storage.Type) + fmt.Printf("\nAvailable endpoints:\n") + fmt.Printf(" GET /health - Health check\n") + fmt.Printf(" WS /ws - WebSocket for live chat\n") + fmt.Printf(" GET /api/v1/models - List available models\n") + fmt.Printf(" GET /api/v1/conversations - List conversations\n") + fmt.Printf(" GET /api/v1/conversations/:id - Get conversation\n") + fmt.Printf(" DELETE /api/v1/conversations/:id - Delete conversation\n") + fmt.Printf(" PATCH /api/v1/conversations/:id/metadata - Update metadata\n") + fmt.Printf(" GET /api/v1/conversations/needs-titles - List conversations needing titles\n") + fmt.Printf(" GET /api/v1/agents/status - Get A2A agents status\n") + fmt.Printf(" GET /api/v1/mcp/status - Get MCP servers status\n") + fmt.Printf(" GET /api/history - Get command history\n") + fmt.Printf(" POST /api/history - Save command to history\n\n") +} + +func handleUIStartup(cfg *config.Config, uiManager *svc.UIManager, serverReady chan struct{}, enableUI bool) { + if !enableUI || uiManager == nil { + return + } + + go func() { + <-serverReady + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + if err := uiManager.Start(ctx); err != nil { + logger.Error("Failed to start UI server", "error", err) + fmt.Printf("Warning: UI server failed to start: %v\n", err) + return + } + + if cfg.API.UI.AutoOpen { + openBrowserForUI(uiManager) + } + }() +} + +func openBrowserForUI(uiManager *svc.UIManager) { + uiURL := uiManager.GetURL() + logger.Info("Opening browser", "url", uiURL) + if err := utils.OpenBrowser(uiURL); err != nil { + logger.Warn("Failed to open browser", "error", err) + fmt.Printf("Could not open browser automatically. Please visit %s\n", uiURL) + } else { + fmt.Printf("Browser opened at %s\n", uiURL) + } +} + +func waitForShutdown(server *http.Server, serverErrors chan error) error { + shutdown := make(chan os.Signal, 1) + signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM) + + select { + case err := <-serverErrors: + return fmt.Errorf("server error: %w", err) + case sig := <-shutdown: + logger.Info("Shutting down API server", "signal", sig) + fmt.Printf("\nShutting down gracefully...\n") + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := server.Shutdown(ctx); err != nil { + if closeErr := server.Close(); closeErr != nil { + logger.Warn("Failed to force close server", "error", closeErr) + } + return fmt.Errorf("could not stop server gracefully: %w", err) + } + } + + return nil +} + +// enableCORS wraps the handler with CORS middleware +func enableCORS(next http.Handler, corsConfig config.CORSConfig) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + + allowed := false + for _, allowedOrigin := range corsConfig.AllowedOrigins { + if allowedOrigin == "*" || allowedOrigin == origin { + allowed = true + break + } + } + + if allowed { + if origin != "" { + w.Header().Set("Access-Control-Allow-Origin", origin) + } else if len(corsConfig.AllowedOrigins) > 0 { + w.Header().Set("Access-Control-Allow-Origin", corsConfig.AllowedOrigins[0]) + } + w.Header().Set("Access-Control-Allow-Methods", joinStrings(corsConfig.AllowedMethods, ", ")) + w.Header().Set("Access-Control-Allow-Headers", joinStrings(corsConfig.AllowedHeaders, ", ")) + w.Header().Set("Access-Control-Max-Age", "86400") // 24 hours + } + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) +} + +// joinStrings joins a slice of strings with a separator +func joinStrings(strs []string, sep string) string { + if len(strs) == 0 { + return "" + } + result := strs[0] + for _, s := range strs[1:] { + result += sep + s + } + return result +} diff --git a/config/api.go b/config/api.go new file mode 100644 index 00000000..afec83a4 --- /dev/null +++ b/config/api.go @@ -0,0 +1,28 @@ +package config + +// APIConfig contains HTTP API server settings +type APIConfig struct { + Host string `yaml:"host" mapstructure:"host"` + Port int `yaml:"port" mapstructure:"port"` + ReadTimeout int `yaml:"read_timeout" mapstructure:"read_timeout"` + WriteTimeout int `yaml:"write_timeout" mapstructure:"write_timeout"` + IdleTimeout int `yaml:"idle_timeout" mapstructure:"idle_timeout"` + CORS CORSConfig `yaml:"cors" mapstructure:"cors"` + UI UIConfig `yaml:"ui" mapstructure:"ui"` +} + +// CORSConfig contains CORS (Cross-Origin Resource Sharing) settings +type CORSConfig struct { + Enabled bool `yaml:"enabled" mapstructure:"enabled"` + AllowedOrigins []string `yaml:"allowed_origins" mapstructure:"allowed_origins"` + AllowedMethods []string `yaml:"allowed_methods" mapstructure:"allowed_methods"` + AllowedHeaders []string `yaml:"allowed_headers" mapstructure:"allowed_headers"` +} + +// UIConfig contains web UI settings +type UIConfig struct { + Port int `yaml:"port" mapstructure:"port"` + AutoOpen bool `yaml:"auto_open" mapstructure:"auto_open"` + Mode string `yaml:"mode" mapstructure:"mode"` + WorkingDir string `yaml:"working_dir" mapstructure:"working_dir"` +} diff --git a/config/config.go b/config/config.go index cbda45b6..bbaaf5d1 100644 --- a/config/config.go +++ b/config/config.go @@ -37,6 +37,7 @@ type Config struct { Chat ChatConfig `yaml:"chat" mapstructure:"chat"` A2A A2AConfig `yaml:"a2a" mapstructure:"a2a"` MCP MCPConfig `yaml:"mcp" mapstructure:"mcp"` + API APIConfig `yaml:"api" mapstructure:"api"` Pricing PricingConfig `yaml:"pricing" mapstructure:"pricing"` Init InitConfig `yaml:"init" mapstructure:"init"` Compact CompactConfig `yaml:"compact" mapstructure:"compact"` @@ -887,7 +888,26 @@ Respond with ONLY the title, no quotes or explanation.`, }, }, }, - MCP: *DefaultMCPConfig(), + MCP: *DefaultMCPConfig(), + API: APIConfig{ + Host: "127.0.0.1", + Port: 8081, + ReadTimeout: 30, + WriteTimeout: 30, + IdleTimeout: 120, + CORS: CORSConfig{ + Enabled: true, + AllowedOrigins: []string{"http://localhost:3000", "http://localhost:3001"}, + AllowedMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"}, + AllowedHeaders: []string{"Content-Type", "Authorization"}, + }, + UI: UIConfig{ + Port: 3000, + AutoOpen: true, + Mode: "npm", + WorkingDir: "./ui", + }, + }, Pricing: GetDefaultPricingConfig(), Init: InitConfig{ Prompt: `Please analyze this project and generate a comprehensive AGENTS.md file. Start by using the Tree tool to understand the project structure. diff --git a/go.mod b/go.mod index 6fbeeed1..7173dafd 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/charmbracelet/lipgloss v1.1.1-0.20250404203927-76690c660834 github.com/go-redis/redis/v8 v8.11.5 github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 github.com/inference-gateway/adk v0.16.2 github.com/inference-gateway/sdk v1.14.1 github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 diff --git a/go.sum b/go.sum index 81f4d104..32f0ca5a 100644 --- a/go.sum +++ b/go.sum @@ -92,6 +92,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/internal/container/container.go b/internal/container/container.go index 05437f38..c9081233 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -125,7 +125,8 @@ func (c *ServiceContainer) initializeGatewayManager() { c.gatewayManager = services.NewGatewayManager(c.sessionID, c.config, c.containerRuntime) } -// initializeAgentManager creates and starts the agent manager if A2A is enabled +// initializeAgentManager creates the agent manager if A2A is enabled +// Note: This does NOT start agents. Caller must explicitly call agentManager.StartAgents(ctx). func (c *ServiceContainer) initializeAgentManager() { agentsPath := filepath.Join(config.ConfigDirName, config.AgentsFileName) c.agentsConfigService = services.NewAgentsConfigService(agentsPath) @@ -160,11 +161,6 @@ func (c *ServiceContainer) initializeAgentManager() { c.agentManager.SetStatusCallback(func(agentName string, state domain.AgentState, message string, url string, image string) { c.stateManager.UpdateAgentStatus(agentName, state, message, url, image) }) - - ctx := context.Background() - if err := c.agentManager.StartAgents(ctx); err != nil { - logger.Warn("Failed to start agents in background", "error", err) - } } // initializeFileWriterServices creates the new file writer architecture services @@ -176,38 +172,14 @@ func (c *ServiceContainer) initializeFileWriterServices() { c.paramExtractor = tools.NewParameterExtractor() } -// initializeMCPManager creates and starts MCP manager if enabled +// initializeMCPManager creates the MCP manager if enabled +// Note: This does NOT start MCP servers. Caller must explicitly call mcpManager.StartServers(ctx). func (c *ServiceContainer) initializeMCPManager() { if !c.config.MCP.Enabled { return } c.mcpManager = services.NewMCPManager(c.sessionID, &c.config.MCP, c.containerRuntime) - - hasServersToStart := c.hasAutoStartMCPServers() - if !hasServersToStart { - return - } - - logger.Info("Starting MCP servers in background...") - go func() { - ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) - defer cancel() - - if err := c.mcpManager.StartServers(ctx); err != nil { - logger.Warn("Some MCP servers failed to start", "error", err) - } - }() -} - -// hasAutoStartMCPServers checks if any MCP servers are configured for auto-start -func (c *ServiceContainer) hasAutoStartMCPServers() bool { - for _, server := range c.config.MCP.Servers { - if server.Run && server.Enabled { - return true - } - } - return false } // initializeDomainServices creates and wires domain service implementations diff --git a/internal/domain/chat_events.go b/internal/domain/chat_events.go index cca9effe..5d5e18bc 100644 --- a/internal/domain/chat_events.go +++ b/internal/domain/chat_events.go @@ -4,9 +4,10 @@ import "time" // ToolInfo represents basic tool information for UI display type ToolInfo struct { - CallID string - Name string - Status string + CallID string + Name string + Status string + Arguments string } // BaseChatEvent provides common implementation for ChatEvent interface @@ -31,6 +32,7 @@ type ToolExecutionProgressEvent struct { ToolName string Status string Message string + Result string } // BashOutputChunkEvent indicates a new chunk of bash output is available diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 86c162fe..80cd5087 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -102,6 +102,7 @@ type ConversationRepository interface { FormatToolResultExpanded(result *ToolExecutionResult, terminalWidth int) string RemovePendingToolCallByID(toolCallID string) StartNewConversation(title string) error + GetCurrentConversationID() string DeleteMessagesAfterIndex(index int) error } diff --git a/internal/handlers/api_handlers.go b/internal/handlers/api_handlers.go new file mode 100644 index 00000000..98fb7e4a --- /dev/null +++ b/internal/handlers/api_handlers.go @@ -0,0 +1,464 @@ +package handlers + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + domain "github.com/inference-gateway/cli/internal/domain" + storage "github.com/inference-gateway/cli/internal/infra/storage" + logger "github.com/inference-gateway/cli/internal/logger" + history "github.com/inference-gateway/cli/internal/ui/history" +) + +// APIHandler handles HTTP API requests for conversation storage +type APIHandler struct { + storage storage.ConversationStorage + modelService domain.ModelService + stateManager domain.StateManager + mcpManager domain.MCPManager + historyManager *history.HistoryManager +} + +// NewAPIHandler creates a new API handler +func NewAPIHandler( + storage storage.ConversationStorage, + modelService domain.ModelService, + stateManager domain.StateManager, + mcpManager domain.MCPManager, + historyManager *history.HistoryManager, +) *APIHandler { + return &APIHandler{ + storage: storage, + modelService: modelService, + stateManager: stateManager, + mcpManager: mcpManager, + historyManager: historyManager, + } +} + +// writeJSON writes a JSON response and logs errors +func (h *APIHandler) writeJSON(w http.ResponseWriter, statusCode int, data any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + if err := json.NewEncoder(w).Encode(data); err != nil { + logger.Error("Failed to encode JSON response", "error", err) + } +} + +// HandleHealth handles health check requests +func (h *APIHandler) HandleHealth(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) + defer cancel() + + err := h.storage.Health(ctx) + if err != nil { + logger.Error("Storage health check failed", "error", err) + h.writeJSON(w, http.StatusServiceUnavailable, map[string]any{ + "status": "unhealthy", + "error": err.Error(), + "time": time.Now().UTC().Format(time.RFC3339), + }) + return + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "status": "healthy", + "time": time.Now().UTC().Format(time.RFC3339), + }) +} + +// HandleListConversations handles GET /api/v1/conversations +func (h *APIHandler) HandleListConversations(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + limit := 50 + offset := 0 + + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { + limit = l + } + } + + if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" { + if o, err := strconv.Atoi(offsetStr); err == nil && o >= 0 { + offset = o + } + } + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + conversations, err := h.storage.ListConversations(ctx, limit, offset) + if err != nil { + logger.Error("Failed to list conversations", "error", err) + h.writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": fmt.Sprintf("Failed to list conversations: %v", err), + }) + return + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "conversations": conversations, + "count": len(conversations), + "limit": limit, + "offset": offset, + }) +} + +// HandleConversationByID handles conversation-specific operations +// GET /api/v1/conversations/:id - Get conversation +// DELETE /api/v1/conversations/:id - Delete conversation +// PATCH /api/v1/conversations/:id/metadata - Update metadata +func (h *APIHandler) HandleConversationByID(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/api/v1/conversations/") + if path == "" { + http.Error(w, "Conversation ID required", http.StatusBadRequest) + return + } + + parts := strings.Split(path, "/") + conversationID := parts[0] + + if len(parts) > 1 && parts[1] == "metadata" { + if r.Method == http.MethodPatch { + h.handleUpdateMetadata(w, r, conversationID) + return + } + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + switch r.Method { + case http.MethodGet: + h.handleGetConversation(w, r, conversationID) + case http.MethodDelete: + h.handleDeleteConversation(w, r, conversationID) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handleGetConversation retrieves a specific conversation +func (h *APIHandler) handleGetConversation(w http.ResponseWriter, r *http.Request, conversationID string) { + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + entries, metadata, err := h.storage.LoadConversation(ctx, conversationID) + if err != nil { + logger.Error("Failed to load conversation", "id", conversationID, "error", err) + h.writeJSON(w, http.StatusNotFound, map[string]any{ + "error": fmt.Sprintf("Conversation not found: %v", err), + }) + return + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "id": conversationID, + "entries": entries, + "metadata": metadata, + }) +} + +// handleDeleteConversation deletes a conversation +func (h *APIHandler) handleDeleteConversation(w http.ResponseWriter, r *http.Request, conversationID string) { + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + err := h.storage.DeleteConversation(ctx, conversationID) + if err != nil { + logger.Error("Failed to delete conversation", "id", conversationID, "error", err) + h.writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": fmt.Sprintf("Failed to delete conversation: %v", err), + }) + return + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "Conversation deleted successfully", + }) +} + +// handleUpdateMetadata updates conversation metadata +func (h *APIHandler) handleUpdateMetadata(w http.ResponseWriter, r *http.Request, conversationID string) { + var updates storage.ConversationMetadata + if err := json.NewDecoder(r.Body).Decode(&updates); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + updates.ID = conversationID + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + err := h.storage.UpdateConversationMetadata(ctx, conversationID, updates) + if err != nil { + logger.Error("Failed to update conversation metadata", "id", conversationID, "error", err) + h.writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": fmt.Sprintf("Failed to update metadata: %v", err), + }) + return + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "Metadata updated successfully", + }) +} + +// HandleConversationsNeedingTitles handles GET /api/v1/conversations/needs-titles +func (h *APIHandler) HandleConversationsNeedingTitles(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + limit := 10 + if limitStr := r.URL.Query().Get("limit"); limitStr != "" { + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 { + limit = l + } + } + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + conversations, err := h.storage.ListConversationsNeedingTitles(ctx, limit) + if err != nil { + logger.Error("Failed to list conversations needing titles", "error", err) + h.writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": fmt.Sprintf("Failed to list conversations: %v", err), + }) + return + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "conversations": conversations, + "count": len(conversations), + }) +} + +// HandleListModels handles GET /api/v1/models +func (h *APIHandler) HandleListModels(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + + models, err := h.modelService.ListModels(ctx) + if err != nil { + logger.Error("Failed to list models", "error", err) + h.writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": fmt.Sprintf("Failed to list models: %v", err), + }) + return + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "models": models, + "count": len(models), + }) +} + +// HandleAgentsStatus handles GET /api/v1/agents/status +func (h *APIHandler) HandleAgentsStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + agentReadiness := h.stateManager.GetAgentReadiness() + + var agents []map[string]any + totalAgents := 0 + readyAgents := 0 + + if agentReadiness != nil { + totalAgents = agentReadiness.TotalAgents + readyAgents = agentReadiness.ReadyAgents + + agents = make([]map[string]any, 0, len(agentReadiness.Agents)) + for _, agent := range agentReadiness.Agents { + errorMsg := agent.Message + if agent.Error != "" { + errorMsg = agent.Error + } + agents = append(agents, map[string]any{ + "name": agent.Name, + "state": agent.State.String(), + "url": agent.URL, + "image": agent.Image, + "error": errorMsg, + }) + } + } else { + agents = []map[string]any{} + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "total_agents": totalAgents, + "ready_agents": readyAgents, + "agents": agents, + }) +} + +// HandleMCPStatus handles GET /api/v1/mcp/status +func (h *APIHandler) HandleMCPStatus(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + if h.mcpManager == nil { + h.writeJSON(w, http.StatusOK, map[string]any{ + "total_servers": 0, + "connected_servers": 0, + "total_tools": 0, + "servers": []map[string]any{}, + }) + return + } + + clients := h.mcpManager.GetClients() + totalServers := h.mcpManager.GetTotalServers() + + servers := make([]map[string]any, 0, len(clients)) + connectedServers := 0 + totalTools := 0 + + for _, client := range clients { + ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second) + toolsMap, err := client.DiscoverTools(ctx) + cancel() + + toolCount := 0 + connected := err == nil + + if connected { + connectedServers++ + for _, tools := range toolsMap { + toolCount += len(tools) + } + totalTools += toolCount + } + + servers = append(servers, map[string]any{ + "name": fmt.Sprintf("server-%d", len(servers)+1), + "connected": connected, + "tools": toolCount, + }) + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "total_servers": totalServers, + "connected_servers": connectedServers, + "total_tools": totalTools, + "servers": servers, + }) +} + +// HandleHistory handles history requests +// GET /api/history - Get all history commands +// POST /api/history - Save a new command to history +func (h *APIHandler) HandleHistory(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + h.handleGetHistory(w, r) + case http.MethodPost: + h.handleSaveHistory(w, r) + default: + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } +} + +// handleGetHistory retrieves all history commands +func (h *APIHandler) handleGetHistory(w http.ResponseWriter, _ *http.Request) { + if h.historyManager == nil { + h.writeJSON(w, http.StatusOK, map[string]any{ + "history": []string{}, + "count": 0, + }) + return + } + + shellHistory := h.historyManager.GetShellHistoryFile() + shellProvider, err := history.NewShellHistoryWithDir(".infer") + if err != nil { + logger.Error("Failed to initialize shell history", "error", err) + h.writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": fmt.Sprintf("Failed to load history: %v", err), + }) + return + } + + commands, err := shellProvider.LoadHistory() + if err != nil { + logger.Error("Failed to load history", "file", shellHistory, "error", err) + h.writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": fmt.Sprintf("Failed to load history: %v", err), + }) + return + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "history": commands, + "count": len(commands), + }) +} + +// handleSaveHistory saves a new command to history +func (h *APIHandler) handleSaveHistory(w http.ResponseWriter, r *http.Request) { + var request struct { + Command string `json:"command"` + } + + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + http.Error(w, fmt.Sprintf("Invalid request body: %v", err), http.StatusBadRequest) + return + } + + if strings.TrimSpace(request.Command) == "" { + http.Error(w, "Command cannot be empty", http.StatusBadRequest) + return + } + + if h.historyManager == nil { + logger.Warn("History manager not available") + h.writeJSON(w, http.StatusOK, map[string]any{ + "success": false, + "message": "History not available", + }) + return + } + + if err := h.historyManager.AddToHistory(request.Command); err != nil { + logger.Error("Failed to save to history", "error", err) + h.writeJSON(w, http.StatusInternalServerError, map[string]any{ + "error": fmt.Sprintf("Failed to save history: %v", err), + }) + return + } + + h.writeJSON(w, http.StatusOK, map[string]any{ + "success": true, + "message": "Command saved to history", + }) +} diff --git a/internal/handlers/chat_command_handler.go b/internal/handlers/chat_command_handler.go index df26b669..52e09032 100644 --- a/internal/handlers/chat_command_handler.go +++ b/internal/handlers/chat_command_handler.go @@ -110,6 +110,8 @@ func (c *ChatCommandHandler) executeBashCommand(commandText, command string) tea } _ = c.handler.conversationRepo.AddMessage(userEntry) + argsJSON := fmt.Sprintf(`{"command":%q}`, command) + return tea.Batch( func() tea.Msg { return domain.UpdateHistoryEvent{ @@ -133,9 +135,10 @@ func (c *ChatCommandHandler) executeBashCommand(commandText, command string) tea }, Tools: []domain.ToolInfo{ { - CallID: toolCallID, - Name: "Bash", - Status: "starting", + CallID: toolCallID, + Name: "Bash", + Status: "starting", + Arguments: argsJSON, }, }, } @@ -492,9 +495,10 @@ func (c *ChatCommandHandler) executeToolCommand(commandText, toolName, argsJSON }, Tools: []domain.ToolInfo{ { - CallID: toolCallID, - Name: toolName, - Status: "starting", + CallID: toolCallID, + Name: toolName, + Status: "starting", + Arguments: argsJSON, }, }, } diff --git a/internal/handlers/headless_handler.go b/internal/handlers/headless_handler.go new file mode 100644 index 00000000..498b407d --- /dev/null +++ b/internal/handlers/headless_handler.go @@ -0,0 +1,824 @@ +package handlers + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os" + "sync" + "time" + + "github.com/google/uuid" + "github.com/inference-gateway/cli/config" + "github.com/inference-gateway/cli/internal/container" + "github.com/inference-gateway/cli/internal/domain" + "github.com/inference-gateway/cli/internal/logger" + "github.com/inference-gateway/sdk" +) + +// AG-UI Protocol Event Types +// Ref: https://docs.ag-ui.com/concepts/events + +// AGUIEvent is the base interface for all AG-UI events +type AGUIEvent struct { + Type string `json:"type"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// RunStarted signals the start of an agent run +type RunStarted struct { + Type string `json:"type"` // "RunStarted" + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + ParentRunID string `json:"parentRunId,omitempty"` + Input any `json:"input,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// RunFinished signals the successful completion of an agent run +type RunFinished struct { + Type string `json:"type"` // "RunFinished" + ThreadID string `json:"threadId"` + RunID string `json:"runId"` + Result any `json:"result,omitempty"` + Outcome string `json:"outcome,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// RunError signals an error during an agent run +type RunError struct { + Type string `json:"type"` // "RunError" + Message string `json:"message"` + Code string `json:"code,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// TextMessageStart initializes a new text message in the conversation +type TextMessageStart struct { + Type string `json:"type"` // "TextMessageStart" + MessageID string `json:"messageId"` + Role string `json:"role"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// TextMessageContent delivers incremental parts of message text as available +type TextMessageContent struct { + Type string `json:"type"` // "TextMessageContent" + MessageID string `json:"messageId"` + Delta string `json:"delta"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// TextMessageEnd marks the completion of a streaming text message +type TextMessageEnd struct { + Type string `json:"type"` // "TextMessageEnd" + MessageID string `json:"messageId"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// ToolCallStart indicates the agent is invoking a tool +type ToolCallStart struct { + Type string `json:"type"` + ToolCallID string `json:"toolCallId"` + ToolCallName string `json:"toolCallName"` + Arguments string `json:"arguments,omitempty"` + ParentMessageID string `json:"parentMessageId,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + Status string `json:"status,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// ToolCallArgs delivers incremental parts of tool argument data +type ToolCallArgs struct { + Type string `json:"type"` + ToolCallID string `json:"toolCallId"` + Delta string `json:"delta"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// ToolCallEnd marks the completion of a tool call specification +type ToolCallEnd struct { + Type string `json:"type"` // "ToolCallEnd" + ToolCallID string `json:"toolCallId"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// ToolCallProgress delivers intermediate progress updates during tool execution +type ToolCallProgress struct { + Type string `json:"type"` // "ToolCallProgress" + ToolCallID string `json:"toolCallId"` + Status string `json:"status"` // Current execution status + Message string `json:"message"` // Human-readable status message + Output string `json:"output,omitempty"` // Streaming output (for Bash) + Metadata map[string]any `json:"metadata,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` +} + +// ToolCallResult provides the output/result from executed tool +type ToolCallResult struct { + Type string `json:"type"` // "ToolCallResult" + MessageID string `json:"messageId"` + ToolCallID string `json:"toolCallId"` + Content any `json:"content"` + Role string `json:"role,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + Status string `json:"status,omitempty"` // Final status: "complete" or "failed" + Duration float64 `json:"duration,omitempty"` // Execution time in seconds + Metadata map[string]any `json:"metadata,omitempty"` // Additional metadata (e.g., exit code, error details) +} + +// ParallelToolsMetadata provides summary information for parallel tool execution +type ParallelToolsMetadata struct { + Type string `json:"type"` // "ParallelToolsMetadata" + TotalCount int `json:"totalCount"` + SuccessCount int `json:"successCount,omitempty"` + FailureCount int `json:"failureCount,omitempty"` + TotalDuration float64 `json:"totalDuration,omitempty"` // Total execution time in seconds + Timestamp int64 `json:"timestamp,omitempty"` +} + +// HeadlessInput represents JSON input from UI +type HeadlessInput struct { + Type string `json:"type"` // "message", "interrupt", "shutdown" + Content string `json:"content"` // User message content + Images []domain.ImageAttachment `json:"images"` // Image attachments + Model string `json:"model,omitempty"` // Optional model to use for this message +} + +// toolExecutionState tracks the execution state of a single tool call +type toolExecutionState struct { + CallID string + ToolName string + StartTime time.Time + Status string + OutputBuffer []string +} + +// HeadlessHandler handles headless mode communication via stdin/stdout +// Implements AG-UI protocol: https://docs.ag-ui.com +type HeadlessHandler struct { + sessionID string + conversationID string + services *container.ServiceContainer + config *config.Config + stdin io.Reader + stdout io.Writer + ctx context.Context + cancel context.CancelFunc + currentRunID string + toolStates map[string]*toolExecutionState + toolStatesMux sync.RWMutex +} + +// NewHeadlessHandler creates a new headless handler +func NewHeadlessHandler(sessionID string, conversationID string, services *container.ServiceContainer, cfg *config.Config) *HeadlessHandler { + if sessionID == "" { + sessionID = uuid.New().String() + } + + ctx, cancel := context.WithCancel(context.Background()) + + return &HeadlessHandler{ + sessionID: sessionID, + conversationID: conversationID, + services: services, + config: cfg, + stdin: os.Stdin, + stdout: os.Stdout, + ctx: ctx, + cancel: cancel, + } +} + +// Start begins the headless session +func (h *HeadlessHandler) Start() error { + if err := h.services.GetGatewayManager().EnsureStarted(); err != nil { + logger.Error("Failed to start gateway", "error", err) + h.emitRunError(fmt.Sprintf("failed to start gateway: %v", err), "GATEWAY_START_FAILED") + return err + } + + ctx, cancel := context.WithTimeout(h.ctx, time.Duration(h.config.Gateway.Timeout)*time.Second) + defer cancel() + + models, err := h.services.GetModelService().ListModels(ctx) + if err != nil { + h.emitRunError(fmt.Sprintf("inference gateway is not available: %v", err), "GATEWAY_UNAVAILABLE") + return fmt.Errorf("inference gateway is not available: %w", err) + } + + if len(models) == 0 { + h.emitRunError("no models available from inference gateway", "NO_MODELS") + return fmt.Errorf("no models available from inference gateway") + } + + defaultModel := h.config.Agent.Model + if defaultModel == "" || !contains(models, defaultModel) { + defaultModel = models[0] + } + + if err := h.services.GetModelService().SelectModel(defaultModel); err != nil { + h.emitRunError(fmt.Sprintf("failed to set model: %v", err), "MODEL_SELECT_FAILED") + return fmt.Errorf("failed to set model: %w", err) + } + + conversationRepo := h.services.GetConversationRepository() + h.loadExistingConversation(conversationRepo) + + h.emitEvent(RunStarted{ + Type: "RunStarted", + ThreadID: h.conversationID, + RunID: h.sessionID, + Timestamp: time.Now().UnixMilli(), + }) + + return h.readLoop() +} + +// readLoop continuously reads JSON input from stdin +func (h *HeadlessHandler) readLoop() error { + scanner := bufio.NewScanner(h.stdin) + scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024) // 1MB initial, 10MB max for large images + + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + var input HeadlessInput + if err := json.Unmarshal(line, &input); err != nil { + logger.Error("Failed to parse JSON input", "error", err, "line", string(line)) + h.emitRunError(fmt.Sprintf("invalid JSON input: %v", err), "INVALID_INPUT") + continue + } + + if err := h.processInput(input); err != nil { + if err == io.EOF { + return nil + } + logger.Error("Failed to process input", "error", err, "input", input) + h.emitRunError(fmt.Sprintf("failed to process input: %v", err), "PROCESSING_FAILED") + } + } + + if err := scanner.Err(); err != nil { + logger.Error("Error reading from stdin", "error", err) + return fmt.Errorf("error reading from stdin: %w", err) + } + + return nil +} + +// processInput handles a single input message +func (h *HeadlessHandler) processInput(input HeadlessInput) error { + switch input.Type { + case "message": + return h.handleMessage(input.Content, input.Images, input.Model) + case "interrupt": + return h.handleInterrupt() + case "shutdown": + h.cancel() + return io.EOF // Signal clean shutdown + default: + return fmt.Errorf("unknown input type: %s", input.Type) + } +} + +// handleMessage processes a user message +func (h *HeadlessHandler) handleMessage(content string, images []domain.ImageAttachment, model string) error { + if content == "" { + return fmt.Errorf("empty message content") + } + + if model != "" { + modelService := h.services.GetModelService() + if err := modelService.SelectModel(model); err != nil { + h.emitRunError(fmt.Sprintf("Failed to select model '%s': %v", model, err), "MODEL_SELECT_FAILED") + return fmt.Errorf("failed to select model: %w", err) + } + logger.Info("Model selected for headless session", "model", model, "session", h.sessionID) + } + + h.currentRunID = uuid.New().String() + + userMessageID := uuid.New().String() + h.emitEvent(TextMessageStart{ + Type: "TextMessageStart", + MessageID: userMessageID, + Role: "user", + Timestamp: time.Now().UnixMilli(), + }) + h.emitEvent(TextMessageContent{ + Type: "TextMessageContent", + MessageID: userMessageID, + Delta: content, + Timestamp: time.Now().UnixMilli(), + }) + h.emitEvent(TextMessageEnd{ + Type: "TextMessageEnd", + MessageID: userMessageID, + Timestamp: time.Now().UnixMilli(), + }) + + h.emitEvent(RunStarted{ + Type: "RunStarted", + ThreadID: h.conversationID, + RunID: h.currentRunID, + Input: map[string]any{ + "content": content, + "images": len(images), + }, + Timestamp: time.Now().UnixMilli(), + }) + + var userMessage sdk.Message + + if len(images) > 0 { + contentParts := []sdk.ContentPart{} + + var textPart sdk.ContentPart + if err := textPart.FromTextContentPart(sdk.TextContentPart{ + Type: sdk.Text, + Text: content, + }); err != nil { + h.emitRunError(fmt.Sprintf("failed to create text content part: %v", err), "CONTENT_PART_FAILED") + return fmt.Errorf("failed to create text content part: %w", err) + } + contentParts = append(contentParts, textPart) + + for _, img := range images { + dataURL := fmt.Sprintf("data:%s;base64,%s", img.MimeType, img.Data) + imagePart, err := sdk.NewImageContentPart(dataURL, nil) + if err != nil { + h.emitRunError(fmt.Sprintf("failed to create image content part: %v", err), "IMAGE_PART_FAILED") + return fmt.Errorf("failed to create image content part: %w", err) + } + contentParts = append(contentParts, imagePart) + } + + userMessage = sdk.Message{ + Role: sdk.User, + Content: sdk.NewMessageContent(contentParts), + } + } else { + userMessage = sdk.Message{ + Role: sdk.User, + Content: sdk.NewMessageContent(content), + } + } + + conversationRepo := h.services.GetConversationRepository() + + wasNewConversation := h.conversationID == "" + + if err := conversationRepo.AddMessage(domain.ConversationEntry{ + Message: userMessage, + Time: time.Now(), + }); err != nil { + logger.Warn("Failed to add user message to conversation", "error", err) + } + + if wasNewConversation { + if persistentRepo, ok := conversationRepo.(interface { + GetCurrentConversationID() string + }); ok { + h.conversationID = persistentRepo.GetCurrentConversationID() + if h.conversationID != "" { + h.emitEvent(map[string]any{ + "type": "ConversationCreated", + "conversation_id": h.conversationID, + "timestamp": time.Now().UnixMilli(), + }) + } + } + } + + entries := conversationRepo.GetMessages() + messages := make([]sdk.Message, len(entries)) + for i, entry := range entries { + messages[i] = entry.Message + } + + req := &domain.AgentRequest{ + RequestID: fmt.Sprintf("req_%d", time.Now().UnixNano()), + Model: h.services.GetModelService().GetCurrentModel(), + Messages: messages, + } + + ctx, cancel := context.WithCancel(h.ctx) + defer cancel() + + agentService := h.services.GetAgentService() + events, err := agentService.RunWithStream(ctx, req) + if err != nil { + h.emitRunError(fmt.Sprintf("failed to start chat: %v", err), "CHAT_START_FAILED") + return fmt.Errorf("failed to start chat: %w", err) + } + + return h.processStreamingEvents(events) +} + +// processStreamingEvents handles streaming chat events +func (h *HeadlessHandler) processStreamingEvents(events <-chan domain.ChatEvent) error { + var messageID string + var tokenStats map[string]int + + for { + select { + case <-h.ctx.Done(): + return h.ctx.Err() + case event, ok := <-events: + if !ok { + return h.handleStreamComplete(messageID, tokenStats) + } + + var err error + messageID, tokenStats, err = h.handleEvent(event, messageID, tokenStats) + if err != nil { + return err + } + } + } +} + +func (h *HeadlessHandler) handleStreamComplete(messageID string, tokenStats map[string]int) error { + h.emitEvent(RunFinished{ + Type: "RunFinished", + ThreadID: h.conversationID, + RunID: h.currentRunID, + Result: map[string]any{ + "message_id": messageID, + "tokens": tokenStats, + }, + Outcome: "success", + Timestamp: time.Now().UnixMilli(), + }) + return nil +} + +func (h *HeadlessHandler) handleEvent(event domain.ChatEvent, messageID string, tokenStats map[string]int) (string, map[string]int, error) { + switch e := event.(type) { + case domain.ChatChunkEvent: + return h.handleChatChunk(e, messageID), tokenStats, nil + case domain.ChatCompleteEvent: + return h.handleChatComplete(e, messageID) + case domain.ChatErrorEvent: + return messageID, tokenStats, h.handleChatError(e) + case domain.ToolCallReadyEvent: + return h.handleToolCallReady(e, messageID), tokenStats, nil + case domain.ParallelToolsStartEvent: + return h.handleParallelToolsStart(e, messageID), tokenStats, nil + case domain.ToolExecutionProgressEvent: + h.handleToolExecutionProgress(e) + return messageID, tokenStats, nil + case domain.BashOutputChunkEvent: + h.handleBashOutputChunk(e) + return messageID, tokenStats, nil + case domain.ParallelToolsCompleteEvent: + h.handleParallelToolsComplete(e) + return messageID, tokenStats, nil + case domain.ToolApprovalRequestedEvent: + h.handleToolApprovalRequested(e, messageID) + return messageID, tokenStats, nil + default: + return messageID, tokenStats, nil + } +} + +func (h *HeadlessHandler) handleChatChunk(e domain.ChatChunkEvent, messageID string) string { + if e.Content == "" { + return messageID + } + + if messageID == "" { + messageID = uuid.New().String() + h.emitEvent(TextMessageStart{ + Type: "TextMessageStart", + MessageID: messageID, + Role: "assistant", + Timestamp: time.Now().UnixMilli(), + }) + } + + h.emitEvent(TextMessageContent{ + Type: "TextMessageContent", + MessageID: messageID, + Delta: e.Content, + Timestamp: time.Now().UnixMilli(), + }) + + return messageID +} + +func (h *HeadlessHandler) handleChatComplete(e domain.ChatCompleteEvent, messageID string) (string, map[string]int, error) { + if messageID != "" { + h.emitEvent(TextMessageEnd{ + Type: "TextMessageEnd", + MessageID: messageID, + Timestamp: time.Now().UnixMilli(), + }) + messageID = "" + } + + var tokenStats map[string]int + if e.Metrics != nil && e.Metrics.Usage != nil { + tokenStats = map[string]int{ + "input_tokens": int(e.Metrics.Usage.PromptTokens), + "output_tokens": int(e.Metrics.Usage.CompletionTokens), + "total_tokens": int(e.Metrics.Usage.PromptTokens + e.Metrics.Usage.CompletionTokens), + } + } + + return messageID, tokenStats, nil +} + +func (h *HeadlessHandler) handleChatError(e domain.ChatErrorEvent) error { + h.emitRunError(fmt.Sprintf("chat error: %v", e.Error), "CHAT_ERROR") + return fmt.Errorf("chat error: %v", e.Error) +} + +func (h *HeadlessHandler) handleToolCallReady(e domain.ToolCallReadyEvent, messageID string) string { + if messageID != "" { + h.emitEvent(TextMessageEnd{ + Type: "TextMessageEnd", + MessageID: messageID, + Timestamp: time.Now().UnixMilli(), + }) + messageID = "" + } + + for _, toolCall := range e.ToolCalls { + h.emitToolCallEvents(toolCall.Id, toolCall.Function.Name, toolCall.Function.Arguments, messageID) + } + + return messageID +} + +func (h *HeadlessHandler) handleParallelToolsStart(e domain.ParallelToolsStartEvent, messageID string) string { + if messageID != "" { + h.emitEvent(TextMessageEnd{ + Type: "TextMessageEnd", + MessageID: messageID, + Timestamp: time.Now().UnixMilli(), + }) + messageID = "" + } + + for _, tool := range e.Tools { + h.trackToolExecution(tool.CallID, tool.Name) + h.emitEvent(ToolCallStart{ + Type: "ToolCallStart", + ToolCallID: tool.CallID, + ToolCallName: tool.Name, + Arguments: tool.Arguments, + ParentMessageID: messageID, + Status: "queued", + Timestamp: time.Now().UnixMilli(), + }) + } + + return messageID +} + +func (h *HeadlessHandler) handleToolExecutionProgress(e domain.ToolExecutionProgressEvent) { + state := h.updateToolStatus(e.ToolCallID, e.Status) + aguiStatus := mapStatusToAGUI(e.Status) + + h.emitEvent(ToolCallProgress{ + Type: "ToolCallProgress", + ToolCallID: e.ToolCallID, + Status: aguiStatus, + Message: e.Message, + Timestamp: time.Now().UnixMilli(), + }) + + if e.Status == "complete" || e.Status == "failed" { + h.emitToolResult(e, state, aguiStatus) + } +} + +func (h *HeadlessHandler) emitToolResult(e domain.ToolExecutionProgressEvent, state *toolExecutionState, aguiStatus string) { + if state == nil { + return + } + + duration := time.Since(state.StartTime).Seconds() + content := e.Result + if content == "" { + content = e.Message + } + + h.emitEvent(ToolCallResult{ + Type: "ToolCallResult", + MessageID: uuid.New().String(), + ToolCallID: e.ToolCallID, + Content: content, + Role: "tool", + Status: aguiStatus, + Duration: duration, + Timestamp: time.Now().UnixMilli(), + }) + + h.removeToolState(e.ToolCallID) +} + +func (h *HeadlessHandler) handleBashOutputChunk(e domain.BashOutputChunkEvent) { + h.emitEvent(ToolCallProgress{ + Type: "ToolCallProgress", + ToolCallID: e.ToolCallID, + Status: "running", + Message: "Streaming output...", + Output: e.Output, + Metadata: map[string]any{ + "isComplete": e.IsComplete, + }, + Timestamp: time.Now().UnixMilli(), + }) +} + +func (h *HeadlessHandler) handleParallelToolsComplete(e domain.ParallelToolsCompleteEvent) { + h.emitEvent(ParallelToolsMetadata{ + Type: "ParallelToolsMetadata", + TotalCount: e.TotalExecuted, + SuccessCount: e.SuccessCount, + FailureCount: e.FailureCount, + TotalDuration: e.Duration.Seconds(), + Timestamp: time.Now().UnixMilli(), + }) +} + +func (h *HeadlessHandler) handleToolApprovalRequested(e domain.ToolApprovalRequestedEvent, messageID string) { + h.emitEvent(ToolCallStart{ + Type: "ToolCallStart", + ToolCallID: e.ToolCall.Id, + ToolCallName: e.ToolCall.Function.Name, + Arguments: e.ToolCall.Function.Arguments, + ParentMessageID: messageID, + Timestamp: time.Now().UnixMilli(), + }) +} + +func (h *HeadlessHandler) emitToolCallEvents(toolCallID, toolName, arguments, messageID string) { + h.trackToolExecution(toolCallID, toolName) + + h.emitEvent(ToolCallStart{ + Type: "ToolCallStart", + ToolCallID: toolCallID, + ToolCallName: toolName, + Arguments: arguments, + ParentMessageID: messageID, + Status: "queued", + Timestamp: time.Now().UnixMilli(), + }) + + h.emitEvent(ToolCallArgs{ + Type: "ToolCallArgs", + ToolCallID: toolCallID, + Delta: arguments, + Timestamp: time.Now().UnixMilli(), + }) + + h.emitEvent(ToolCallEnd{ + Type: "ToolCallEnd", + ToolCallID: toolCallID, + Timestamp: time.Now().UnixMilli(), + }) +} + +// handleInterrupt handles interruption signal +func (h *HeadlessHandler) handleInterrupt() error { + h.cancel() + h.emitRunError("Chat interrupted by user", "INTERRUPTED") + return nil +} + +// emitEvent sends an AG-UI protocol event to stdout +func (h *HeadlessHandler) emitEvent(event any) { + jsonData, err := json.Marshal(event) + if err != nil { + logger.Error("Failed to marshal event", "error", err, "event", event) + return + } + + if _, err := fmt.Fprintf(h.stdout, "%s\n", jsonData); err != nil { + logger.Error("Failed to write to stdout", "error", err) + } +} + +// emitRunError sends a RunError event +func (h *HeadlessHandler) emitRunError(message string, code string) { + h.emitEvent(RunError{ + Type: "RunError", + Message: message, + Code: code, + Timestamp: time.Now().UnixMilli(), + }) +} + +// Shutdown cleanly shuts down the headless handler +func (h *HeadlessHandler) Shutdown() error { + h.cancel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + return h.services.Shutdown(ctx) +} + +// contains checks if a slice contains a string +func contains(slice []string, item string) bool { + for _, s := range slice { + if s == item { + return true + } + } + return false +} + +// loadExistingConversation attempts to load a conversation from storage if conversationID is set +func (h *HeadlessHandler) loadExistingConversation(conversationRepo domain.ConversationRepository) { + if h.conversationID == "" { + return + } + + ctx, cancel := context.WithTimeout(h.ctx, 30*time.Second) + defer cancel() + + persistentRepo, ok := conversationRepo.(interface { + LoadConversation(ctx context.Context, conversationID string) error + }) + if !ok { + logger.Warn("ConversationRepository does not support LoadConversation, will create new conversation on first message") + h.conversationID = "" + return + } + + logger.Info("💾 Attempting to load conversation from storage...") + if err := persistentRepo.LoadConversation(ctx, h.conversationID); err != nil { + logger.Warn("Failed to load conversation, will create new one on first message", "conversation_id", h.conversationID, "error", err) + h.conversationID = "" + } +} + +// trackToolExecution initializes state tracking for a tool call +func (h *HeadlessHandler) trackToolExecution(callID, toolName string) { + h.toolStatesMux.Lock() + defer h.toolStatesMux.Unlock() + + if h.toolStates == nil { + h.toolStates = make(map[string]*toolExecutionState) + } + + h.toolStates[callID] = &toolExecutionState{ + CallID: callID, + ToolName: toolName, + StartTime: time.Now(), + Status: "queued", + OutputBuffer: []string{}, + } +} + +// getToolState retrieves tool execution state (thread-safe read) +func (h *HeadlessHandler) getToolState(callID string) *toolExecutionState { + h.toolStatesMux.RLock() + defer h.toolStatesMux.RUnlock() + return h.toolStates[callID] +} + +// updateToolStatus updates tool status and returns updated state +func (h *HeadlessHandler) updateToolStatus(callID, status string) *toolExecutionState { + h.toolStatesMux.Lock() + defer h.toolStatesMux.Unlock() + + if state, exists := h.toolStates[callID]; exists { + state.Status = status + return state + } + return nil +} + +// removeToolState cleans up tool state after completion +func (h *HeadlessHandler) removeToolState(callID string) { + h.toolStatesMux.Lock() + defer h.toolStatesMux.Unlock() + delete(h.toolStates, callID) +} + +// mapStatusToAGUI maps TUI status values to AG-UI protocol status +func mapStatusToAGUI(status string) string { + switch status { + case "queued", "ready": + return "queued" + case "running", "starting", "saving", "executing", "streaming": + return "running" + case "complete", "completed", "executed": + return "complete" + case "error", "failed": + return "failed" + default: + return status + } +} diff --git a/internal/handlers/headless_handler_progress_test.go b/internal/handlers/headless_handler_progress_test.go new file mode 100644 index 00000000..ca21e766 --- /dev/null +++ b/internal/handlers/headless_handler_progress_test.go @@ -0,0 +1,404 @@ +package handlers + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "testing" + "time" + + domain "github.com/inference-gateway/cli/internal/domain" +) + +// TestMapStatusToAGUI tests the status mapping function +func TestMapStatusToAGUI(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"queued maps to queued", "queued", "queued"}, + {"ready maps to queued", "ready", "queued"}, + {"running maps to running", "running", "running"}, + {"starting maps to running", "starting", "running"}, + {"saving maps to running", "saving", "running"}, + {"executing maps to running", "executing", "running"}, + {"streaming maps to running", "streaming", "running"}, + {"complete maps to complete", "complete", "complete"}, + {"completed maps to complete", "completed", "complete"}, + {"executed maps to complete", "executed", "complete"}, + {"error maps to failed", "error", "failed"}, + {"failed maps to failed", "failed", "failed"}, + {"unknown passes through", "unknown", "unknown"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := mapStatusToAGUI(tt.input) + if result != tt.expected { + t.Errorf("mapStatusToAGUI(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// TestTrackToolExecution tests tool execution state tracking +func TestTrackToolExecution(t *testing.T) { + h := &HeadlessHandler{} + + callID := "call_123" + toolName := "Bash" + + h.trackToolExecution(callID, toolName) + + state := h.getToolState(callID) + if state == nil { + t.Fatal("Expected state to be created, got nil") + } + + if state.CallID != callID { + t.Errorf("Expected CallID %q, got %q", callID, state.CallID) + } + + if state.ToolName != toolName { + t.Errorf("Expected ToolName %q, got %q", toolName, state.ToolName) + } + + if state.Status != "queued" { + t.Errorf("Expected initial status 'queued', got %q", state.Status) + } + + if state.OutputBuffer == nil { + t.Error("Expected OutputBuffer to be initialized, got nil") + } + + if state.StartTime.IsZero() { + t.Error("Expected StartTime to be set, got zero time") + } +} + +// TestUpdateToolStatus tests updating tool status +func TestUpdateToolStatus(t *testing.T) { + h := &HeadlessHandler{} + + callID := "call_123" + toolName := "Bash" + + h.trackToolExecution(callID, toolName) + + state := h.updateToolStatus(callID, "running") + if state == nil { + t.Fatal("Expected state to be returned, got nil") + } + + if state.Status != "running" { + t.Errorf("Expected status 'running', got %q", state.Status) + } + + nilState := h.updateToolStatus("non_existent", "complete") + if nilState != nil { + t.Error("Expected nil for non-existent tool, got state") + } +} + +// TestRemoveToolState tests removing tool state +func TestRemoveToolState(t *testing.T) { + h := &HeadlessHandler{} + + callID := "call_123" + toolName := "Bash" + + h.trackToolExecution(callID, toolName) + + if h.getToolState(callID) == nil { + t.Fatal("Expected state to exist before removal") + } + + h.removeToolState(callID) + + if h.getToolState(callID) != nil { + t.Error("Expected state to be removed, but it still exists") + } +} + +// TestConcurrentStateAccess tests thread-safe state access +func TestConcurrentStateAccess(t *testing.T) { + h := &HeadlessHandler{} + + done := make(chan bool) + numGoroutines := 10 + + for i := 0; i < numGoroutines; i++ { + go func(id int) { + callID := "call_" + string(rune('0'+id)) + toolName := "Tool" + string(rune('0'+id)) + + h.trackToolExecution(callID, toolName) + h.updateToolStatus(callID, "running") + _ = h.getToolState(callID) + h.updateToolStatus(callID, "complete") + h.removeToolState(callID) + + done <- true + }(i) + } + + for i := 0; i < numGoroutines; i++ { + <-done + } + + h.toolStatesMux.RLock() + remaining := len(h.toolStates) + h.toolStatesMux.RUnlock() + + if remaining != 0 { + t.Errorf("Expected 0 remaining states, got %d", remaining) + } +} + +// TestToolCallProgressEventEmission tests ToolCallProgress event emission +func TestToolCallProgressEventEmission(t *testing.T) { + var stdout bytes.Buffer + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := &HeadlessHandler{ + stdout: &stdout, + ctx: ctx, + } + + callID := "call_123" + h.trackToolExecution(callID, "Bash") + + progressEvent := domain.ToolExecutionProgressEvent{ + ToolCallID: callID, + ToolName: "Bash", + Status: "running", + Message: "Executing command...", + } + + events := make(chan domain.ChatEvent, 1) + events <- progressEvent + close(events) + + err := h.processStreamingEvents(events) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + scanner := bufio.NewScanner(&stdout) + scanner.Split(bufio.ScanLines) + + var emittedProgress bool + for scanner.Scan() { + line := scanner.Bytes() + var event map[string]any + if err := json.Unmarshal(line, &event); err != nil { + continue + } + + if event["type"] == "ToolCallProgress" { + emittedProgress = true + if event["toolCallId"] != callID { + t.Errorf("Expected toolCallId %q, got %v", callID, event["toolCallId"]) + } + if event["status"] != "running" { + t.Errorf("Expected status 'running', got %v", event["status"]) + } + if event["message"] != "Executing command..." { + t.Errorf("Expected message 'Executing command...', got %v", event["message"]) + } + } + } + + if !emittedProgress { + t.Error("Expected ToolCallProgress event to be emitted") + } +} + +// TestBashOutputChunkEventEmission tests BashOutputChunkEvent emission +func TestBashOutputChunkEventEmission(t *testing.T) { + var stdout bytes.Buffer + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := &HeadlessHandler{ + stdout: &stdout, + ctx: ctx, + } + + callID := "call_123" + h.trackToolExecution(callID, "Bash") + + bashEvent := domain.BashOutputChunkEvent{ + ToolCallID: callID, + Output: "total 48\n", + IsComplete: false, + } + + events := make(chan domain.ChatEvent, 1) + events <- bashEvent + close(events) + + err := h.processStreamingEvents(events) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + scanner := bufio.NewScanner(&stdout) + scanner.Split(bufio.ScanLines) + + var foundOutput bool + for scanner.Scan() { + line := scanner.Bytes() + var event map[string]any + if err := json.Unmarshal(line, &event); err != nil { + continue + } + + if event["type"] != "ToolCallProgress" { + continue + } + + foundOutput = true + if event["output"] != "total 48\n" { + t.Errorf("Expected output 'total 48\\n', got %v", event["output"]) + } + if metadata, ok := event["metadata"].(map[string]any); ok { + if metadata["isComplete"] != false { + t.Errorf("Expected isComplete false, got %v", metadata["isComplete"]) + } + } + } + + if !foundOutput { + t.Error("Expected ToolCallProgress event with output to be emitted") + } +} + +// TestParallelToolsMetadataEmission tests ParallelToolsMetadata event emission +func TestParallelToolsMetadataEmission(t *testing.T) { + var stdout bytes.Buffer + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := &HeadlessHandler{ + stdout: &stdout, + ctx: ctx, + } + + completeEvent := domain.ParallelToolsCompleteEvent{ + TotalExecuted: 3, + SuccessCount: 2, + FailureCount: 1, + Duration: 5 * time.Second, + } + + events := make(chan domain.ChatEvent, 1) + events <- completeEvent + close(events) + + err := h.processStreamingEvents(events) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + scanner := bufio.NewScanner(&stdout) + scanner.Split(bufio.ScanLines) + + var foundMetadata bool + for scanner.Scan() { + line := scanner.Bytes() + var event map[string]any + if err := json.Unmarshal(line, &event); err != nil { + continue + } + + if event["type"] != "ParallelToolsMetadata" { + continue + } + + foundMetadata = true + if int(event["totalCount"].(float64)) != 3 { + t.Errorf("Expected totalCount 3, got %v", event["totalCount"]) + } + if int(event["successCount"].(float64)) != 2 { + t.Errorf("Expected successCount 2, got %v", event["successCount"]) + } + if int(event["failureCount"].(float64)) != 1 { + t.Errorf("Expected failureCount 1, got %v", event["failureCount"]) + } + if event["totalDuration"].(float64) != 5.0 { + t.Errorf("Expected totalDuration 5.0, got %v", event["totalDuration"]) + } + } + + if !foundMetadata { + t.Error("Expected ParallelToolsMetadata event to be emitted") + } +} + +// TestToolCallResultWithDuration tests ToolCallResult emission with duration +func TestToolCallResultWithDuration(t *testing.T) { + var stdout bytes.Buffer + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := &HeadlessHandler{ + stdout: &stdout, + ctx: ctx, + } + + callID := "call_123" + h.trackToolExecution(callID, "Bash") + + completeEvent := domain.ToolExecutionProgressEvent{ + ToolCallID: callID, + ToolName: "Bash", + Status: "complete", + Message: "Command executed successfully", + } + + events := make(chan domain.ChatEvent, 1) + events <- completeEvent + close(events) + + err := h.processStreamingEvents(events) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + scanner := bufio.NewScanner(&stdout) + scanner.Split(bufio.ScanLines) + + var foundResult bool + for scanner.Scan() { + line := scanner.Bytes() + var event map[string]any + if err := json.Unmarshal(line, &event); err != nil { + continue + } + + if event["type"] == "ToolCallResult" { + foundResult = true + if event["status"] != "complete" { + t.Errorf("Expected status 'complete', got %v", event["status"]) + } + if _, ok := event["duration"]; !ok { + t.Error("Expected duration field to be present") + } + if event["duration"].(float64) < 0 { + t.Errorf("Expected non-negative duration, got %v", event["duration"]) + } + } + } + + if !foundResult { + t.Error("Expected ToolCallResult event to be emitted") + } + + if h.getToolState(callID) != nil { + t.Error("Expected tool state to be removed after completion") + } +} diff --git a/internal/handlers/session_manager.go b/internal/handlers/session_manager.go new file mode 100644 index 00000000..69a5effb --- /dev/null +++ b/internal/handlers/session_manager.go @@ -0,0 +1,314 @@ +package handlers + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "sync" + "time" + + "github.com/inference-gateway/cli/internal/logger" +) + +// SessionManager manages headless CLI sessions +type SessionManager struct { + sessions map[string]*Session + mu sync.RWMutex +} + +// Session represents a headless CLI session +type Session struct { + ID string + ConversationID string + cmd *exec.Cmd + stdin io.WriteCloser + stdout io.ReadCloser + stderr io.ReadCloser + cancel context.CancelFunc + clients map[string]chan []byte + mu sync.RWMutex + wg sync.WaitGroup +} + +// NewSessionManager creates a new session manager +func NewSessionManager() *SessionManager { + return &SessionManager{ + sessions: make(map[string]*Session), + } +} + +// CreateSession spawns a new headless CLI process +func (sm *SessionManager) CreateSession(ctx context.Context, sessionID string, conversationID string) (*Session, error) { + sm.mu.Lock() + defer sm.mu.Unlock() + + if _, exists := sm.sessions[sessionID]; exists { + return nil, fmt.Errorf("session %s already exists", sessionID) + } + + sessionCtx, cancel := context.WithCancel(ctx) + + execPath, err := os.Executable() + if err != nil { + cancel() + return nil, fmt.Errorf("failed to get executable path: %w", err) + } + + args := []string{"chat", "--headless", "--session-id", sessionID} + if conversationID != "" { + args = append(args, "--conversation-id", conversationID) + } + + cmd := exec.CommandContext(sessionCtx, execPath, args...) + + stdin, err := cmd.StdinPipe() + if err != nil { + cancel() + return nil, fmt.Errorf("failed to create stdin pipe: %w", err) + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + cancel() + return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + } + + stderr, err := cmd.StderrPipe() + if err != nil { + cancel() + return nil, fmt.Errorf("failed to create stderr pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + cancel() + return nil, fmt.Errorf("failed to start headless process: %w", err) + } + + session := &Session{ + ID: sessionID, + ConversationID: conversationID, + cmd: cmd, + stdin: stdin, + stdout: stdout, + stderr: stderr, + cancel: cancel, + clients: make(map[string]chan []byte), + } + + sm.sessions[sessionID] = session + + session.wg.Add(2) + go session.readOutput() + go session.readStderr() + + return session, nil +} + +// GetSession retrieves an existing session +func (sm *SessionManager) GetSession(sessionID string) (*Session, bool) { + sm.mu.RLock() + defer sm.mu.RUnlock() + session, exists := sm.sessions[sessionID] + return session, exists +} + +// CloseSession terminates a session +func (sm *SessionManager) CloseSession(sessionID string) error { + sm.mu.Lock() + session, exists := sm.sessions[sessionID] + if !exists { + sm.mu.Unlock() + return fmt.Errorf("session %s not found", sessionID) + } + + delete(sm.sessions, sessionID) + sm.mu.Unlock() + + logger.Info("🛑 SESSION CLOSING", "session_id", sessionID, "conversation_id", session.ConversationID, "pid", session.cmd.Process.Pid) + + shutdownInput := map[string]any{ + "type": "shutdown", + } + if data, err := json.Marshal(shutdownInput); err == nil { + data = append(data, '\n') + + session.mu.Lock() + _, writeErr := session.stdin.Write(data) + session.mu.Unlock() + + if writeErr != nil { + logger.Warn("Failed to send shutdown signal", "session_id", sessionID, "error", writeErr) + } else { + logger.Info("Sent shutdown signal to session", "session_id", sessionID) + } + } + + done := make(chan error, 1) + go func() { + done <- session.cmd.Wait() + }() + + select { + case err := <-done: + if err != nil { + logger.Info("Headless process exited with error", "session_id", sessionID, "error", err) + } else { + logger.Info("Headless process exited cleanly", "session_id", sessionID) + } + case <-time.After(5 * time.Second): + logger.Warn("Headless process did not exit gracefully, forcing termination", "session_id", sessionID) + session.cancel() + <-done + } + + if err := session.stdout.Close(); err != nil { + logger.Warn("Failed to close stdout", "error", err) + } + if err := session.stderr.Close(); err != nil { + logger.Warn("Failed to close stderr", "error", err) + } + + session.wg.Wait() + + if err := session.stdin.Close(); err != nil { + logger.Warn("Failed to close stdin", "error", err) + } + session.mu.Lock() + for _, ch := range session.clients { + close(ch) + } + session.mu.Unlock() + + logger.Info("Session cleanup completed", "session_id", sessionID) + return nil +} + +// readOutput continuously reads from the CLI stdout and broadcasts to clients +func (s *Session) readOutput() { + defer s.wg.Done() + scanner := bufio.NewScanner(s.stdout) + scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024) // 1MB initial, 10MB max + + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + s.mu.RLock() + for clientID, ch := range s.clients { + select { + case ch <- line: + default: + logger.Warn("Client channel full, skipping", "session_id", s.ID, "client_id", clientID) + } + } + s.mu.RUnlock() + } + + if err := scanner.Err(); err != nil { + logger.Error("Error reading session output", "session_id", s.ID, "error", err) + } +} + +// readStderr continuously reads from the CLI stderr (for debugging) +func (s *Session) readStderr() { + defer s.wg.Done() + scanner := bufio.NewScanner(s.stderr) + scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024) // 1MB initial, 10MB max + + for scanner.Scan() { + } + + if err := scanner.Err(); err != nil { + logger.Error("Error reading session stderr", "session_id", s.ID, "error", err) + } +} + +// Subscribe adds a client to receive session output +func (s *Session) Subscribe(clientID string) chan []byte { + s.mu.Lock() + defer s.mu.Unlock() + + ch := make(chan []byte, 100) + s.clients[clientID] = ch + return ch +} + +// Unsubscribe removes a client +func (s *Session) Unsubscribe(clientID string) { + s.mu.Lock() + defer s.mu.Unlock() + + ch, exists := s.clients[clientID] + if !exists { + return + } + + delete(s.clients, clientID) + + func() { + defer func() { + if r := recover(); r != nil { + logger.Debug("Channel already closed during unsubscribe", "session_id", s.ID, "client_id", clientID) + } + }() + close(ch) + }() +} + +// SendMessage sends a message to the headless CLI +func (s *Session) SendMessage(content string, images []any, model string) error { + input := map[string]any{ + "type": "message", + "content": content, + "images": images, + } + + if model != "" { + input["model"] = model + } + + data, err := json.Marshal(input) + if err != nil { + return fmt.Errorf("failed to marshal input: %w", err) + } + + data = append(data, '\n') + + s.mu.Lock() + defer s.mu.Unlock() + + if _, err := s.stdin.Write(data); err != nil { + return fmt.Errorf("failed to write to stdin: %w", err) + } + + return nil +} + +// SendInterrupt sends an interrupt signal to the headless CLI +func (s *Session) SendInterrupt() error { + input := map[string]any{ + "type": "interrupt", + } + + data, err := json.Marshal(input) + if err != nil { + return fmt.Errorf("failed to marshal input: %w", err) + } + + data = append(data, '\n') + + s.mu.Lock() + defer s.mu.Unlock() + + if _, err := s.stdin.Write(data); err != nil { + return fmt.Errorf("failed to write to stdin: %w", err) + } + + return nil +} diff --git a/internal/handlers/websocket_handler.go b/internal/handlers/websocket_handler.go new file mode 100644 index 00000000..c320aa4f --- /dev/null +++ b/internal/handlers/websocket_handler.go @@ -0,0 +1,284 @@ +package handlers + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/inference-gateway/cli/internal/logger" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +// WebSocketHandler handles WebSocket connections for live chat +type WebSocketHandler struct { + sessionManager *SessionManager +} + +// NewWebSocketHandler creates a new WebSocket handler +func NewWebSocketHandler(sessionManager *SessionManager) *WebSocketHandler { + return &WebSocketHandler{ + sessionManager: sessionManager, + } +} + +// HandleWebSocket handles WebSocket upgrade and communication +func (h *WebSocketHandler) HandleWebSocket(w http.ResponseWriter, r *http.Request) { + logger.Info("WebSocket handler called", "method", r.Method, "path", r.URL.Path, "headers", r.Header) + + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.Error("Failed to upgrade to WebSocket", "error", err) + return + } + defer h.closeConnection(conn) + + clientID := uuid.New().String() + logger.Info("WebSocket client connected", "client_id", clientID) + + var sessionID string + var session *Session + + defer h.cleanupSession(session, sessionID, clientID) + + h.messageLoop(conn, r, &session, &sessionID, clientID) +} + +func (h *WebSocketHandler) closeConnection(conn *websocket.Conn) { + if err := conn.Close(); err != nil { + logger.Warn("Failed to close WebSocket connection", "error", err) + } +} + +func (h *WebSocketHandler) cleanupSession(session *Session, sessionID, clientID string) { + if session == nil { + return + } + + session.Unsubscribe(clientID) + logger.Info("WebSocket client disconnected, unsubscribed from session", "client_id", clientID, "session_id", sessionID) + + session.mu.RLock() + clientCount := len(session.clients) + session.mu.RUnlock() + + if clientCount == 0 { + logger.Info("Last client disconnected, closing session", "session_id", sessionID) + if err := h.sessionManager.CloseSession(sessionID); err != nil { + logger.Warn("Failed to close session", "session_id", sessionID, "error", err) + } + } +} + +func (h *WebSocketHandler) messageLoop(conn *websocket.Conn, r *http.Request, session **Session, sessionID *string, clientID string) { + for { + var msg WSMessage + if err := conn.ReadJSON(&msg); err != nil { + logger.Error("Failed to read WebSocket message", "error", err) + return + } + + shouldReturn := h.handleMessage(conn, r, msg, session, sessionID, clientID) + if shouldReturn { + return + } + } +} + +func (h *WebSocketHandler) handleMessage(conn *websocket.Conn, r *http.Request, msg WSMessage, session **Session, sessionID *string, clientID string) bool { + switch msg.Type { + case "create_session": + return h.handleCreateSession(conn, r, msg, session, sessionID, clientID) + case "join_session": + return h.handleJoinSession(conn, msg, session, sessionID, clientID) + case "message": + return h.handleChatMessage(conn, msg, *session) + case "interrupt": + return h.handleInterruptMessage(conn, *session) + case "close_session": + return h.handleCloseSession(*sessionID) + default: + h.sendError(conn, fmt.Sprintf("Unknown message type: %s", msg.Type)) + return false + } +} + +func (h *WebSocketHandler) handleCreateSession(conn *websocket.Conn, r *http.Request, msg WSMessage, session **Session, sessionID *string, clientID string) bool { + requestedSessionID := msg.SessionID + if requestedSessionID == "" { + requestedSessionID = uuid.New().String() + } + + conversationID := msg.ConversationID + + logger.Info("=== CREATE_SESSION REQUEST ===", + "requested_session_id", requestedSessionID, + "conversation_id", conversationID, + "client_id", clientID) + + existingSession, exists := h.sessionManager.GetSession(requestedSessionID) + canReuse := exists && (existingSession.ConversationID == conversationID || + (existingSession.ConversationID == "" && conversationID == "")) + + if canReuse { + *session = h.reuseExistingSession(existingSession, requestedSessionID, conversationID) + *sessionID = requestedSessionID + } else { + newSession, err := h.createNewSession(r, requestedSessionID, conversationID, exists) + if err != nil { + h.sendError(conn, fmt.Sprintf("Failed to create session: %v", err)) + return true + } + *session = newSession + *sessionID = requestedSessionID + } + + outputChan := (*session).Subscribe(clientID) + + h.sendMessage(conn, WSMessage{ + Type: "session_created", + SessionID: *sessionID, + ConversationID: (*session).ConversationID, + }) + + go h.forwardOutput(conn, outputChan) + return false +} + +func (h *WebSocketHandler) reuseExistingSession(existingSession *Session, requestedSessionID, conversationID string) *Session { + logger.Info("✅ Reusing existing session with same conversation", + "session_id", requestedSessionID, + "conversation_id", conversationID) + return existingSession +} + +func (h *WebSocketHandler) createNewSession(r *http.Request, requestedSessionID, conversationID string, exists bool) (*Session, error) { + if exists { + if err := h.sessionManager.CloseSession(requestedSessionID); err != nil { + logger.Warn("Failed to close existing session", "session_id", requestedSessionID, "error", err) + } + } + + return h.sessionManager.CreateSession(r.Context(), requestedSessionID, conversationID) +} + +func (h *WebSocketHandler) handleJoinSession(conn *websocket.Conn, msg WSMessage, session **Session, sessionID *string, clientID string) bool { + if msg.SessionID == "" { + h.sendError(conn, "Session ID required") + return false + } + + var exists bool + *session, exists = h.sessionManager.GetSession(msg.SessionID) + if !exists { + h.sendError(conn, fmt.Sprintf("Session %s not found", msg.SessionID)) + return false + } + + *sessionID = msg.SessionID + + outputChan := (*session).Subscribe(clientID) + + h.sendMessage(conn, WSMessage{ + Type: "session_joined", + SessionID: *sessionID, + }) + + go h.forwardOutput(conn, outputChan) + return false +} + +func (h *WebSocketHandler) handleChatMessage(conn *websocket.Conn, msg WSMessage, session *Session) bool { + if session == nil { + h.sendError(conn, "No active session") + return false + } + + if err := session.SendMessage(msg.Content, msg.Images, msg.Model); err != nil { + h.sendError(conn, fmt.Sprintf("Failed to send message: %v", err)) + } + return false +} + +func (h *WebSocketHandler) handleInterruptMessage(conn *websocket.Conn, session *Session) bool { + if session == nil { + h.sendError(conn, "No active session") + return false + } + + if err := session.SendInterrupt(); err != nil { + h.sendError(conn, fmt.Sprintf("Failed to send interrupt: %v", err)) + } + return false +} + +func (h *WebSocketHandler) handleCloseSession(sessionID string) bool { + if sessionID != "" { + if err := h.sessionManager.CloseSession(sessionID); err != nil { + logger.Warn("Failed to close session", "session_id", sessionID, "error", err) + } + } + return true +} + +// forwardOutput forwards session output to WebSocket client +func (h *WebSocketHandler) forwardOutput(conn *websocket.Conn, outputChan chan []byte) { + for data := range outputChan { + var jsonTest map[string]any + if err := json.Unmarshal(data, &jsonTest); err != nil { + wrapped := WSMessage{ + Type: "output", + Content: string(data), + Time: time.Now().UTC().Format(time.RFC3339), + } + if err := conn.WriteJSON(wrapped); err != nil { + logger.Error("Failed to write wrapped message to WebSocket", "error", err) + return + } + } else { + if err := conn.WriteMessage(websocket.TextMessage, data); err != nil { + logger.Error("Failed to write to WebSocket", "error", err) + return + } + } + } +} + +// sendMessage sends a message to the WebSocket client +func (h *WebSocketHandler) sendMessage(conn *websocket.Conn, msg WSMessage) { + if err := conn.WriteJSON(msg); err != nil { + logger.Error("Failed to send WebSocket message", "error", err) + } +} + +// sendError sends an error message to the client +func (h *WebSocketHandler) sendError(conn *websocket.Conn, errMsg string) { + msg := WSMessage{ + Type: "error", + Error: errMsg, + Time: time.Now().UTC().Format(time.RFC3339), + } + h.sendMessage(conn, msg) +} + +// WSMessage represents a WebSocket message +type WSMessage struct { + Type string `json:"type"` + SessionID string `json:"session_id,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + Content string `json:"content,omitempty"` + Images []any `json:"images,omitempty"` + Model string `json:"model,omitempty"` + Error string `json:"error,omitempty"` + Time string `json:"time,omitempty"` + Data any `json:"data,omitempty"` +} diff --git a/internal/infra/storage/jsonl.go b/internal/infra/storage/jsonl.go index 9e94c4c2..c8b95b37 100644 --- a/internal/infra/storage/jsonl.go +++ b/internal/infra/storage/jsonl.go @@ -53,12 +53,12 @@ func (s *JsonlStorage) conversationFilePath(conversationID string) string { // saveConversationUnlocked saves a conversation without acquiring the lock // Caller must hold the lock before calling this method func (s *JsonlStorage) saveConversationUnlocked(ctx context.Context, conversationID string, entries []domain.ConversationEntry, metadata ConversationMetadata) error { - metadataJSON, err := json.Marshal(map[string]interface{}{"metadata": metadata}) + metadataJSON, err := json.Marshal(map[string]any{"metadata": metadata}) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } - entriesJSON, err := json.Marshal(map[string]interface{}{"entries": entries}) + entriesJSON, err := json.Marshal(map[string]any{"entries": entries}) if err != nil { return fmt.Errorf("failed to marshal entries: %w", err) } @@ -309,12 +309,12 @@ func (s *JsonlStorage) UpdateConversationMetadata(ctx context.Context, conversat return fmt.Errorf("failed to unmarshal entries: %w", err) } - metadataJSON, err := json.Marshal(map[string]interface{}{"metadata": metadata}) + metadataJSON, err := json.Marshal(map[string]any{"metadata": metadata}) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } - entriesJSON, err := json.Marshal(map[string]interface{}{"entries": entriesWrapper.Entries}) + entriesJSON, err := json.Marshal(map[string]any{"entries": entriesWrapper.Entries}) if err != nil { return fmt.Errorf("failed to marshal entries: %w", err) } diff --git a/internal/services/agent.go b/internal/services/agent.go index 8d3104cb..8fc78972 100644 --- a/internal/services/agent.go +++ b/internal/services/agent.go @@ -106,9 +106,10 @@ func (p *eventPublisher) publishParallelToolsStart(toolCalls []sdk.ChatCompletio tools := make([]domain.ToolInfo, len(toolCalls)) for i, tc := range toolCalls { tools[i] = domain.ToolInfo{ - CallID: tc.Id, - Name: tc.Function.Name, - Status: "queued", + CallID: tc.Id, + Name: tc.Function.Name, + Status: "queued", + Arguments: tc.Function.Arguments, } } @@ -139,6 +140,23 @@ func (p *eventPublisher) publishToolStatusChange(callID string, toolName string, p.chatEvents <- event } +// publishToolStatusChangeWithResult publishes a ToolExecutionProgressEvent with formatted result +func (p *eventPublisher) publishToolStatusChangeWithResult(callID string, toolName string, status string, message string, result string) { + event := domain.ToolExecutionProgressEvent{ + BaseChatEvent: domain.BaseChatEvent{ + RequestID: p.requestID, + Timestamp: time.Now(), + }, + ToolCallID: callID, + ToolName: toolName, + Status: status, + Message: message, + Result: result, + } + + p.chatEvents <- event +} + // publishBashOutputChunk publishes a BashOutputChunkEvent for streaming bash output func (p *eventPublisher) publishBashOutputChunk(callID string, output string, isComplete bool) { event := domain.BashOutputChunkEvent{ @@ -847,19 +865,12 @@ func (s *AgentServiceImpl) executeToolCallsParallel( } for _, at := range approvalTools { - time.Sleep(constants.AgentToolExecutionDelay) result := s.executeTool(ctx, *at.tool, eventPublisher, isChatMode) + status, message, formattedResult := extractToolResultStatus(result) - status := "complete" - message := "Completed successfully" - if result.ToolExecution != nil && !result.ToolExecution.Success { - status = "failed" - message = "Execution failed" - } - - eventPublisher.publishToolStatusChange(at.tool.Id, at.tool.Function.Name, status, message) + eventPublisher.publishToolStatusChangeWithResult(at.tool.Id, at.tool.Function.Name, status, message, formattedResult) results[at.index] = result } @@ -889,15 +900,9 @@ func (s *AgentServiceImpl) executeToolCallsParallel( time.Sleep(constants.AgentToolExecutionDelay) result := s.executeTool(ctx, *toolCall, eventPublisher, isChatMode) + status, message, formattedResult := extractToolResultStatus(result) - status := "complete" - message := "Completed successfully" - if result.ToolExecution != nil && !result.ToolExecution.Success { - status = "failed" - message = "Execution failed" - } - - eventPublisher.publishToolStatusChange(toolCall.Id, toolCall.Function.Name, status, message) + eventPublisher.publishToolStatusChangeWithResult(toolCall.Id, toolCall.Function.Name, status, message, formattedResult) resultsChan <- IndexedToolResult{ Index: index, @@ -937,6 +942,29 @@ func (s *AgentServiceImpl) executeToolCallsParallel( return results } +func extractToolResultStatus(result domain.ConversationEntry) (status, message, formattedResult string) { + status = "complete" + message = "Completed successfully" + formattedResult = "" + + if result.ToolExecution == nil { + return status, message, formattedResult + } + + if !result.ToolExecution.Success { + status = "failed" + message = "Execution failed" + } + + if result.ToolExecution.Data != nil { + if jsonData, err := json.Marshal(result.ToolExecution.Data); err == nil { + formattedResult = string(jsonData) + } + } + + return status, message, formattedResult +} + //nolint:funlen,gocyclo,cyclop // Tool execution requires comprehensive error handling and status updates func (s *AgentServiceImpl) executeTool( ctx context.Context, diff --git a/internal/services/agent_manager.go b/internal/services/agent_manager.go index 81838055..e4ca6de7 100644 --- a/internal/services/agent_manager.go +++ b/internal/services/agent_manager.go @@ -72,6 +72,12 @@ func (am *AgentManager) StartAgents(ctx context.Context) error { } } + if len(agentsToStart) > 0 && am.containerRuntime != nil { + if err := am.containerRuntime.EnsureNetwork(ctx); err != nil { + logger.Warn("Failed to create container network", "session", am.sessionID, "error", err) + } + } + for _, agent := range agentsToStart { go am.startAgentAsync(ctx, agent) } diff --git a/internal/services/conversation.go b/internal/services/conversation.go index 0ab3bafd..5ba3902b 100644 --- a/internal/services/conversation.go +++ b/internal/services/conversation.go @@ -183,6 +183,11 @@ func (r *InMemoryConversationRepository) StartNewConversation(title string) erro return r.Clear() } +// GetCurrentConversationID returns empty string for in-memory repository +func (r *InMemoryConversationRepository) GetCurrentConversationID() string { + return "" +} + func (r *InMemoryConversationRepository) ClearExceptFirstUserMessage() error { r.mutex.Lock() defer r.mutex.Unlock() diff --git a/internal/services/ui_manager.go b/internal/services/ui_manager.go new file mode 100644 index 00000000..515f43fd --- /dev/null +++ b/internal/services/ui_manager.go @@ -0,0 +1,186 @@ +package services + +import ( + "context" + "fmt" + "net/http" + "os" + "os/exec" + "path/filepath" + "time" + + config "github.com/inference-gateway/cli/config" + logger "github.com/inference-gateway/cli/internal/logger" +) + +// UIManager manages the lifecycle of the web UI server +type UIManager struct { + config *config.Config + cmd *exec.Cmd + isRunning bool +} + +// NewUIManager creates a new UI manager +func NewUIManager(cfg *config.Config) *UIManager { + return &UIManager{ + config: cfg, + } +} + +// Start starts the UI server process +func (um *UIManager) Start(ctx context.Context) error { + if um.isRunning { + return nil + } + + switch um.config.API.UI.Mode { + case "npm": + return um.startNPM(ctx) + case "docker": + return fmt.Errorf("docker mode not yet implemented") + default: + return fmt.Errorf("unsupported UI mode: %s", um.config.API.UI.Mode) + } +} + +// startNPM starts the UI using npm dev server +func (um *UIManager) startNPM(ctx context.Context) error { + logger.Info("Starting UI development server") + + workingDir := um.config.API.UI.WorkingDir + if !filepath.IsAbs(workingDir) { + cwd, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get current directory: %w", err) + } + workingDir = filepath.Join(cwd, workingDir) + } + + if _, err := os.Stat(workingDir); os.IsNotExist(err) { + return fmt.Errorf("UI directory not found: %s", workingDir) + } + + packageJSON := filepath.Join(workingDir, "package.json") + if _, err := os.Stat(packageJSON); os.IsNotExist(err) { + return fmt.Errorf("package.json not found in UI directory: %s", workingDir) + } + + nodeModules := filepath.Join(workingDir, "node_modules") + if _, err := os.Stat(nodeModules); os.IsNotExist(err) { + logger.Info("Installing UI dependencies") + fmt.Println("Installing UI dependencies (this may take a moment)...") + installCmd := exec.Command("npm", "install") + installCmd.Dir = workingDir + installCmd.Stdout = os.Stdout + installCmd.Stderr = os.Stderr + if err := installCmd.Run(); err != nil { + return fmt.Errorf("failed to install UI dependencies: %w", err) + } + } + + fmt.Println("Starting UI development server...") + + um.cmd = exec.Command("npm", "run", "dev") + um.cmd.Dir = workingDir + + apiURL := fmt.Sprintf("http://localhost:%d", um.config.API.Port) + um.cmd.Env = append(os.Environ(), + fmt.Sprintf("PORT=%d", um.config.API.UI.Port), + fmt.Sprintf("NEXT_PUBLIC_API_URL=%s", apiURL), + ) + + logger.Info("UI environment configured", + "ui_port", um.config.API.UI.Port, + "api_url", apiURL, + ) + + if um.config.Gateway.Debug { + um.cmd.Stdout = os.Stdout + um.cmd.Stderr = os.Stderr + } + + if err := um.cmd.Start(); err != nil { + return fmt.Errorf("failed to start UI server: %w", err) + } + + fmt.Println("Waiting for UI server to become ready...") + + if err := um.waitForReady(ctx); err != nil { + if stopErr := um.Stop(); stopErr != nil { + logger.Warn("Failed to stop UI server during error cleanup", "error", stopErr) + } + return fmt.Errorf("UI server failed to become ready: %w", err) + } + + um.isRunning = true + fmt.Printf("UI server is ready at %s\n\n", um.GetURL()) + logger.Info("UI server started successfully", "port", um.config.API.UI.Port) + return nil +} + +// waitForReady waits for the UI server to become ready by polling the root URL +func (um *UIManager) waitForReady(ctx context.Context) error { + url := fmt.Sprintf("http://%s:%d", um.config.API.Host, um.config.API.UI.Port) + + timeout := 60 * time.Second + deadline := time.Now().Add(timeout) + ticker := time.NewTicker(500 * time.Millisecond) + defer ticker.Stop() + + client := &http.Client{ + Timeout: 2 * time.Second, + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + if time.Now().After(deadline) { + return fmt.Errorf("timeout waiting for UI server to become ready") + } + + resp, err := client.Get(url) + if err == nil { + if closeErr := resp.Body.Close(); closeErr != nil { + logger.Warn("Failed to close response body", "error", closeErr) + } + if resp.StatusCode < 500 { + return nil + } + } + } + } +} + +// Stop stops the UI server process +func (um *UIManager) Stop() error { + if !um.isRunning { + return nil + } + + if um.cmd == nil || um.cmd.Process == nil { + return nil + } + + logger.Info("Stopping UI server", "pid", um.cmd.Process.Pid) + + if err := um.cmd.Process.Kill(); err != nil { + logger.Warn("Failed to kill UI server process", "error", err) + return err + } + + um.isRunning = false + logger.Info("UI server stopped successfully") + return nil +} + +// IsRunning returns whether the UI server is running +func (um *UIManager) IsRunning() bool { + return um.isRunning +} + +// GetURL returns the URL where the UI is accessible +func (um *UIManager) GetURL() string { + return fmt.Sprintf("http://localhost:%d", um.config.API.UI.Port) +} diff --git a/internal/utils/browser.go b/internal/utils/browser.go new file mode 100644 index 00000000..5c95ec27 --- /dev/null +++ b/internal/utils/browser.go @@ -0,0 +1,23 @@ +package utils + +import ( + "fmt" + "os/exec" + "runtime" +) + +// OpenBrowser opens a URL in the default browser (cross-platform) +func OpenBrowser(url string) error { + var cmd *exec.Cmd + + switch runtime.GOOS { + case "darwin": + cmd = exec.Command("open", url) + case "linux": + cmd = exec.Command("xdg-open", url) + default: + return fmt.Errorf("unsupported platform: %s", runtime.GOOS) + } + + return cmd.Start() +} diff --git a/tests/mocks/domain/fake_conversation_repository.go b/tests/mocks/domain/fake_conversation_repository.go index d2f89814..5d2de2a2 100644 --- a/tests/mocks/domain/fake_conversation_repository.go +++ b/tests/mocks/domain/fake_conversation_repository.go @@ -113,6 +113,16 @@ type FakeConversationRepository struct { formatToolResultForUIReturnsOnCall map[int]struct { result1 string } + GetCurrentConversationIDStub func() string + getCurrentConversationIDMutex sync.RWMutex + getCurrentConversationIDArgsForCall []struct { + } + getCurrentConversationIDReturns struct { + result1 string + } + getCurrentConversationIDReturnsOnCall map[int]struct { + result1 string + } GetMessageCountStub func() int getMessageCountMutex sync.RWMutex getMessageCountArgsForCall []struct { @@ -736,6 +746,59 @@ func (fake *FakeConversationRepository) FormatToolResultForUIReturnsOnCall(i int }{result1} } +func (fake *FakeConversationRepository) GetCurrentConversationID() string { + fake.getCurrentConversationIDMutex.Lock() + ret, specificReturn := fake.getCurrentConversationIDReturnsOnCall[len(fake.getCurrentConversationIDArgsForCall)] + fake.getCurrentConversationIDArgsForCall = append(fake.getCurrentConversationIDArgsForCall, struct { + }{}) + stub := fake.GetCurrentConversationIDStub + fakeReturns := fake.getCurrentConversationIDReturns + fake.recordInvocation("GetCurrentConversationID", []interface{}{}) + fake.getCurrentConversationIDMutex.Unlock() + if stub != nil { + return stub() + } + if specificReturn { + return ret.result1 + } + return fakeReturns.result1 +} + +func (fake *FakeConversationRepository) GetCurrentConversationIDCallCount() int { + fake.getCurrentConversationIDMutex.RLock() + defer fake.getCurrentConversationIDMutex.RUnlock() + return len(fake.getCurrentConversationIDArgsForCall) +} + +func (fake *FakeConversationRepository) GetCurrentConversationIDCalls(stub func() string) { + fake.getCurrentConversationIDMutex.Lock() + defer fake.getCurrentConversationIDMutex.Unlock() + fake.GetCurrentConversationIDStub = stub +} + +func (fake *FakeConversationRepository) GetCurrentConversationIDReturns(result1 string) { + fake.getCurrentConversationIDMutex.Lock() + defer fake.getCurrentConversationIDMutex.Unlock() + fake.GetCurrentConversationIDStub = nil + fake.getCurrentConversationIDReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeConversationRepository) GetCurrentConversationIDReturnsOnCall(i int, result1 string) { + fake.getCurrentConversationIDMutex.Lock() + defer fake.getCurrentConversationIDMutex.Unlock() + fake.GetCurrentConversationIDStub = nil + if fake.getCurrentConversationIDReturnsOnCall == nil { + fake.getCurrentConversationIDReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.getCurrentConversationIDReturnsOnCall[i] = struct { + result1 string + }{result1} +} + func (fake *FakeConversationRepository) GetMessageCount() int { fake.getMessageCountMutex.Lock() ret, specificReturn := fake.getMessageCountReturnsOnCall[len(fake.getMessageCountArgsForCall)] diff --git a/ui/.gitignore b/ui/.gitignore new file mode 100644 index 00000000..b90a368f --- /dev/null +++ b/ui/.gitignore @@ -0,0 +1,2 @@ +node_modules +.next diff --git a/ui/README.md b/ui/README.md new file mode 100644 index 00000000..adc197b9 --- /dev/null +++ b/ui/README.md @@ -0,0 +1,289 @@ +# Inference Gateway UI + +Web interface for the Inference Gateway CLI, featuring direct database access and headless CLI session management using the AG-UI protocol. + +## Features + +- **AG-UI Protocol Compliant**: Headless CLI sessions via stdin/stdout using the [AG-UI Protocol](https://docs.ag-ui.com) +- **Direct Database Access**: TypeScript storage client mirrors CLI's Go storage layer +- **Multiple Storage Backends**: PostgreSQL, SQLite, Redis, JSONL, In-Memory +- **Concurrent Chat Sessions**: Spawn and manage multiple headless CLI processes +- **Real-time Streaming**: Live streaming responses from LLMs +- **Multimodal Support**: Text and image inputs +- **Conversation Management**: Browse, search, and export conversations +- **Analytics Dashboard**: Token usage and cost statistics + +## Architecture + +```text +┌─────────────────────────────────────────────────────────┐ +│ UI (Next.js) │ +│ ┌────────────────────┐ ┌─────────────────────┐ │ +│ │ TypeScript Storage │ │ CLI Process Mgr │ │ +│ │ Client (PG/SQLite/ │ │ (spawn headless │ │ +│ │ Redis/JSONL) │ │ CLI sessions) │ │ +│ └──────────┬─────────┘ └──────────┬──────────┘ │ +└─────────────┼────────────────────────────┼──────────────┘ + │ │ + │ Direct DB Access │ AG-UI Protocol + │ │ (stdin/stdout JSON) +┌─────────────▼────────────────────────────▼──────────────┐ +│ Shared Database │ +│ (PostgreSQL / SQLite / Redis / JSONL) │ +└──────────────────────────┬───────────────────────────────┘ + │ + ┌────────────┴─────────────┐ + │ │ +┌─────────────▼──────────┐ ┌───────────▼──────────┐ +│ CLI Session 1 │ │ CLI Session 2 │ +│ (headless, UI-spawned) │ │ (headless, UI-spawned)│ +└────────────────────────┘ └───────────────────────┘ +``` + +## Prerequisites + +- Node.js 20+ or 22+ +- Inference Gateway CLI installed (`infer` binary in PATH or specified location) +- Database setup (PostgreSQL, SQLite, Redis, or JSONL directory) + +## Installation + +```bash +cd ui +npm install +``` + +## Configuration + +Create a `.env.local` file: + +```env +# Storage Configuration (choose one) +NEXT_PUBLIC_STORAGE_TYPE=postgres + +# PostgreSQL +NEXT_PUBLIC_STORAGE_POSTGRES_HOST=localhost +NEXT_PUBLIC_STORAGE_POSTGRES_PORT=5432 +NEXT_PUBLIC_STORAGE_POSTGRES_DB=infer +NEXT_PUBLIC_STORAGE_POSTGRES_USER=postgres +NEXT_PUBLIC_STORAGE_POSTGRES_PASSWORD=password + +# SQLite +NEXT_PUBLIC_STORAGE_SQLITE_PATH=/Users/username/.infer/conversations.db + +# Redis +NEXT_PUBLIC_STORAGE_REDIS_HOST=localhost +NEXT_PUBLIC_STORAGE_REDIS_PORT=6379 +NEXT_PUBLIC_STORAGE_REDIS_DB=0 + +# JSONL +NEXT_PUBLIC_STORAGE_JSONL_DIR=/Users/username/.infer/conversations + +# CLI Path +NEXT_PUBLIC_CLI_PATH=/usr/local/bin/infer +``` + +## Development + +```bash +npm run dev +``` + +Open [http://localhost:3000](http://localhost:3000) + +## Production Build + +```bash +npm run build +npm run start +``` + +## Project Structure + +```text +ui/ +├── app/ # Next.js 16 app directory +│ ├── layout.tsx # Root layout +│ ├── page.tsx # Home page +│ ├── providers.tsx # React Query provider +│ ├── chat/ # Live chat page +│ ├── conversations/ # Conversation browser page +│ └── dashboard/ # Analytics dashboard page +├── components/ # React components +│ ├── chat/ # Chat UI components +│ ├── conversations/ # Conversation list/detail components +│ └── stats/ # Dashboard components +├── lib/ # Core libraries +│ ├── storage/ # TypeScript storage client +│ │ ├── interfaces.ts # Type definitions +│ │ ├── factory.ts # Storage factory +│ │ ├── hooks.ts # React Query hooks +│ │ ├── postgres/ # PostgreSQL implementation +│ │ ├── sqlite/ # SQLite implementation +│ │ ├── redis/ # Redis implementation +│ │ ├── jsonl/ # JSONL implementation +│ │ └── memory/ # In-memory implementation +│ └── cli/ # CLI process manager +│ ├── process-manager.ts # Session management +│ ├── hooks.ts # React hooks +│ └── index.ts # Exports +└── package.json +``` + +## AG-UI Protocol + +The UI communicates with headless CLI sessions using the [AG-UI Protocol](https://docs.ag-ui.com): + +### Event Types (CLI → UI) + +**Run Lifecycle:** + +- `RunStarted`: Agent run begins +- `RunFinished`: Agent run completes +- `RunError`: Error during run + +**Text Messages:** + +- `TextMessageStart`: Start of streaming message +- `TextMessageContent`: Delta content chunk +- `TextMessageEnd`: End of streaming message + +**Tool Execution:** + +- `ToolCallStart`: Tool invocation begins (includes `status` and `metadata` fields) +- `ToolCallArgs`: Tool arguments +- `ToolCallEnd`: Tool call ready +- `ToolCallProgress`: Intermediate progress updates during execution (NEW) +- `ToolCallResult`: Tool execution result (includes `status`, `duration`, and `metadata` fields) +- `ParallelToolsMetadata`: Summary metadata for parallel tool execution (NEW) + +### Input Types (UI → CLI) + +- `message`: User message with optional images +- `interrupt`: Stop current processing +- `shutdown`: Graceful shutdown + +### Tool Execution Progress Events (NEW) + +The AG-UI protocol now includes real-time progress tracking for tool execution: + +**ToolCallStart** (Enhanced): + +```json +{ + "type": "ToolCallStart", + "toolCallId": "call_abc123", + "toolCallName": "Bash", + "parentMessageId": "msg_xyz789", + "status": "queued", + "timestamp": 1703001234567 +} +``` + +**ToolCallProgress** (New): + +```json +{ + "type": "ToolCallProgress", + "toolCallId": "call_abc123", + "status": "running", + "message": "Executing command...", + "output": "total 48\ndrwxr-xr-x 12 user staff 384 Dec 21 10:30 .\n", + "metadata": { + "isComplete": false + }, + "timestamp": 1703001235123 +} +``` + +**ToolCallResult** (Enhanced): + +```json +{ + "type": "ToolCallResult", + "messageId": "msg_result_456", + "toolCallId": "call_abc123", + "content": "Command executed successfully", + "role": "tool", + "status": "complete", + "duration": 2.34, + "timestamp": 1703001236567 +} +``` + +**ParallelToolsMetadata** (New): + +```json +{ + "type": "ParallelToolsMetadata", + "totalCount": 3, + "successCount": 2, + "failureCount": 1, + "totalDuration": 5.67, + "timestamp": 1703001240000 +} +``` + +**Status Values:** + +- `queued`: Tool is queued for execution +- `running`: Tool is actively executing +- `complete`: Tool execution completed successfully +- `failed`: Tool execution failed with error + +## Usage Examples + +### Starting a Chat Session + +```typescript +import { useCLISession } from "@/lib/cli/hooks"; + +function ChatPage() { + const { start, sendMessage, messages } = useCLISession(); + + useEffect(() => { + start(); + }, []); + + const handleSend = () => { + sendMessage({ content: "Hello!" }); + }; + + // Render chat UI +} +``` + +### Accessing Storage + +```typescript +import { useConversations } from "@/lib/storage/hooks"; + +function ConversationsList() { + const { data: conversations } = useConversations(50, 0); + + // Render conversations +} +``` + +## Extracting to Separate Repository + +This UI is currently embedded in the CLI repository for convenience. To extract: + +```bash +# From CLI repo root +cp -r ui /path/to/new/ui-repo +cd /path/to/new/ui-repo +npm install +``` + +Update paths if needed and ensure `.env.local` points to correct CLI binary and database. + +## License + +Same as Inference Gateway CLI + +## Links + +- [AG-UI Protocol Documentation](https://docs.ag-ui.com) +- [Inference Gateway CLI](https://github.com/inference-gateway/cli) +- [Next.js 16 Documentation](https://nextjs.org/docs) diff --git a/ui/app/globals.css b/ui/app/globals.css new file mode 100644 index 00000000..9c74d1fc --- /dev/null +++ b/ui/app/globals.css @@ -0,0 +1,124 @@ +@import "tailwindcss"; +@plugin "tailwindcss-animate"; + +@custom-variant dark (&:is(.dark *)); + +@theme inline { + --radius-sm: calc(var(--radius) - 4px); + --radius-md: calc(var(--radius) - 2px); + --radius-lg: var(--radius); + --radius-xl: calc(var(--radius) + 4px); + --radius-2xl: calc(var(--radius) + 8px); + --radius-3xl: calc(var(--radius) + 12px); + --radius-4xl: calc(var(--radius) + 16px); + --color-background: var(--background); + --color-foreground: var(--foreground); + --color-card: var(--card); + --color-card-foreground: var(--card-foreground); + --color-popover: var(--popover); + --color-popover-foreground: var(--popover-foreground); + --color-primary: var(--primary); + --color-primary-foreground: var(--primary-foreground); + --color-secondary: var(--secondary); + --color-secondary-foreground: var(--secondary-foreground); + --color-muted: var(--muted); + --color-muted-foreground: var(--muted-foreground); + --color-accent: var(--accent); + --color-accent-foreground: var(--accent-foreground); + --color-destructive: var(--destructive); + --color-border: var(--border); + --color-input: var(--input); + --color-ring: var(--ring); + --color-chart-1: var(--chart-1); + --color-chart-2: var(--chart-2); + --color-chart-3: var(--chart-3); + --color-chart-4: var(--chart-4); + --color-chart-5: var(--chart-5); + --color-sidebar: var(--sidebar); + --color-sidebar-foreground: var(--sidebar-foreground); + --color-sidebar-primary: var(--sidebar-primary); + --color-sidebar-primary-foreground: var(--sidebar-primary-foreground); + --color-sidebar-accent: var(--sidebar-accent); + --color-sidebar-accent-foreground: var(--sidebar-accent-foreground); + --color-sidebar-border: var(--sidebar-border); + --color-sidebar-ring: var(--sidebar-ring); +} + +:root { + --radius: 0.625rem; + --background: oklch(1 0 0); + --foreground: oklch(0.145 0 0); + --card: oklch(1 0 0); + --card-foreground: oklch(0.145 0 0); + --popover: oklch(1 0 0); + --popover-foreground: oklch(0.145 0 0); + --primary: oklch(0.205 0 0); + --primary-foreground: oklch(0.985 0 0); + --secondary: oklch(0.97 0 0); + --secondary-foreground: oklch(0.205 0 0); + --muted: oklch(0.97 0 0); + --muted-foreground: oklch(0.556 0 0); + --accent: oklch(0.97 0 0); + --accent-foreground: oklch(0.205 0 0); + --destructive: oklch(0.577 0.245 27.325); + --border: oklch(0.922 0 0); + --input: oklch(0.922 0 0); + --ring: oklch(0.708 0 0); + --chart-1: oklch(0.646 0.222 41.116); + --chart-2: oklch(0.6 0.118 184.704); + --chart-3: oklch(0.398 0.07 227.392); + --chart-4: oklch(0.828 0.189 84.429); + --chart-5: oklch(0.769 0.188 70.08); + --sidebar: oklch(0.985 0 0); + --sidebar-foreground: oklch(0.145 0 0); + --sidebar-primary: oklch(0.205 0 0); + --sidebar-primary-foreground: oklch(0.985 0 0); + --sidebar-accent: oklch(0.97 0 0); + --sidebar-accent-foreground: oklch(0.205 0 0); + --sidebar-border: oklch(0.922 0 0); + --sidebar-ring: oklch(0.708 0 0); +} + +.dark { + --background: oklch(0.145 0 0); + --foreground: oklch(0.985 0 0); + --card: oklch(0.205 0 0); + --card-foreground: oklch(0.985 0 0); + --popover: oklch(0.205 0 0); + --popover-foreground: oklch(0.985 0 0); + --primary: oklch(0.922 0 0); + --primary-foreground: oklch(0.205 0 0); + --secondary: oklch(0.269 0 0); + --secondary-foreground: oklch(0.985 0 0); + --muted: oklch(0.269 0 0); + --muted-foreground: oklch(0.708 0 0); + --accent: oklch(0.269 0 0); + --accent-foreground: oklch(0.985 0 0); + --destructive: oklch(0.704 0.191 22.216); + --border: oklch(1 0 0 / 10%); + --input: oklch(1 0 0 / 15%); + --ring: oklch(0.556 0 0); + --chart-1: oklch(0.488 0.243 264.376); + --chart-2: oklch(0.696 0.17 162.48); + --chart-3: oklch(0.769 0.188 70.08); + --chart-4: oklch(0.627 0.265 303.9); + --chart-5: oklch(0.645 0.246 16.439); + --sidebar: oklch(0.205 0 0); + --sidebar-foreground: oklch(0.985 0 0); + --sidebar-primary: oklch(0.488 0.243 264.376); + --sidebar-primary-foreground: oklch(0.985 0 0); + --sidebar-accent: oklch(0.269 0 0); + --sidebar-accent-foreground: oklch(0.985 0 0); + --sidebar-border: oklch(1 0 0 / 10%); + --sidebar-ring: oklch(0.556 0 0); +} + +* { + border-color: var(--color-border); + outline-color: color-mix(in oklch, var(--color-ring) 50%, transparent); +} + +body { + background-color: var(--color-background); + color: var(--color-foreground); +} diff --git a/ui/app/layout.tsx b/ui/app/layout.tsx new file mode 100644 index 00000000..1f702ffe --- /dev/null +++ b/ui/app/layout.tsx @@ -0,0 +1,22 @@ +import type { Metadata } from "next"; +import "./globals.css"; +import { Providers } from "./providers"; + +export const metadata: Metadata = { + title: "Inference Gateway UI", + description: "Web UI for Inference Gateway CLI", +}; + +export default function RootLayout({ + children, +}: Readonly<{ + children: React.ReactNode; +}>) { + return ( + + + {children} + + + ); +} diff --git a/ui/app/page.tsx b/ui/app/page.tsx new file mode 100644 index 00000000..ea403ddf --- /dev/null +++ b/ui/app/page.tsx @@ -0,0 +1,968 @@ +"use client"; + +import { useState, useEffect, useRef, useCallback } from "react"; +import { apiClient } from "../lib/api/client"; +import type { ConversationSummary } from "../lib/storage/interfaces"; +import { WebSocketChatClient } from "../lib/chat/websocket-client"; +import { Button } from "@/components/ui/button"; +import { + Command, + CommandEmpty, + CommandGroup, + CommandInput, + CommandItem, + CommandList, +} from "@/components/ui/command"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; +import { Check, ChevronsUpDown, Send as SendIcon, Square } from "lucide-react"; +import { cn } from "@/lib/utils"; +import { ThemeToggle } from "@/components/theme-toggle"; +import StatusBar from "@/components/status-bar"; +import { ToolCallDisplay } from "@/components/tool-call-display"; +import type { ToolCallState } from "@/lib/agui-types"; + +// Utility to strip ANSI color codes from terminal output +function stripAnsiCodes(text: string): string { + return text.replace(/\x1B\[[0-9;]*[a-zA-Z]/g, '').replace(/\[[\d;]+m/g, ''); +} + +export default function Home() { + const [conversations, setConversations] = useState([]); + const [selectedConversation, setSelectedConversation] = useState(null); + const [loading, setLoading] = useState(true); + const [view, setView] = useState<"chat" | "dashboard">("chat"); + const [liveSessionId, setLiveSessionId] = useState(null); + const [sidebarOpen, setSidebarOpen] = useState(false); + const [hasActiveNewChat, setHasActiveNewChat] = useState(false); + const [sessionRestored, setSessionRestored] = useState(false); + const [chatKey, setChatKey] = useState(0); + + useEffect(() => { + if (loading || sessionRestored) return; + + const savedSessionId = localStorage.getItem("currentSessionId"); + const savedSessionType = localStorage.getItem("currentSessionType"); + + if (savedSessionId && savedSessionType === "new") { + setLiveSessionId("new"); + setHasActiveNewChat(true); + setSessionRestored(true); + } else if (savedSessionId && savedSessionType === "conversation") { + if (conversations.some(c => c.id === savedSessionId)) { + setSelectedConversation(savedSessionId); + setLiveSessionId(savedSessionId); + setHasActiveNewChat(false); + setSessionRestored(true); + } else if (conversations.length > 0 || !loading) { + console.warn("Saved conversation not found, clearing localStorage"); + localStorage.removeItem("currentSessionId"); + localStorage.removeItem("currentSessionType"); + setSessionRestored(true); + } + } else { + setSessionRestored(true); + } + }, [conversations, loading, sessionRestored]); + + useEffect(() => { + loadConversations(); + }, []); + + const loadConversations = async () => { + try { + setLoading(true); + const response = await apiClient.listConversations(50, 0); + setConversations(response.conversations); + } catch (error) { + console.error("Failed to load conversations:", error); + } finally { + setLoading(false); + } + }; + + const handleNewChat = () => { + loadConversations(); + + setLiveSessionId("new"); + setSelectedConversation(null); + setHasActiveNewChat(true); + setSidebarOpen(false); + setChatKey(prev => prev + 1); + localStorage.setItem("currentSessionId", "new"); + localStorage.setItem("currentSessionType", "new"); + + localStorage.removeItem("wsSession_new-chat"); + localStorage.removeItem("conversationId_new-chat"); + }; + + const handleContinueConversation = (conversationId: string) => { + setSelectedConversation(conversationId); + setLiveSessionId(conversationId); + setHasActiveNewChat(false); + setSidebarOpen(false); + localStorage.setItem("currentSessionId", conversationId); + localStorage.setItem("currentSessionType", "conversation"); + }; + + const handleDeleteConversation = async (conversationId: string, event: React.MouseEvent) => { + event.stopPropagation(); + + if (!confirm("Are you sure you want to delete this conversation? This cannot be undone.")) { + return; + } + + try { + await apiClient.deleteConversation(conversationId); + + setConversations(prev => prev.filter(c => c.id !== conversationId)); + + if (selectedConversation === conversationId) { + setSelectedConversation(null); + setLiveSessionId(null); + } + } catch (error) { + console.error("Failed to delete conversation:", error); + alert("Failed to delete conversation. Please try again."); + } + }; + + return ( +
+ {!sidebarOpen && ( + + )} + + {sidebarOpen && ( +
setSidebarOpen(false)} + /> + )} + +
+
+
+
+

Inference Gateway

+

Chat UI

+
+
+ + +
+
+
+
+ +
+
+ + +
+ + {/* Conversations List */} +
+ {view === "chat" && ( + <> + {loading ? ( +
+
+

Loading conversations...

+
+ ) : conversations.length === 0 && !hasActiveNewChat ? ( +
+

No conversations yet

+

Start a headless CLI session to create one

+
+ ) : ( + <> + {hasActiveNewChat && ( +
+
+
+ New Chat +
+
+ Active session +
+
+ {new Date().toLocaleDateString()} +
+
+
+ )} + {conversations.map((conv) => ( +
+ + +
+ ))} + + )} + + )} + + {view === "dashboard" && ( +
+
+
Total Conversations
+
{conversations.length}
+
+ +
+
Total Messages
+
+ {conversations.reduce((sum, c) => sum + c.message_count, 0)} +
+
+ +
+
Total Tokens
+
+ {conversations.reduce((sum, c) => sum + c.token_stats.total_input_tokens + c.token_stats.total_output_tokens, 0).toLocaleString()} +
+
+ +
+
Total Cost
+
+ ${conversations.reduce((sum, c) => sum + (c.cost_stats?.total_cost || 0), 0).toFixed(4)} +
+
+
+ )} +
+ + {/* Footer */} +
+
+ API Server + :8081 +
+
+
+ + {/* Main Chat Area */} +
+ {liveSessionId ? ( + { + setLiveSessionId(null); + setSelectedConversation(null); + setHasActiveNewChat(false); + localStorage.removeItem("currentSessionId"); + localStorage.removeItem("currentSessionType"); + loadConversations(); + }} + /> + ) : ( + + )} +
+
+ ); +} + +// ChatView is now integrated into LiveChatView - removed to avoid duplication + +type TimelineItem = + | { type: 'message'; id: string; role: string; content: string } + | { type: 'tool'; id: string; name: string; status: ToolCallState['status']; startTime?: number; duration?: number; message?: string; output?: string; arguments?: string; result?: any }; + +function LiveChatView({ conversationId, onClose }: { conversationId?: string; onClose: () => void }) { + const [timeline, setTimeline] = useState([]); + const [inputMessage, setInputMessage] = useState(""); + const [isConnecting, setIsConnecting] = useState(true); + const [isSending, setIsSending] = useState(false); + const [error, setError] = useState(null); + const [models, setModels] = useState([]); + const [selectedModel, setSelectedModel] = useState(""); + const [loadingModels, setLoadingModels] = useState(true); + const [modelSelectorOpen, setModelSelectorOpen] = useState(false); + const wsClientRef = useRef(null); + const messagesEndRef = useRef(null); + const inputRef = useRef(null); + const sendingTimeoutRef = useRef(null); + const cleanupTimeoutRef = useRef(null); + const isMountedRef = useRef(true); + const actualConversationIdRef = useRef(conversationId); + const [history, setHistory] = useState([]); + const [historyIndex, setHistoryIndex] = useState(-1); + const [currentDraft, setCurrentDraft] = useState(""); + + useEffect(() => { + apiClient + .listModels() + .then((data) => { + setModels(data.models); + + const savedModel = localStorage.getItem("selectedModel"); + + if (savedModel && data.models.includes(savedModel)) { + setSelectedModel(savedModel); + } else if (data.models.length > 0) { + setSelectedModel(data.models[0]); + } + + setLoadingModels(false); + }) + .catch((err) => { + console.error("Failed to load models:", err); + setLoadingModels(false); + }); + + apiClient + .getHistory() + .then((data) => { + setHistory(data.history); + }) + .catch((err) => { + console.error("Failed to load history:", err); + }); + }, []); + + useEffect(() => { + isMountedRef.current = true; + + if (sendingTimeoutRef.current) { + clearTimeout(sendingTimeoutRef.current); + sendingTimeoutRef.current = null; + } + + const loadExistingConversation = async () => { + if (conversationId) { + try { + const data = await apiClient.getConversation(conversationId); + const loadedTimeline: TimelineItem[] = data.entries + .filter((entry: any) => { + return entry.message.hidden !== true; + }) + .map((entry: any, index: number) => ({ + type: 'message' as const, + id: `msg-${index}`, + role: entry.message.role, + content: typeof entry.message.content === "string" + ? stripAnsiCodes(entry.message.content) + : JSON.stringify(entry.message.content) + })); + setTimeline(loadedTimeline); + } catch (error: any) { + console.error("[LiveChatView] Failed to load conversation:", error); + setError(error.message || "Failed to load conversation"); + setIsConnecting(false); + } + } else { + setTimeline([]); + } + }; + + loadExistingConversation(); + + // Create session ID based on conversation to ensure proper reuse + // Same conversation = same session (reuse containers) + // Different conversation = different session (fresh start) + const sessionKey = conversationId || "new-chat"; + let wsSessionId = localStorage.getItem(`wsSession_${sessionKey}`); + + if (!wsSessionId) { + wsSessionId = `ws-${Date.now()}-${Math.random().toString(36).substring(2, 11)}`; + localStorage.setItem(`wsSession_${sessionKey}`, wsSessionId); + } + + const storedConversationId = localStorage.getItem(`conversationId_${sessionKey}`); + const effectiveConversationId = storedConversationId || conversationId; + + if (storedConversationId) { + actualConversationIdRef.current = storedConversationId; + } + + const client = new WebSocketChatClient(); + wsClientRef.current = client; + + const unsubscribe = client.onMessage((event) => { + if (event.type === "error") { + setError(event.data.error); + if (sendingTimeoutRef.current) { + clearTimeout(sendingTimeoutRef.current); + sendingTimeoutRef.current = null; + } + setIsSending(false); + return; + } + + try { + const data = typeof event.data === "string" ? JSON.parse(event.data) : event.data; + + console.log('[WS Event]', data.type); + + switch (data.type) { + case "TextMessageStart": + setTimeline((prev) => [ + ...prev, + { + type: 'message', + id: data.messageId || `msg-${Date.now()}`, + role: data.role || 'assistant', + content: '' + } + ]); + break; + + case "TextMessageContent": + setTimeline((prev) => { + const last = prev[prev.length - 1]; + if (last && last.type === 'message') { + return [ + ...prev.slice(0, -1), + { ...last, content: last.content + (data.delta || '') } + ]; + } + return prev; + }); + break; + + case "TextMessageEnd": + // Message complete (no action needed) + break; + + case "RunStarted": + if (data.input) { + setIsSending(true); + if (sendingTimeoutRef.current) { + clearTimeout(sendingTimeoutRef.current); + } + sendingTimeoutRef.current = setTimeout(() => { + console.warn("Run timeout - no RunFinished/RunError received within 5 minutes"); + setIsSending(false); + sendingTimeoutRef.current = null; + }, 5 * 60 * 1000); + } + break; + + case "RunFinished": + if (sendingTimeoutRef.current) { + clearTimeout(sendingTimeoutRef.current); + sendingTimeoutRef.current = null; + } + setIsSending(false); + break; + + case "RunError": + console.error("Run error:", data); + if (sendingTimeoutRef.current) { + clearTimeout(sendingTimeoutRef.current); + sendingTimeoutRef.current = null; + } + setIsSending(false); + if (data.message) { + setTimeline((prev) => [ + ...prev, + { type: 'message', id: `err-${Date.now()}`, role: "assistant", content: `Error: ${data.message}` } + ]); + } + break; + + case "ToolCallStart": + setTimeline((prev) => [ + ...prev, + { + type: 'tool', + id: data.toolCallId, + name: data.toolCallName, + status: data.status || 'queued', + startTime: data.timestamp || Date.now(), + arguments: data.arguments || '', + } + ]); + break; + + case "ToolCallArgs": + console.log('[ToolCallArgs]', data.toolCallId, 'delta:', data.delta); + setTimeline((prev) => + prev.map(item => + item.type === 'tool' && item.id === data.toolCallId + ? { + ...item, + arguments: (item.arguments || '') + data.delta, + } + : item + ) + ); + break; + + case "ToolCallEnd": + console.log('[ToolCallEnd]', data.toolCallId); + // Tool call specification complete (no UI action needed) + break; + + case "ToolCallProgress": + setTimeline((prev) => + prev.map(item => + item.type === 'tool' && item.id === data.toolCallId + ? { + ...item, + status: data.status, + message: data.message, + output: data.output ? (item.output || '') + data.output : item.output, + } + : item + ) + ); + break; + + case "ToolCallResult": + setTimeline((prev) => + prev.map(item => + item.type === 'tool' && item.id === data.toolCallId + ? { + ...item, + status: data.status || 'complete', + duration: data.duration, + output: typeof data.content === 'string' ? data.content : JSON.stringify(data.content, null, 2), + } + : item + ) + ); + break; + + case "ParallelToolsMetadata": + console.log("Parallel tools completed:", { + total: data.totalCount, + success: data.successCount, + failed: data.failureCount, + duration: data.totalDuration, + }); + break; + + case "session_created": + if (data.conversation_id) { + actualConversationIdRef.current = data.conversation_id; + + const sessionKey = conversationId || "new-chat"; + localStorage.setItem(`conversationId_${sessionKey}`, data.conversation_id); + } + break; + + case "ConversationCreated": + if (data.conversation_id) { + actualConversationIdRef.current = data.conversation_id; + + const sessionKey = conversationId || "new-chat"; + localStorage.setItem(`conversationId_${sessionKey}`, data.conversation_id); + } + break; + + default: + console.log("Unknown event type:", data.type || event.type); + console.log("Event data:", JSON.stringify(data, null, 2)); + break; + } + } catch (error) { + console.error("Failed to parse message:", error); + } + }); + + client.createSession(effectiveConversationId, wsSessionId) + .then((returnedSessionId) => { + setIsConnecting(false); + }) + .catch((err) => { + console.error("[LiveChatView] Failed to create session:", err); + setError(`Failed to create session: ${err.message}`); + setIsConnecting(false); + }); + + return () => { + isMountedRef.current = false; + + unsubscribe(); + + if (sendingTimeoutRef.current) { + clearTimeout(sendingTimeoutRef.current); + sendingTimeoutRef.current = null; + } + + if (cleanupTimeoutRef.current) { + clearTimeout(cleanupTimeoutRef.current); + } + + cleanupTimeoutRef.current = setTimeout(() => { + if (!isMountedRef.current && wsClientRef.current) { + wsClientRef.current.close(); + wsClientRef.current = null; + } + }, 150); + }; + }, [conversationId]); + + useEffect(() => { + messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); + }, [timeline]); + + const handleStop = useCallback(() => { + if (!wsClientRef.current) return; + + try { + wsClientRef.current.interrupt(); + if (sendingTimeoutRef.current) { + clearTimeout(sendingTimeoutRef.current); + sendingTimeoutRef.current = null; + } + setIsSending(false); + } catch (err: any) { + console.error("[LiveChatView] Failed to stop response:", err); + } + }, []); + + const handleSend = () => { + if (!inputMessage.trim() || !wsClientRef.current) return; + + const messageContent = inputMessage; + + try { + wsClientRef.current.sendMessage(messageContent, [], selectedModel); + + apiClient.saveToHistory(messageContent).catch((err) => { + console.error("Failed to save to history:", err); + }); + setHistory((prev) => [...prev, messageContent]); + + setInputMessage(""); + setHistoryIndex(-1); + setCurrentDraft(""); + } catch (err: any) { + setError(`Failed to send message: ${err.message}`); + } + }; + + useEffect(() => { + const handleKeyDown = (e: KeyboardEvent) => { + if (e.key === "Escape" && isSending && wsClientRef.current) { + handleStop(); + } + }; + + window.addEventListener("keydown", handleKeyDown); + return () => window.removeEventListener("keydown", handleKeyDown); + }, [isSending, handleStop]); + + useEffect(() => { + if (!isSending && !isConnecting && inputRef.current) { + inputRef.current.focus(); + } + }, [isSending, isConnecting]); + + if (isConnecting) { + return ( +
+
+
+

Starting new chat session...

+
+
+ ); + } + + if (error) { + return ( +
+
+
⚠️
+

Connection Error

+

{error}

+ +
+
+ ); + } + + return ( + <> +
+
+
+

{conversationId ? "Continue Conversation" : "New Chat Session"}

+

Live chat via WebSocket

+
+ +
+ + {loadingModels ? ( +
Loading...
+ ) : ( +
+ + + + + + + + + No model found. + + {models.map((model) => { + const searchTerms = model.toLowerCase().split(/[\/\-\s]+/); + return ( + { + setSelectedModel(model); + localStorage.setItem("selectedModel", model); + setModelSelectorOpen(false); + }} + > + {model} + + + ); + })} + + + + + +
+ )} + +
+ + +
+
+
+ {timeline.length === 0 ? ( +
+

How can I help you today?

+
+ ) : ( +
+ {timeline.map((item, index) => + item.type === 'message' ? ( +
+
+
+ {stripAnsiCodes(item.content)} +
+
+
+ ) : ( + + ) + )} +
+
+ )} +
+
+
+
+