diff --git a/.infer/config.yaml b/.infer/config.yaml index 94da6715..439a0793 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -6,7 +6,7 @@ output: format: text quiet: false tools: - enabled: false + enabled: true whitelist: commands: - ls diff --git a/cmd/chat.go b/cmd/chat.go index 7c9cb1bf..61759b8e 100644 --- a/cmd/chat.go +++ b/cmd/chat.go @@ -8,13 +8,12 @@ import ( "path/filepath" "regexp" "strings" - "sync" "time" + "github.com/charmbracelet/bubbletea" "github.com/inference-gateway/cli/config" "github.com/inference-gateway/cli/internal" sdk "github.com/inference-gateway/sdk" - "github.com/manifoldco/promptui" "github.com/spf13/cobra" ) @@ -75,38 +74,56 @@ func startChatSession() error { return fmt.Errorf("model selection failed: %w", err) } - fmt.Printf("\nšŸ¤– Starting chat session with %s\n", selectedModel) - fmt.Println("Commands: '/exit' to quit, '/clear' for history, '/compact' to export, '/help' for all") - fmt.Println("Commands are processed immediately and won't be sent to the model") - fmt.Println("šŸ“ File references: Use @filename to include file contents in your message") + var conversation []sdk.Message + + inputModel := internal.NewChatInputModel() + program := tea.NewProgram(inputModel, tea.WithAltScreen()) + + var toolsManager *internal.LLMToolsManager + if cfg.Tools.Enabled { + toolsManager = internal.NewLLMToolsManagerWithUI(cfg, program, inputModel) + } + + welcomeHistory := []string{ + fmt.Sprintf("šŸ¤– Chat session started with %s", selectedModel), + "šŸ’” Type '/help' or '?' for commands • Use @filename for file references", + } if cfg.Tools.Enabled { toolCount := len(createSDKTools(cfg)) if toolCount > 0 { - fmt.Printf("šŸ”§ Tools enabled: %d tool(s) available for the model to use\n", toolCount) + welcomeHistory = append(welcomeHistory, fmt.Sprintf("šŸ”§ %d tool(s) available for the model to use", toolCount)) } } - var conversation []sdk.Message + welcomeHistory = append(welcomeHistory, "") - for { - prompt := promptui.Prompt{ - Label: "You", - HideEntered: false, - Templates: &promptui.PromptTemplates{ - Prompt: "{{ . | bold }}{{ \":\" | faint }} ", - Valid: "{{ . | bold }}{{ \":\" | faint }} ", - Invalid: "{{ . | bold }}{{ \":\" | faint }} ", - }, + updateHistory := func(conversation []sdk.Message) { + if len(conversation) > 0 { + chatHistory := formatConversationForDisplay(conversation, selectedModel) + program.Send(internal.UpdateHistoryMsg{History: append(welcomeHistory, chatHistory...)}) + } else { + program.Send(internal.UpdateHistoryMsg{History: welcomeHistory}) } + } - userInput, err := prompt.Run() + go func() { + _, err := program.Run() if err != nil { - if err == promptui.ErrInterrupt { - break - } - fmt.Printf("Error reading input: %v\n", err) - continue + fmt.Printf("Error running chat interface: %v\n", err) + } + }() + + updateHistory(conversation) + + for { + updateHistory(conversation) + + userInput := waitForInput(program, inputModel) + if userInput == "" { + program.Quit() + fmt.Println("\nšŸ‘‹ Chat session ended!") + os.Exit(0) } userInput = strings.TrimSpace(userInput) @@ -120,7 +137,7 @@ func startChatSession() error { processedInput, err := processFileReferences(userInput) if err != nil { - fmt.Printf("āŒ Error processing file references: %v\n", err) + program.Send(internal.SetStatusMsg{Message: fmt.Sprintf("āŒ Error processing file references: %v", err), Spinner: false}) continue } @@ -129,28 +146,34 @@ func startChatSession() error { Content: processedInput, }) - fmt.Printf("\n%s: ", selectedModel) + userChatHistory := formatConversationForDisplay(conversation, selectedModel) + program.Send(internal.UpdateHistoryMsg{History: userChatHistory}) + + program.Send(internal.SetStatusMsg{Message: "Generating response...", Spinner: true}) + + inputModel.ResetCancellation() var totalMetrics *ChatMetrics maxIterations := 10 for iteration := 0; iteration < maxIterations; iteration++ { - var wg sync.WaitGroup - var spinnerActive = true - var mu sync.Mutex - - wg.Add(1) - go func() { - defer wg.Done() - showSpinner(&spinnerActive, &mu) - }() - - assistantMessage, assistantToolCalls, metrics, err := sendStreamingChatCompletion(cfg, selectedModel, conversation, &spinnerActive, &mu) + if inputModel.IsCancelled() { + conversation = conversation[:len(conversation)-1] + chatHistory := formatConversationForDisplay(conversation, selectedModel) + program.Send(internal.UpdateHistoryMsg{History: chatHistory}) + program.Send(internal.SetStatusMsg{Message: "āŒ Generation cancelled by user", Spinner: false}) + break + } - wg.Wait() + _, assistantToolCalls, metrics, err := sendStreamingChatCompletionToUI(cfg, selectedModel, conversation, program, &conversation, inputModel) if err != nil { - fmt.Printf("āŒ Error: %v\n", err) + if strings.Contains(err.Error(), "cancelled by user") { + break + } conversation = conversation[:len(conversation)-1] + chatHistory := formatConversationForDisplay(conversation, selectedModel) + program.Send(internal.UpdateHistoryMsg{History: chatHistory}) + program.Send(internal.SetStatusMsg{Message: fmt.Sprintf("āŒ Error: %v", err), Spinner: false}) break } @@ -165,14 +188,14 @@ func startChatSession() error { } } - assistantMsg := sdk.Message{ - Role: sdk.Assistant, - Content: assistantMessage, - } - if len(assistantToolCalls) > 0 { - assistantMsg.ToolCalls = &assistantToolCalls + if len(assistantToolCalls) > 0 && len(conversation) > 0 { + lastIdx := len(conversation) - 1 + if conversation[lastIdx].Role == sdk.Assistant { + conversation[lastIdx].ToolCalls = &assistantToolCalls + chatHistory := formatConversationForDisplay(conversation, selectedModel) + program.Send(internal.UpdateHistoryMsg{History: chatHistory}) + } } - conversation = append(conversation, assistantMsg) if len(assistantToolCalls) == 0 { break @@ -180,61 +203,80 @@ func startChatSession() error { toolExecutionFailed := false for _, toolCall := range assistantToolCalls { - toolResult, err := executeToolCall(cfg, toolCall.Function.Name, toolCall.Function.Arguments) + if inputModel.IsCancelled() { + conversation = conversation[:len(conversation)-1] + chatHistory := formatConversationForDisplay(conversation, selectedModel) + program.Send(internal.UpdateHistoryMsg{History: chatHistory}) + program.Send(internal.SetStatusMsg{Message: "āŒ Generation cancelled by user", Spinner: false}) + toolExecutionFailed = true + break + } + + toolResult, err := executeToolCall(toolsManager, toolCall.Function.Name, toolCall.Function.Arguments) if err != nil { - fmt.Printf("āŒ Tool execution failed: %v\n", err) + program.Send(internal.SetStatusMsg{Message: fmt.Sprintf("āŒ Tool execution failed: %v", err), Spinner: false}) toolExecutionFailed = true break } else { - fmt.Printf("āœ… Tool result:\n%s\n", toolResult) conversation = append(conversation, sdk.Message{ Role: sdk.Tool, Content: toolResult, ToolCallId: &toolCall.Id, }) + program.Send(internal.SetStatusMsg{Message: "āœ… Tool executed successfully", Spinner: false}) } } if toolExecutionFailed { conversation = conversation[:len(conversation)-1] - fmt.Printf("\nāŒ Tool execution was cancelled. Please try a different request.\n") + chatHistory := formatConversationForDisplay(conversation, selectedModel) + program.Send(internal.UpdateHistoryMsg{History: chatHistory}) + program.Send(internal.SetStatusMsg{Message: "āŒ Tool execution was cancelled. Please try a different request.", Spinner: false}) break } - fmt.Printf("\n%s: ", selectedModel) } - displayChatMetrics(totalMetrics) - fmt.Print("\n\n") + if totalMetrics != nil { + metricsMsg := formatMetricsString(totalMetrics) + program.Send(internal.SetStatusMsg{Message: fmt.Sprintf("āœ… Complete - %s", metricsMsg), Spinner: false}) + } else { + program.Send(internal.SetStatusMsg{Message: "āœ… Response complete", Spinner: false}) + } } +} - fmt.Println("\nšŸ‘‹ Chat session ended!") - return nil +// waitForInput waits for user input from the chat interface +func waitForInput(program *tea.Program, inputModel *internal.ChatInputModel) string { + for { + time.Sleep(100 * time.Millisecond) + if inputModel.HasInput() { + return inputModel.GetInput() + } + if inputModel.IsQuitRequested() { + return "" + } + } } func selectModel(models []string) (string, error) { - searcher := func(input string, index int) bool { - model := models[index] - name := strings.ReplaceAll(strings.ToLower(model), " ", "") - input = strings.ReplaceAll(strings.ToLower(input), " ", "") - return strings.Contains(name, input) - } - - prompt := promptui.Select{ - Label: "Search and select a model for the chat session (type / to search)", - Items: models, - Size: 10, - Searcher: searcher, - Templates: &promptui.SelectTemplates{ - Label: "{{ . }}?", - Active: "ā–¶ {{ . | cyan | bold }}", - Inactive: " {{ . }}", - Selected: "āœ“ Selected model: {{ . | green | bold }}", - }, - } - - _, result, err := prompt.Run() - return result, err + modelSelector := internal.NewModelSelectorModel(models) + program := tea.NewProgram(modelSelector) + + _, err := program.Run() + if err != nil { + return "", fmt.Errorf("model selection failed: %w", err) + } + + if modelSelector.IsCancelled() { + return "", fmt.Errorf("model selection was cancelled") + } + + if !modelSelector.IsSelected() { + return "", fmt.Errorf("no model was selected") + } + + return modelSelector.GetSelected(), nil } func getAvailableModelsList(cfg *config.Config) ([]string, error) { @@ -281,8 +323,8 @@ func handleChatCommands(input string, conversation *[]sdk.Message, selectedModel fmt.Printf("āŒ Error creating compact file: %v\n", err) } return true - case "/help": - showChatHelp() + case "/help", "?": + showHelpScreen() return true } @@ -320,35 +362,14 @@ func showConversationHistory(conversation []sdk.Message) { fmt.Println() } -func showChatHelp() { - fmt.Println("šŸ’¬ Chat Session Commands:") - fmt.Println() - fmt.Println("Chat Commands:") - fmt.Println(" /exit, /quit - Exit the chat session") - fmt.Println(" /clear - Clear conversation history") - fmt.Println(" /history - Show conversation history") - fmt.Println(" /models - Show current and available models") - fmt.Println(" /switch - Switch to a different model") - fmt.Println(" /compact - Export conversation to markdown file") - fmt.Println(" /help - Show this help") - fmt.Println() - fmt.Println("File References:") - fmt.Println(" @filename.txt - Include contents of filename.txt in your message") - fmt.Println(" @./config.yaml - Include contents of config.yaml from current directory") - fmt.Println(" @../README.md - Include contents of README.md from parent directory") - fmt.Println(" Maximum file size: 100KB") - fmt.Println() - fmt.Println("Tool Usage:") - fmt.Println(" Models can invoke available tools automatically during conversation") - fmt.Println(" Use 'infer tools list' to see whitelisted commands") - fmt.Println(" Use 'infer tools enable/disable' to control tool access") - fmt.Println(" Tools execute securely with command whitelisting") - fmt.Println() - fmt.Println("Input Tips:") - fmt.Println(" End line with '\\' for multi-line input") - fmt.Println(" Press Ctrl+C to interrupt") - fmt.Println(" Press Ctrl+D to exit") - fmt.Println() +func showHelpScreen() { + helpViewer := internal.NewHelpViewerModel() + program := tea.NewProgram(helpViewer, tea.WithAltScreen()) + + _, err := program.Run() + if err != nil { + fmt.Printf("Error displaying help: %v\n", err) + } } // ChatMetrics holds timing and token usage information @@ -357,8 +378,10 @@ type ChatMetrics struct { Usage *sdk.CompletionUsage } -func executeToolCall(cfg *config.Config, toolName, arguments string) (string, error) { - manager := internal.NewLLMToolsManager(cfg) +func executeToolCall(manager *internal.LLMToolsManager, toolName, arguments string) (string, error) { + if manager == nil { + return "", fmt.Errorf("tools are not enabled") + } var params map[string]interface{} if arguments != "" { @@ -396,7 +419,121 @@ func createSDKTools(cfg *config.Config) []sdk.ChatCompletionTool { return sdkTools } -func sendStreamingChatCompletion(cfg *config.Config, model string, messages []ChatMessage, spinnerActive *bool, mu *sync.Mutex) (string, []sdk.ChatCompletionMessageToolCall, *ChatMetrics, error) { +type uiStreamingResult struct { + fullMessage *strings.Builder + firstContent bool + usage *sdk.CompletionUsage + toolCalls []sdk.ChatCompletionMessageToolCall + activeToolCalls map[int]*sdk.ChatCompletionMessageToolCall + program *tea.Program + conversation *[]sdk.Message + cfg *config.Config + inputModel *internal.ChatInputModel + selectedModel string +} + +func createStreamingClient(cfg *config.Config) (sdk.Client, error) { + baseURL := strings.TrimSuffix(cfg.Gateway.URL, "/") + if !strings.HasSuffix(baseURL, "/v1") { + baseURL += "/v1" + } + + client := sdk.NewClient(&sdk.ClientOptions{ + BaseURL: baseURL, + APIKey: cfg.Gateway.APIKey, + }) + + tools := createSDKTools(cfg) + if len(tools) > 0 { + client = client.WithTools(&tools) + } + + return client, nil +} + +func processFileReferences(input string) (string, error) { + fileRefRegex := regexp.MustCompile(`@([\w\./\-_]+(?:\.[\w]+)?)`) + matches := fileRefRegex.FindAllStringSubmatch(input, -1) + + if len(matches) == 0 { + return input, nil + } + + processedInput := input + + for _, match := range matches { + fullMatch := match[0] + filePath := match[1] + + content, err := readFileForChat(filePath) + if err != nil { + return "", fmt.Errorf("failed to read file '%s': %w", filePath, err) + } + + replacement := fmt.Sprintf("\n\n--- File: %s ---\n%s\n--- End of %s ---\n", filePath, content, filePath) + processedInput = strings.Replace(processedInput, fullMatch, replacement, 1) + } + + return processedInput, nil +} + +func readFileForChat(filePath string) (string, error) { + absPath, err := filepath.Abs(filePath) + if err != nil { + return "", fmt.Errorf("failed to resolve file path: %w", err) + } + + info, err := os.Stat(absPath) + if os.IsNotExist(err) { + return "", fmt.Errorf("file does not exist") + } + if err != nil { + return "", fmt.Errorf("failed to access file: %w", err) + } + + if info.IsDir() { + return "", fmt.Errorf("path is a directory, not a file") + } + + const maxFileSize = 100 * 1024 + if info.Size() > maxFileSize { + return "", fmt.Errorf("file too large (max %d bytes)", maxFileSize) + } + + content, err := os.ReadFile(absPath) + if err != nil { + return "", fmt.Errorf("failed to read file: %w", err) + } + + return string(content), nil +} + +func formatMetricsString(metrics *ChatMetrics) string { + if metrics == nil { + return "" + } + + var parts []string + + duration := metrics.Duration.Round(time.Millisecond) + parts = append(parts, fmt.Sprintf("Time: %v", duration)) + + if metrics.Usage != nil { + if metrics.Usage.PromptTokens > 0 { + parts = append(parts, fmt.Sprintf("Input: %d tokens", metrics.Usage.PromptTokens)) + } + if metrics.Usage.CompletionTokens > 0 { + parts = append(parts, fmt.Sprintf("Output: %d tokens", metrics.Usage.CompletionTokens)) + } + if metrics.Usage.TotalTokens > 0 { + parts = append(parts, fmt.Sprintf("Total: %d tokens", metrics.Usage.TotalTokens)) + } + } + + return strings.Join(parts, " | ") +} + +func sendStreamingChatCompletionToUI(cfg *config.Config, model string, messages []ChatMessage, program *tea.Program, conversation *[]sdk.Message, inputModel *internal.ChatInputModel) (string, []sdk.ChatCompletionMessageToolCall, *ChatMetrics, error) { client, err := createStreamingClient(cfg) if err != nil { return "", nil, nil, err @@ -415,16 +552,18 @@ func sendStreamingChatCompletion(cfg *config.Config, model string, messages []Ch return "", nil, nil, fmt.Errorf("failed to generate content stream: %w", err) } - result := &streamingResult{ + result := &uiStreamingResult{ fullMessage: &strings.Builder{}, firstContent: true, activeToolCalls: make(map[int]*sdk.ChatCompletionMessageToolCall), - spinnerActive: spinnerActive, - mu: mu, + program: program, + conversation: conversation, cfg: cfg, + inputModel: inputModel, + selectedModel: model, } - err = processStreamingEvents(events, result) + err = processStreamingEventsToUI(events, result) if err != nil { return "", nil, nil, err } @@ -440,58 +579,33 @@ func sendStreamingChatCompletion(cfg *config.Config, model string, messages []Ch return result.fullMessage.String(), finalToolCalls, metrics, nil } -type streamingResult struct { - fullMessage *strings.Builder - firstContent bool - usage *sdk.CompletionUsage - toolCalls []sdk.ChatCompletionMessageToolCall - activeToolCalls map[int]*sdk.ChatCompletionMessageToolCall - spinnerActive *bool - mu *sync.Mutex - cfg *config.Config -} - -func createStreamingClient(cfg *config.Config) (sdk.Client, error) { - baseURL := strings.TrimSuffix(cfg.Gateway.URL, "/") - if !strings.HasSuffix(baseURL, "/v1") { - baseURL += "/v1" - } - - client := sdk.NewClient(&sdk.ClientOptions{ - BaseURL: baseURL, - APIKey: cfg.Gateway.APIKey, - }) - - tools := createSDKTools(cfg) - if len(tools) > 0 { - client = client.WithTools(&tools) - } - - return client, nil -} - -func processStreamingEvents(events <-chan sdk.SSEvent, result *streamingResult) error { +func processStreamingEventsToUI(events <-chan sdk.SSEvent, result *uiStreamingResult) error { for event := range events { + // Check for cancellation + if result.inputModel.IsCancelled() { + return fmt.Errorf("generation cancelled by user") + } + if event.Event == nil { continue } switch *event.Event { case sdk.ContentDelta: - if err := handleContentDelta(event, result); err != nil { + if err := handleContentDeltaToUI(event, result); err != nil { return err } case sdk.StreamEnd: - handleStreamEnd(result) + handleStreamEndToUI(result) return nil case "error": - return handleStreamError(event, result) + return handleStreamErrorToUI(event, result) } } return nil } -func handleContentDelta(event sdk.SSEvent, result *streamingResult) error { +func handleContentDeltaToUI(event sdk.SSEvent, result *uiStreamingResult) error { if event.Data == nil { return nil } @@ -502,50 +616,80 @@ func handleContentDelta(event sdk.SSEvent, result *streamingResult) error { } for _, choice := range streamResponse.Choices { - handleContentChoice(choice, result) - handleToolCallsChoice(choice, result) + handleContentChoiceToUI(choice, result) + handleToolCallsChoiceToUI(choice, result) } if streamResponse.Usage != nil { result.usage = streamResponse.Usage + // Update status with token count + result.program.Send(internal.SetStatusMsg{ + Message: fmt.Sprintf("Generating response... (%d tokens)", streamResponse.Usage.TotalTokens), + Spinner: true, + }) } return nil } -func handleContentChoice(choice sdk.ChatCompletionStreamChoice, result *streamingResult) { +func handleContentChoiceToUI(choice sdk.ChatCompletionStreamChoice, result *uiStreamingResult) { if choice.Delta.Content == "" { return } if result.firstContent { - stopSpinner(result) + // Add assistant message to conversation and update UI + assistantMsg := sdk.Message{ + Role: sdk.Assistant, + Content: choice.Delta.Content, + } + *result.conversation = append(*result.conversation, assistantMsg) + + // Update UI with new history using model name + chatHistory := formatConversationForDisplay(*result.conversation, result.selectedModel) + result.program.Send(internal.UpdateHistoryMsg{History: chatHistory}) + result.firstContent = false + } else { + // Update the last message in conversation + if len(*result.conversation) > 0 { + lastIdx := len(*result.conversation) - 1 + (*result.conversation)[lastIdx].Content += choice.Delta.Content + + // Update UI with updated history using model name + chatHistory := formatConversationForDisplay(*result.conversation, result.selectedModel) + result.program.Send(internal.UpdateHistoryMsg{History: chatHistory}) + } } - fmt.Print(choice.Delta.Content) result.fullMessage.WriteString(choice.Delta.Content) } -func handleToolCallsChoice(choice sdk.ChatCompletionStreamChoice, result *streamingResult) { +func handleToolCallsChoiceToUI(choice sdk.ChatCompletionStreamChoice, result *uiStreamingResult) { if len(choice.Delta.ToolCalls) == 0 { return } - if result.firstContent { - stopSpinner(result) - result.firstContent = false - } - for _, deltaToolCall := range choice.Delta.ToolCalls { - handleToolCallDelta(deltaToolCall, result) + handleToolCallDeltaToUI(deltaToolCall, result) } } -func handleToolCallDelta(deltaToolCall sdk.ChatCompletionMessageToolCallChunk, result *streamingResult) { +func handleToolCallDeltaToUI(deltaToolCall sdk.ChatCompletionMessageToolCallChunk, result *uiStreamingResult) { index := deltaToolCall.Index if result.activeToolCalls[index] == nil { + // Ensure we have an assistant message in the conversation when tool calls start + if result.firstContent { + // Add empty assistant message to conversation since tool calls are starting + assistantMsg := sdk.Message{ + Role: sdk.Assistant, + Content: "", + } + *result.conversation = append(*result.conversation, assistantMsg) + result.firstContent = false + } + result.activeToolCalls[index] = &sdk.ChatCompletionMessageToolCall{ Id: deltaToolCall.ID, Type: sdk.ChatCompletionToolType(deltaToolCall.Type), @@ -555,7 +699,10 @@ func handleToolCallDelta(deltaToolCall sdk.ChatCompletionMessageToolCallChunk, r }, } if deltaToolCall.Function.Name != "" { - fmt.Printf("\nšŸ”§ Calling tool: %s", deltaToolCall.Function.Name) + result.program.Send(internal.SetStatusMsg{ + Message: fmt.Sprintf("šŸ”§ Calling tool: %s", deltaToolCall.Function.Name), + Spinner: true, + }) } } @@ -564,21 +711,21 @@ func handleToolCallDelta(deltaToolCall sdk.ChatCompletionMessageToolCallChunk, r } } -func handleStreamEnd(result *streamingResult) { - stopSpinner(result) +func handleStreamEndToUI(result *uiStreamingResult) { var toolCalls []sdk.ChatCompletionMessageToolCall for _, toolCall := range result.activeToolCalls { if toolCall.Function.Name != "" { - fmt.Printf(" with arguments: %s\n", toolCall.Function.Arguments) + result.program.Send(internal.SetStatusMsg{ + Message: fmt.Sprintf("šŸ”§ Tool: %s with arguments: %s", toolCall.Function.Name, toolCall.Function.Arguments), + Spinner: false, + }) toolCalls = append(toolCalls, *toolCall) } } result.toolCalls = toolCalls } -func handleStreamError(event sdk.SSEvent, result *streamingResult) error { - stopSpinner(result) - +func handleStreamErrorToUI(event sdk.SSEvent, result *uiStreamingResult) error { if event.Data == nil { return fmt.Errorf("stream error: unknown error") } @@ -592,116 +739,6 @@ func handleStreamError(event sdk.SSEvent, result *streamingResult) error { return fmt.Errorf("stream error: %s", errResp.Error) } -func stopSpinner(result *streamingResult) { - result.mu.Lock() - *result.spinnerActive = false - result.mu.Unlock() -} - - - - -func showSpinner(active *bool, mu *sync.Mutex) { - spinner := []string{"ā ‹", "ā ™", "ā ¹", "ā ø", "ā ¼", "ā “", "ā ¦", "ā §", "ā ‡", "ā "} - i := 0 - - for { - mu.Lock() - if !*active { - mu.Unlock() - break - } - mu.Unlock() - - fmt.Printf("%s", spinner[i%len(spinner)]) - fmt.Printf("\b") - i++ - time.Sleep(100 * time.Millisecond) - } - - fmt.Print(" \b") -} - -func processFileReferences(input string) (string, error) { - fileRefRegex := regexp.MustCompile(`@([\w\./\-_]+(?:\.[\w]+)?)`) - matches := fileRefRegex.FindAllStringSubmatch(input, -1) - - if len(matches) == 0 { - return input, nil - } - - processedInput := input - - for _, match := range matches { - fullMatch := match[0] - filePath := match[1] - - content, err := readFileForChat(filePath) - if err != nil { - return "", fmt.Errorf("failed to read file '%s': %w", filePath, err) - } - - replacement := fmt.Sprintf("\n\n--- File: %s ---\n%s\n--- End of %s ---\n", filePath, content, filePath) - processedInput = strings.Replace(processedInput, fullMatch, replacement, 1) - } - - return processedInput, nil -} - -func readFileForChat(filePath string) (string, error) { - absPath, err := filepath.Abs(filePath) - if err != nil { - return "", fmt.Errorf("failed to resolve file path: %w", err) - } - - info, err := os.Stat(absPath) - if os.IsNotExist(err) { - return "", fmt.Errorf("file does not exist") - } - if err != nil { - return "", fmt.Errorf("failed to access file: %w", err) - } - - if info.IsDir() { - return "", fmt.Errorf("path is a directory, not a file") - } - - const maxFileSize = 100 * 1024 - if info.Size() > maxFileSize { - return "", fmt.Errorf("file too large (max %d bytes)", maxFileSize) - } - - content, err := os.ReadFile(absPath) - if err != nil { - return "", fmt.Errorf("failed to read file: %w", err) - } - - return string(content), nil -} - -func displayChatMetrics(metrics *ChatMetrics) { - if metrics == nil { - return - } - - fmt.Printf("\nšŸ“Š ") - - duration := metrics.Duration.Round(time.Millisecond) - fmt.Printf("Time: %v", duration) - - if metrics.Usage != nil { - if metrics.Usage.PromptTokens > 0 { - fmt.Printf(" | Input: %d tokens", metrics.Usage.PromptTokens) - } - if metrics.Usage.CompletionTokens > 0 { - fmt.Printf(" | Output: %d tokens", metrics.Usage.CompletionTokens) - } - if metrics.Usage.TotalTokens > 0 { - fmt.Printf(" | Total: %d tokens", metrics.Usage.TotalTokens) - } - } -} - func compactConversation(conversation []sdk.Message, selectedModel string) error { cfg, err := config.LoadConfig("") if err != nil { @@ -791,6 +828,45 @@ func compactConversation(conversation []sdk.Message, selectedModel string) error return nil } +func formatConversationForDisplay(conversation []sdk.Message, selectedModel string) []string { + var history []string + + for _, msg := range conversation { + var role string + var content string + + switch msg.Role { + case sdk.User: + role = "šŸ‘¤ You" + case sdk.Assistant: + role = fmt.Sprintf("šŸ¤– %s", selectedModel) + case sdk.System: + role = "āš™ļø System" + case sdk.Tool: + role = "šŸ”§ Tool" + default: + role = string(msg.Role) + } + + content = msg.Content + + content = strings.ReplaceAll(content, "\r\n", "\n") + content = strings.ReplaceAll(content, "\r", "\n") + + if content != "" { + history = append(history, fmt.Sprintf("%s: %s", role, content)) + } + + if msg.ToolCalls != nil && len(*msg.ToolCalls) > 0 { + for _, toolCall := range *msg.ToolCalls { + history = append(history, fmt.Sprintf("šŸ”§ Tool Call: %s", toolCall.Function.Name)) + } + } + } + + return history +} + func init() { rootCmd.AddCommand(chatCmd) } diff --git a/cmd/chat_test.go b/cmd/chat_test.go new file mode 100644 index 00000000..963c902d --- /dev/null +++ b/cmd/chat_test.go @@ -0,0 +1,610 @@ +package cmd + +import ( + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + sdk "github.com/inference-gateway/sdk" +) + +func TestHandleChatCommands(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + initialConv []sdk.Message + selectedModel string + availableModels []string + expectedHandled bool + expectedConvLen int + }{ + { + name: "clear command should clear conversation", + input: "/clear", + initialConv: []sdk.Message{{Role: sdk.User, Content: "test"}}, + selectedModel: "gpt-3.5-turbo", + availableModels: []string{"gpt-3.5-turbo"}, + expectedHandled: true, + expectedConvLen: 0, + }, + { + name: "history command should be handled", + input: "/history", + initialConv: []sdk.Message{{Role: sdk.User, Content: "test"}}, + selectedModel: "gpt-3.5-turbo", + availableModels: []string{"gpt-3.5-turbo"}, + expectedHandled: true, + expectedConvLen: 1, + }, + { + name: "models command should be handled", + input: "/models", + initialConv: []sdk.Message{}, + selectedModel: "gpt-3.5-turbo", + availableModels: []string{"gpt-3.5-turbo", "gpt-4"}, + expectedHandled: true, + expectedConvLen: 0, + }, + { + name: "help command should be handled", + input: "/help", + initialConv: []sdk.Message{}, + selectedModel: "gpt-3.5-turbo", + availableModels: []string{"gpt-3.5-turbo"}, + expectedHandled: true, + expectedConvLen: 0, + }, + { + name: "question mark help should be handled", + input: "?", + initialConv: []sdk.Message{}, + selectedModel: "gpt-3.5-turbo", + availableModels: []string{"gpt-3.5-turbo"}, + expectedHandled: true, + expectedConvLen: 0, + }, + { + name: "compact command should be handled", + input: "/compact", + initialConv: []sdk.Message{}, + selectedModel: "gpt-3.5-turbo", + availableModels: []string{"gpt-3.5-turbo"}, + expectedHandled: true, + expectedConvLen: 0, + }, + { + name: "unknown command should be handled with error", + input: "/unknown", + initialConv: []sdk.Message{}, + selectedModel: "gpt-3.5-turbo", + availableModels: []string{"gpt-3.5-turbo"}, + expectedHandled: true, + expectedConvLen: 0, + }, + { + name: "regular text should not be handled as command", + input: "Hello world", + initialConv: []sdk.Message{}, + selectedModel: "gpt-3.5-turbo", + availableModels: []string{"gpt-3.5-turbo"}, + expectedHandled: false, + expectedConvLen: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conversation := make([]sdk.Message, len(tt.initialConv)) + copy(conversation, tt.initialConv) + selectedModel := tt.selectedModel + + handled := handleChatCommands(tt.input, &conversation, &selectedModel, tt.availableModels) + + if handled != tt.expectedHandled { + t.Errorf("handleChatCommands() handled = %v, want %v", handled, tt.expectedHandled) + } + + if len(conversation) != tt.expectedConvLen { + t.Errorf("conversation length = %v, want %v", len(conversation), tt.expectedConvLen) + } + }) + } +} + +func TestProcessFileReferences(t *testing.T) { + t.Parallel() + + tempDir, err := os.MkdirTemp("", "chat_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer func() { _ = os.RemoveAll(tempDir) }() + + testFile := filepath.Join(tempDir, "test.txt") + testContent := "This is test content" + err = os.WriteFile(testFile, []byte(testContent), 0644) + if err != nil { + t.Fatalf("Failed to create test file: %v", err) + } + + tests := []struct { + name string + input string + wantErr bool + expectsFile bool + description string + }{ + { + name: "no file references", + input: "Hello world", + wantErr: false, + expectsFile: false, + description: "Input without file references should pass through unchanged", + }, + { + name: "single file reference - existing file", + input: "Look at @" + testFile + " please", + wantErr: false, + expectsFile: true, + description: "Input with @filename should read existing file", + }, + { + name: "single file reference - nonexistent file", + input: "Look at @nonexistent.yaml please", + wantErr: true, + expectsFile: true, + description: "Input with @filename should error for missing file", + }, + { + name: "file reference with relative path", + input: "Check @./nonexistent.go", + wantErr: true, + expectsFile: true, + description: "File references with paths should work", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := processFileReferences(tt.input) + + if tt.wantErr && err == nil { + t.Errorf("processFileReferences() expected error but got none") + } + + if !tt.wantErr && err != nil { + t.Errorf("processFileReferences() unexpected error: %v", err) + } + + if !tt.expectsFile && result != tt.input { + t.Errorf("processFileReferences() = %v, want %v", result, tt.input) + } + + if !tt.wantErr && tt.expectsFile { + if !containsText(result, testContent) { + t.Errorf("processFileReferences() result should contain file content") + } + } + }) + } +} + +func TestReadFileForChat(t *testing.T) { + t.Parallel() + + tempDir, err := os.MkdirTemp("", "chat_file_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer func() { _ = os.RemoveAll(tempDir) }() + + validFile := filepath.Join(tempDir, "valid.txt") + testContent := "Hello, world!\nThis is a test file." + err = os.WriteFile(validFile, []byte(testContent), 0644) + if err != nil { + t.Fatalf("Failed to create valid test file: %v", err) + } + + largeFile := filepath.Join(tempDir, "large.txt") + largeContent := make([]byte, 101*1024) // 101KB + for i := range largeContent { + largeContent[i] = 'A' + } + err = os.WriteFile(largeFile, largeContent, 0644) + if err != nil { + t.Fatalf("Failed to create large test file: %v", err) + } + + tests := []struct { + name string + filePath string + wantErr bool + expectEmpty bool + }{ + { + name: "valid file", + filePath: validFile, + wantErr: false, + }, + { + name: "nonexistent file", + filePath: filepath.Join(tempDir, "nonexistent.txt"), + wantErr: true, + }, + { + name: "directory instead of file", + filePath: tempDir, + wantErr: true, + }, + { + name: "file too large", + filePath: largeFile, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + content, err := readFileForChat(tt.filePath) + + if tt.wantErr && err == nil { + t.Errorf("readFileForChat() expected error but got none") + } + + if !tt.wantErr && err != nil { + t.Errorf("readFileForChat() unexpected error: %v", err) + } + + if !tt.wantErr && !tt.expectEmpty && content == "" { + t.Errorf("readFileForChat() expected content but got empty string") + } + + if !tt.wantErr && tt.filePath == validFile && content != testContent { + t.Errorf("readFileForChat() = %q, want %q", content, testContent) + } + }) + } +} + +func TestFormatConversationForDisplay(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + conversation []sdk.Message + selectedModel string + expectedLen int + expectRoles []string + }{ + { + name: "empty conversation", + conversation: []sdk.Message{}, + selectedModel: "gpt-3.5-turbo", + expectedLen: 0, + expectRoles: []string{}, + }, + { + name: "single user message", + conversation: []sdk.Message{ + {Role: sdk.User, Content: "Hello"}, + }, + selectedModel: "gpt-3.5-turbo", + expectedLen: 1, + expectRoles: []string{"šŸ‘¤ You"}, + }, + { + name: "user and assistant messages", + conversation: []sdk.Message{ + {Role: sdk.User, Content: "Hello"}, + {Role: sdk.Assistant, Content: "Hi there!"}, + }, + selectedModel: "gpt-3.5-turbo", + expectedLen: 2, + expectRoles: []string{"šŸ‘¤ You", "šŸ¤– gpt-3.5-turbo"}, + }, + { + name: "system message", + conversation: []sdk.Message{ + {Role: sdk.System, Content: "You are a helpful assistant"}, + }, + selectedModel: "gpt-4", + expectedLen: 1, + expectRoles: []string{"āš™ļø System"}, + }, + { + name: "tool message", + conversation: []sdk.Message{ + {Role: sdk.Tool, Content: "Tool result"}, + }, + selectedModel: "gpt-4", + expectedLen: 1, + expectRoles: []string{"šŸ”§ Tool"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatConversationForDisplay(tt.conversation, tt.selectedModel) + + if len(result) != tt.expectedLen { + t.Errorf("formatConversationForDisplay() length = %v, want %v", len(result), tt.expectedLen) + } + + for i, expectedRole := range tt.expectRoles { + if i < len(result) { + if !containsRole(result[i], expectedRole) { + t.Errorf("formatConversationForDisplay()[%d] = %v, expected to contain %v", i, result[i], expectedRole) + } + } + } + }) + } +} + +func TestFormatMetricsString(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metrics *ChatMetrics + expected string + contains []string + }{ + { + name: "nil metrics", + metrics: nil, + expected: "", + }, + { + name: "metrics with duration only", + metrics: &ChatMetrics{ + Duration: 1000000000, // 1 second in nanoseconds + }, + contains: []string{"Time: 1s"}, + }, + { + name: "metrics with usage", + metrics: &ChatMetrics{ + Duration: 500000000, // 500ms + Usage: &sdk.CompletionUsage{ + PromptTokens: 100, + CompletionTokens: 50, + TotalTokens: 150, + }, + }, + contains: []string{"Time:", "Input: 100 tokens", "Output: 50 tokens", "Total: 150 tokens"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatMetricsString(tt.metrics) + + if tt.expected != "" && result != tt.expected { + t.Errorf("formatMetricsString() = %v, want %v", result, tt.expected) + } + + for _, contains := range tt.contains { + if result == "" || !containsText(result, contains) { + t.Errorf("formatMetricsString() = %v, expected to contain %v", result, contains) + } + } + }) + } +} + +func TestCompactConversation(t *testing.T) { + tempDir, err := os.MkdirTemp("", "compact_test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer func() { _ = os.RemoveAll(tempDir) }() + + originalDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get current directory: %v", err) + } + + tests := []struct { + name string + conversation []sdk.Message + model string + expectErr bool + expectFile bool + }{ + { + name: "empty conversation", + conversation: []sdk.Message{}, + model: "gpt-3.5-turbo", + expectErr: false, + expectFile: false, + }, + { + name: "single message conversation", + conversation: []sdk.Message{ + {Role: sdk.User, Content: "Hello world"}, + }, + model: "gpt-4", + expectErr: false, + expectFile: true, + }, + { + name: "multi-message conversation with tool calls", + conversation: []sdk.Message{ + {Role: sdk.User, Content: "What's the weather?"}, + { + Role: sdk.Assistant, + Content: "I'll check the weather for you.", + ToolCalls: &[]sdk.ChatCompletionMessageToolCall{ + { + Id: "call_123", + Type: "function", + Function: sdk.ChatCompletionMessageToolCallFunction{ + Name: "get_weather", + Arguments: `{"location": "New York"}`, + }, + }, + }, + }, + {Role: sdk.Tool, Content: "Sunny, 75°F", ToolCallId: &[]string{"call_123"}[0]}, + {Role: sdk.Assistant, Content: "It's sunny and 75°F in New York!"}, + }, + model: "gpt-4", + expectErr: false, + expectFile: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + testTempDir, err := os.MkdirTemp(tempDir, "test_*") + if err != nil { + t.Fatalf("Failed to create test temp dir: %v", err) + } + + originalDir, _ := os.Getwd() + configFile := filepath.Join(testTempDir, ".infer", "config.yaml") + err = os.MkdirAll(filepath.Dir(configFile), 0755) + if err != nil { + t.Fatalf("Failed to create config dir: %v", err) + } + + configData := `compact: + output_dir: "` + testTempDir + `"` + + err = os.WriteFile(configFile, []byte(configData), 0644) + if err != nil { + t.Fatalf("Failed to create config file: %v", err) + } + + err = os.Chdir(testTempDir) + if err != nil { + t.Fatalf("Failed to change directory: %v", err) + } + defer func() { + if chdirErr := os.Chdir(originalDir); chdirErr != nil { + t.Errorf("Failed to restore original directory: %v", chdirErr) + } + }() + + err = compactConversation(tt.conversation, tt.model) + + if tt.expectErr && err == nil { + t.Errorf("compactConversation() expected error but got none") + } + + if !tt.expectErr && err != nil { + t.Errorf("compactConversation() unexpected error: %v", err) + } + + if tt.expectFile { + files, err := filepath.Glob(filepath.Join(testTempDir, "chat-export-*.md")) + if err != nil { + t.Errorf("Failed to list exported files: %v", err) + } + if len(files) == 0 { + t.Errorf("Expected export file to be created but none found") + } else { + content, err := os.ReadFile(files[0]) + if err != nil { + t.Errorf("Failed to read export file: %v", err) + } else { + contentStr := string(content) + if !strings.Contains(contentStr, "# Chat Session Export") { + t.Errorf("Export file should contain header") + } + if !strings.Contains(contentStr, tt.model) { + t.Errorf("Export file should contain model name") + } + if len(tt.conversation) > 0 && !strings.Contains(contentStr, tt.conversation[0].Content) { + t.Errorf("Export file should contain conversation content") + } + } + } + } + }) + } + + if err := os.Chdir(originalDir); err != nil { + t.Errorf("Failed to restore original directory after all tests: %v", err) + } +} + +// Helper functions +func containsRole(text, role string) bool { + return len(text) > 0 && text[:len(role)] == role +} + +func containsText(text, substring string) bool { + return len(text) >= len(substring) && + text != "" && substring != "" && + contains(text, substring) +} + +func contains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func TestChatRequest(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + request ChatRequest + expected ChatRequest + }{ + { + name: "basic chat request", + request: ChatRequest{ + Model: "gpt-3.5-turbo", + Messages: []ChatMessage{ + {Role: sdk.User, Content: "Hello"}, + }, + Stream: false, + }, + expected: ChatRequest{ + Model: "gpt-3.5-turbo", + Messages: []ChatMessage{ + {Role: sdk.User, Content: "Hello"}, + }, + Stream: false, + }, + }, + { + name: "streaming chat request", + request: ChatRequest{ + Model: "gpt-4", + Messages: []ChatMessage{ + {Role: sdk.System, Content: "You are helpful"}, + {Role: sdk.User, Content: "Hi"}, + }, + Stream: true, + }, + expected: ChatRequest{ + Model: "gpt-4", + Messages: []ChatMessage{ + {Role: sdk.System, Content: "You are helpful"}, + {Role: sdk.User, Content: "Hi"}, + }, + Stream: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !reflect.DeepEqual(tt.request, tt.expected) { + t.Errorf("ChatRequest = %v, want %v", tt.request, tt.expected) + } + }) + } +} diff --git a/go.mod b/go.mod index b00eb77e..13217905 100644 --- a/go.mod +++ b/go.mod @@ -3,20 +3,38 @@ module github.com/inference-gateway/cli go 1.24.5 require ( + github.com/charmbracelet/bubbletea v1.3.6 github.com/chzyer/readline v1.5.1 github.com/inference-gateway/sdk v1.11.0 - github.com/manifoldco/promptui v0.9.0 github.com/spf13/cobra v1.9.1 gopkg.in/yaml.v3 v3.0.1 ) require ( + github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect + github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect + github.com/charmbracelet/lipgloss v1.1.0 // indirect + github.com/charmbracelet/x/ansi v0.9.3 // indirect + github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect + github.com/charmbracelet/x/term v0.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect github.com/go-resty/resty/v2 v2.16.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/lucasb-eyer/go-colorful v1.2.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-localereader v0.0.1 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect + github.com/muesli/cancelreader v0.2.2 // indirect + github.com/muesli/termenv v0.16.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/spf13/pflag v1.0.6 // indirect github.com/stretchr/testify v1.10.0 // indirect + github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/net v0.33.0 // indirect - golang.org/x/sys v0.28.0 // indirect + golang.org/x/sync v0.15.0 // indirect + golang.org/x/sys v0.33.0 // indirect + golang.org/x/text v0.21.0 // indirect ) diff --git a/go.sum b/go.sum index abafaeb3..54bb2983 100644 --- a/go.sum +++ b/go.sum @@ -1,25 +1,53 @@ -github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= +github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= +github.com/charmbracelet/bubbletea v1.3.6 h1:VkHIxPJQeDt0aFJIsVxw8BQdh/F/L2KKZGsK6et5taU= +github.com/charmbracelet/bubbletea v1.3.6/go.mod h1:oQD9VCRQFF8KplacJLo28/jofOI2ToOfGYeFgBBxHOc= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs= +github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk= +github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY= +github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30= +github.com/charmbracelet/x/ansi v0.9.3 h1:BXt5DHS/MKF+LjuK4huWrC6NCvHtexww7dMayh6GXd0= +github.com/charmbracelet/x/ansi v0.9.3/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE= +github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8= +github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs= +github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ= +github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg= github.com/chzyer/logex v1.2.1 h1:XHDu3E6q+gdHgsdTPH6ImJMIp436vR6MPtH8gP05QzM= github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= -github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/readline v1.5.1 h1:upd/6fQk4src78LMRzh5vItIt361/o4uq553V8B5sGI= github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= -github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/chzyer/test v1.0.0 h1:p3BQDXSxOhOG0P9z6/hGnII4LGiEPOYBhs8asl/fC04= github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= +github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= github.com/go-resty/resty/v2 v2.16.3 h1:zacNT7lt4b8M/io2Ahj6yPypL7bqx9n1iprfQuodV+E= github.com/go-resty/resty/v2 v2.16.3/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ07xAwp/fiA= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/inference-gateway/sdk v1.11.0 h1:eeq/VE8/2m+kFajwXGOFnDNvskkyfAwFZDxOLiIEv2A= github.com/inference-gateway/sdk v1.11.0/go.mod h1:3TTD7Kbr7FRt+9ZbCPAm3u0tXUIWx7flZuwrRgZgrdk= -github.com/manifoldco/promptui v0.9.0 h1:3V4HzJk1TtXW1MTZMP7mdlwbBpIinw3HztaIlYthEiA= -github.com/manifoldco/promptui v0.9.0/go.mod h1:ka04sppxSGFAtxX0qhlYQjISsg9mR4GWtQEhdbn6Pgg= +github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= +github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4= +github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI= +github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo= +github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= +github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= +github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc= +github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= @@ -27,12 +55,21 @@ github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= +github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= +golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561 h1:MDc5xs78ZrZr3HMQugiXOAkSZtfTpbJLDr/lwfgO53E= +golang.org/x/exp v0.0.0-20220909182711-5c715a9e8561/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= -golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8= +golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/internal/chatinput.go b/internal/chatinput.go new file mode 100644 index 00000000..9f816ea2 --- /dev/null +++ b/internal/chatinput.go @@ -0,0 +1,474 @@ +package internal + +import ( + "fmt" + "strings" + "time" + + "github.com/charmbracelet/bubbletea" +) + +// UpdateHistoryMsg is used to update chat history +type UpdateHistoryMsg struct { + History []string +} + +// SetStatusMsg is used to set status +type SetStatusMsg struct { + Message string + Spinner bool +} + +// ApprovalRequestMsg is used to request approval for a command +type ApprovalRequestMsg struct { + Command string +} + +// ChatInputModel represents a persistent chat input interface +type ChatInputModel struct { + textarea []string + cursor int + lineIndex int + width int + height int + chatHistory []string + historyScroll int + focusOnHistory bool + statusMessage string + showSpinner bool + spinnerFrame int + inputSubmitted bool + lastInput string + startTime time.Time + showTimer bool + cancelled bool + quit bool + approvalPending bool + approvalCommand string + approvalResponse int // 0=deny, 1=allow, 2=allow all + approvalSelected int // Currently selected option in dropdown +} + +// SpinnerTick represents a spinner animation tick +type SpinnerTick struct{} + +// NewChatInputModel creates a new chat input model +func NewChatInputModel() *ChatInputModel { + return &ChatInputModel{ + textarea: []string{""}, + cursor: 0, + lineIndex: 0, + width: 80, + height: 20, + chatHistory: []string{}, + historyScroll: 0, + focusOnHistory: false, + statusMessage: "", + showSpinner: false, + spinnerFrame: 0, + inputSubmitted: false, + lastInput: "", + startTime: time.Now(), + showTimer: false, + cancelled: false, + quit: false, + approvalPending: false, + approvalCommand: "", + approvalResponse: -1, + approvalSelected: 0, + } +} + +func (m *ChatInputModel) Init() tea.Cmd { + return nil +} + +func (m *ChatInputModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case UpdateHistoryMsg: + m.chatHistory = msg.History + return m, nil + + case ApprovalRequestMsg: + m.approvalPending = true + m.approvalCommand = msg.Command + m.approvalResponse = -1 + m.approvalSelected = 0 // Start with first option selected + return m, nil + + case SetStatusMsg: + m.statusMessage = msg.Message + if msg.Spinner { + m.showSpinner = true + m.showTimer = true + m.startTime = time.Now() + m.spinnerFrame = 0 + return m, tea.Tick(time.Millisecond*100, func(t time.Time) tea.Msg { + return SpinnerTick{} + }) + } else { + m.showSpinner = false + m.showTimer = false + } + return m, nil + + case SpinnerTick: + if m.showSpinner { + m.spinnerFrame++ + return m, tea.Tick(time.Millisecond*100, func(t time.Time) tea.Msg { + return SpinnerTick{} + }) + } + return m, nil + + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + return m, nil + + case tea.KeyMsg: + if m.approvalPending { + switch msg.String() { + case "up": + if m.approvalSelected > 0 { + m.approvalSelected-- + } + return m, nil + case "down": + if m.approvalSelected < 2 { + m.approvalSelected++ + } + return m, nil + case "enter": + switch m.approvalSelected { + case 0: + m.approvalResponse = 1 // Allow + case 1: + m.approvalResponse = 2 // Allow all + case 2: + m.approvalResponse = 0 // Deny + } + m.approvalPending = false + return m, nil + case "esc", "ctrl+c": + m.approvalResponse = 0 + m.approvalPending = false + return m, nil + } + return m, nil + } + + switch msg.String() { + case "ctrl+c": + m.quit = true + return m, tea.Quit + + case "esc": + if m.showSpinner { + m.cancelled = true + m.showSpinner = false + m.showTimer = false + m.statusMessage = "āŒ Generation cancelled by user" + } + return m, nil + + case "ctrl+d": + if len(m.textarea) == 1 && m.textarea[0] == "" { + return m, nil + } + m.lastInput = strings.Join(m.textarea, "\n") + m.inputSubmitted = true + return m, nil + + case "tab": + m.focusOnHistory = !m.focusOnHistory + return m, nil + + case "enter": + if !m.focusOnHistory { + currentLine := m.textarea[m.lineIndex] + before := currentLine[:m.cursor] + after := currentLine[m.cursor:] + m.textarea[m.lineIndex] = before + " " + after + m.cursor++ + } + return m, nil + + case "backspace": + if !m.focusOnHistory { + if m.cursor > 0 { + currentLine := m.textarea[m.lineIndex] + before := currentLine[:m.cursor-1] + after := currentLine[m.cursor:] + m.textarea[m.lineIndex] = before + after + m.cursor-- + } + } + return m, nil + + case "up": + if m.focusOnHistory && len(m.chatHistory) > 0 { + if m.historyScroll > 0 { + m.historyScroll-- + } + } + return m, nil + + case "down": + if m.focusOnHistory && len(m.chatHistory) > 0 { + maxVisibleLines := m.getHistoryVisibleLines() + maxScroll := max(0, len(m.chatHistory)-maxVisibleLines) + if m.historyScroll < maxScroll { + m.historyScroll++ + } + } + return m, nil + + case "left": + if !m.focusOnHistory { + if m.cursor > 0 { + m.cursor-- + } + } + return m, nil + + case "right": + if !m.focusOnHistory { + currentText := strings.Join(m.textarea, " ") + if m.cursor < len(currentText) { + m.cursor++ + } + } + return m, nil + + default: + if !m.focusOnHistory && len(msg.String()) == 1 && msg.String()[0] >= 32 { + char := msg.String() + currentLine := m.textarea[m.lineIndex] + before := currentLine[:m.cursor] + after := currentLine[m.cursor:] + m.textarea[m.lineIndex] = before + char + after + m.cursor++ + } + return m, nil + } + } + + return m, nil +} + +func (m *ChatInputModel) getHistoryVisibleLines() int { + inputAreaHeight := 3 + statusAreaHeight := 3 + messagesHeight := m.height - inputAreaHeight - statusAreaHeight + return max(0, messagesHeight) +} + +func (m *ChatInputModel) View() string { + var b strings.Builder + + inputAreaHeight := 3 + statusAreaHeight := 3 + + messagesHeight := m.height - inputAreaHeight - statusAreaHeight + + if messagesHeight > 0 { + maxVisibleLines := messagesHeight + + var startIdx, endIdx int + + if !m.focusOnHistory { + if len(m.chatHistory) <= maxVisibleLines { + startIdx = 0 + endIdx = len(m.chatHistory) + } else { + startIdx = len(m.chatHistory) - maxVisibleLines + endIdx = len(m.chatHistory) + } + } else { + startIdx = m.historyScroll + endIdx = min(len(m.chatHistory), startIdx+maxVisibleLines) + } + + displayedLines := 0 + + linesShown := endIdx - startIdx + emptyLinesAtTop := maxVisibleLines - linesShown + for i := 0; i < emptyLinesAtTop; i++ { + b.WriteString("\n") + displayedLines++ + } + + for i := startIdx; i < endIdx && displayedLines < maxVisibleLines; i++ { + line := m.chatHistory[i] + b.WriteString(line + "\n") + displayedLines++ + } + + for displayedLines < maxVisibleLines { + b.WriteString("\n") + displayedLines++ + } + } + + b.WriteString("\n") + + if m.approvalPending { + b.WriteString("āš ļø Command execution approval required:\n") + b.WriteString(fmt.Sprintf("Command: %s\n\n", m.approvalCommand)) + + options := []string{ + "Yes - Execute this command", + "Yes, and don't ask again - Execute this and all future commands", + "No - Cancel command execution", + } + + for i, option := range options { + if i == m.approvalSelected { + b.WriteString(fmt.Sprintf("ā–¶ \033[36;1m%s\033[0m\n", option)) + } else { + b.WriteString(fmt.Sprintf(" %s\n", option)) + } + } + + b.WriteString("\nUse ↑↓ arrows to navigate, Enter to select, Esc to cancel\n") + b.WriteString(strings.Repeat("─", m.width) + "\n") + } else { + statusLine := "" + if m.showSpinner { + spinner := []string{"ā ‹", "ā ™", "ā ¹", "ā ø", "ā ¼", "ā “", "ā ¦", "ā §", "ā ‡", "ā "} + spinnerChar := spinner[m.spinnerFrame%len(spinner)] + + if m.showTimer { + elapsed := time.Since(m.startTime) + seconds := elapsed.Seconds() + statusLine = fmt.Sprintf("%s %s (%.1fs) - Press Esc to cancel", spinnerChar, m.statusMessage, seconds) + } else { + statusLine = fmt.Sprintf("%s %s - Press Esc to cancel", spinnerChar, m.statusMessage) + } + } else if m.statusMessage != "" { + statusLine = m.statusMessage + } + + if statusLine != "" { + b.WriteString(fmt.Sprintf("ā„¹ļø %s\n", statusLine)) + } else { + b.WriteString("\n") + } + b.WriteString(strings.Repeat("─", m.width) + "\n") + } + + currentText := strings.Join(m.textarea, " ") + if len(currentText) > 0 { + if m.cursor <= len(currentText) { + before := currentText[:min(m.cursor, len(currentText))] + after := currentText[min(m.cursor, len(currentText)):] + b.WriteString(fmt.Sprintf("> %s│%s", before, after)) + } else { + b.WriteString(fmt.Sprintf("> %s│", currentText)) + } + } else { + b.WriteString("> │") + } + + b.WriteString("\n\n\033[90mPress Ctrl+D to send message, Ctrl+C to exit\033[0m") + + return b.String() +} + +// UpdateHistory updates the chat history +func (m *ChatInputModel) UpdateHistory(history []string) { + m.chatHistory = history +} + +// SetStatusMessage sets a status message +func (m *ChatInputModel) SetStatusMessage(message string) { + m.statusMessage = message + m.showSpinner = false + m.showTimer = false +} + +// SetSpinnerMessage sets a status message with spinner +func (m *ChatInputModel) SetSpinnerMessage(message string) tea.Cmd { + m.statusMessage = message + m.showSpinner = true + m.showTimer = true + m.startTime = time.Now() + m.spinnerFrame = 0 + return tea.Tick(time.Millisecond*100, func(t time.Time) tea.Msg { + return SpinnerTick{} + }) +} + +// ClearStatus clears the status message +func (m *ChatInputModel) ClearStatus() { + m.statusMessage = "" + m.showSpinner = false + m.showTimer = false +} + +// HasInput returns true if there's input ready to be processed +func (m *ChatInputModel) HasInput() bool { + return m.inputSubmitted +} + +// GetInput returns the submitted input and clears the flag +func (m *ChatInputModel) GetInput() string { + if m.inputSubmitted { + input := m.lastInput + m.inputSubmitted = false + // Clear the textarea for next input + m.textarea = []string{""} + m.cursor = 0 + m.lineIndex = 0 + return input + } + return "" +} + +// IsCancelled returns true if generation was cancelled +func (m *ChatInputModel) IsCancelled() bool { + return m.cancelled +} + +// ResetCancellation resets the cancellation flag +func (m *ChatInputModel) ResetCancellation() { + m.cancelled = false +} + +// IsQuitRequested returns true if the user requested to quit +func (m *ChatInputModel) IsQuitRequested() bool { + return m.quit +} + +// IsApprovalPending returns true if approval is pending +func (m *ChatInputModel) IsApprovalPending() bool { + return m.approvalPending +} + +// GetApprovalResponse returns the approval response (-1=none, 0=deny, 1=allow, 2=allow all) +func (m *ChatInputModel) GetApprovalResponse() int { + return m.approvalResponse +} + +// ResetApproval resets the approval state +func (m *ChatInputModel) ResetApproval() { + m.approvalPending = false + m.approvalCommand = "" + m.approvalResponse = -1 + m.approvalSelected = 0 +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/internal/chatinput_test.go b/internal/chatinput_test.go new file mode 100644 index 00000000..420c3d60 --- /dev/null +++ b/internal/chatinput_test.go @@ -0,0 +1,437 @@ +package internal + +import ( + "strings" + "testing" + + "github.com/charmbracelet/bubbletea" +) + +func TestNewChatInputModel(t *testing.T) { + model := NewChatInputModel() + + if model == nil { + t.Fatal("NewChatInputModel() returned nil") + } + + if len(model.textarea) != 1 || model.textarea[0] != "" { + t.Errorf("Expected textarea to be initialized with empty string, got %v", model.textarea) + } + + if model.cursor != 0 { + t.Errorf("Expected cursor to be 0, got %d", model.cursor) + } + + if model.lineIndex != 0 { + t.Errorf("Expected lineIndex to be 0, got %d", model.lineIndex) + } + + if model.inputSubmitted { + t.Errorf("Expected inputSubmitted to be false initially") + } + + if model.cancelled { + t.Errorf("Expected cancelled to be false initially") + } + + if model.quit { + t.Errorf("Expected quit to be false initially") + } + + if model.approvalPending { + t.Errorf("Expected approvalPending to be false initially") + } +} + +func TestChatInputModel_UpdateHistory(t *testing.T) { + model := NewChatInputModel() + testHistory := []string{"User: Hello", "Assistant: Hi there!"} + + _, cmd := model.Update(UpdateHistoryMsg{History: testHistory}) + + if cmd != nil { + t.Errorf("Expected no command from UpdateHistoryMsg, got %v", cmd) + } + + if !equalStringSlices(model.chatHistory, testHistory) { + t.Errorf("Expected chatHistory to be %v, got %v", testHistory, model.chatHistory) + } +} + +func TestChatInputModel_SetStatus(t *testing.T) { + model := NewChatInputModel() + + tests := []struct { + name string + message string + spinner bool + expectSpinner bool + expectTimer bool + expectCommand bool + }{ + { + name: "status without spinner", + message: "Ready", + spinner: false, + expectSpinner: false, + expectTimer: false, + expectCommand: false, + }, + { + name: "status with spinner", + message: "Processing...", + spinner: true, + expectSpinner: true, + expectTimer: true, + expectCommand: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, cmd := model.Update(SetStatusMsg{Message: tt.message, Spinner: tt.spinner}) + + if model.statusMessage != tt.message { + t.Errorf("Expected statusMessage to be %q, got %q", tt.message, model.statusMessage) + } + + if model.showSpinner != tt.expectSpinner { + t.Errorf("Expected showSpinner to be %v, got %v", tt.expectSpinner, model.showSpinner) + } + + if model.showTimer != tt.expectTimer { + t.Errorf("Expected showTimer to be %v, got %v", tt.expectTimer, model.showTimer) + } + + if (cmd != nil) != tt.expectCommand { + t.Errorf("Expected command existence to be %v, got %v", tt.expectCommand, cmd != nil) + } + }) + } +} + +func TestChatInputModel_KeyHandling(t *testing.T) { + tests := []struct { + name string + initialText string + initialCursor int + key string + expectedText string + expectedCursor int + expectSubmitted bool + expectCancelled bool + expectQuit bool + }{ + { + name: "character input", + initialText: "", + initialCursor: 0, + key: "h", + expectedText: "h", + expectedCursor: 1, + }, + { + name: "backspace removes character", + initialText: "hello", + initialCursor: 4, + key: "backspace", + expectedText: "helo", + expectedCursor: 3, + }, + { + name: "backspace at start does nothing", + initialText: "hello", + initialCursor: 0, + key: "backspace", + expectedText: "hello", + expectedCursor: 0, + }, + { + name: "left arrow moves cursor", + initialText: "hello", + initialCursor: 5, + key: "left", + expectedText: "hello", + expectedCursor: 4, + }, + { + name: "left arrow at start does nothing", + initialText: "hello", + initialCursor: 0, + key: "left", + expectedText: "hello", + expectedCursor: 0, + }, + { + name: "ctrl+d submits input", + initialText: "hello", + initialCursor: 5, + key: "ctrl+d", + expectedText: "hello", + expectedCursor: 5, + expectSubmitted: true, + }, + { + name: "esc cancels generation", + initialText: "hello", + initialCursor: 5, + key: "esc", + expectedText: "hello", + expectedCursor: 5, + expectCancelled: true, + }, + { + name: "ctrl+c quits", + key: "ctrl+c", + expectQuit: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + model := NewChatInputModel() + + if tt.initialText != "" { + model.textarea[0] = tt.initialText + } + model.cursor = tt.initialCursor + + if tt.key == "esc" { + model.showSpinner = true + } + + var cmd tea.Cmd + if tt.key == "backspace" { + _, cmd = model.Update(tea.KeyMsg{Type: tea.KeyBackspace}) + } else if tt.key == "left" { + _, cmd = model.Update(tea.KeyMsg{Type: tea.KeyLeft}) + } else if tt.key == "ctrl+d" { + _, cmd = model.Update(tea.KeyMsg{Type: tea.KeyCtrlD}) + } else if tt.key == "esc" { + _, cmd = model.Update(tea.KeyMsg{Type: tea.KeyEsc}) + } else if tt.key == "ctrl+c" { + _, cmd = model.Update(tea.KeyMsg{Type: tea.KeyCtrlC}) + } else if len(tt.key) == 1 { + _, cmd = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(tt.key)}) + } + + if tt.expectedText != "" && model.textarea[0] != tt.expectedText { + t.Errorf("Expected text to be %q, got %q", tt.expectedText, model.textarea[0]) + } + + if model.cursor != tt.expectedCursor { + t.Errorf("Expected cursor to be %d, got %d", tt.expectedCursor, model.cursor) + } + + if model.inputSubmitted != tt.expectSubmitted { + t.Errorf("Expected inputSubmitted to be %v, got %v", tt.expectSubmitted, model.inputSubmitted) + } + + if model.cancelled != tt.expectCancelled { + t.Errorf("Expected cancelled to be %v, got %v", tt.expectCancelled, model.cancelled) + } + + if tt.expectQuit { + if cmd == nil { + t.Errorf("Expected quit command, got nil") + } + if model.quit != true { + t.Errorf("Expected quit to be true, got %v", model.quit) + } + } + }) + } +} + +func TestChatInputModel_ApprovalFlow(t *testing.T) { + model := NewChatInputModel() + + _, cmd := model.Update(ApprovalRequestMsg{Command: "rm -rf /"}) + + if cmd != nil { + t.Errorf("Expected no command from ApprovalRequestMsg, got %v", cmd) + } + + if !model.approvalPending { + t.Errorf("Expected approvalPending to be true") + } + + if model.approvalCommand != "rm -rf /" { + t.Errorf("Expected approvalCommand to be 'rm -rf /', got %q", model.approvalCommand) + } + + if model.approvalSelected != 0 { + t.Errorf("Expected approvalSelected to be 0, got %d", model.approvalSelected) + } + + _, _ = model.Update(tea.KeyMsg{Type: tea.KeyDown}) + if model.approvalSelected != 1 { + t.Errorf("Expected approvalSelected to be 1 after down arrow, got %d", model.approvalSelected) + } + + _, _ = model.Update(tea.KeyMsg{Type: tea.KeyUp}) + if model.approvalSelected != 0 { + t.Errorf("Expected approvalSelected to be 0 after up arrow, got %d", model.approvalSelected) + } + + _, _ = model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + + if model.approvalPending { + t.Errorf("Expected approvalPending to be false after selection") + } + + if model.approvalResponse != 1 { // First option is "Allow" + t.Errorf("Expected approvalResponse to be 1, got %d", model.approvalResponse) + } +} + +func TestChatInputModel_InputOutput(t *testing.T) { + model := NewChatInputModel() + + // Initially no input + if model.HasInput() { + t.Errorf("Expected HasInput() to be false initially") + } + + if model.GetInput() != "" { + t.Errorf("Expected GetInput() to return empty string initially") + } + + model.textarea[0] = "test input" + model.cursor = len("test input") + _, _ = model.Update(tea.KeyMsg{Type: tea.KeyCtrlD}) + + if !model.HasInput() { + t.Errorf("Expected HasInput() to be true after submission") + } + + input := model.GetInput() + if input != "test input" { + t.Errorf("Expected GetInput() to return 'test input', got %q", input) + } + + if model.HasInput() { + t.Errorf("Expected HasInput() to be false after GetInput() call") + } + + if model.textarea[0] != "" { + t.Errorf("Expected textarea to be cleared after GetInput(), got %q", model.textarea[0]) + } + + if model.cursor != 0 { + t.Errorf("Expected cursor to be reset to 0, got %d", model.cursor) + } +} + +func TestChatInputModel_CancellationFlow(t *testing.T) { + model := NewChatInputModel() + + if model.IsCancelled() { + t.Errorf("Expected IsCancelled() to be false initially") + } + + model.showSpinner = true + _, _ = model.Update(tea.KeyMsg{Type: tea.KeyEsc}) + + if !model.IsCancelled() { + t.Errorf("Expected IsCancelled() to be true after Esc during spinner") + } + + model.ResetCancellation() + + if model.IsCancelled() { + t.Errorf("Expected IsCancelled() to be false after ResetCancellation()") + } +} + +func TestChatInputModel_View(t *testing.T) { + model := NewChatInputModel() + model.width = 80 + model.height = 20 + + view := model.View() + + if view == "" { + t.Errorf("Expected non-empty view") + } + + if !strings.Contains(view, "> │") { + t.Errorf("Expected view to contain input prompt '> │'") + } + + if !strings.Contains(view, "Ctrl+D") { + t.Errorf("Expected view to contain help text about Ctrl+D") + } + + model.statusMessage = "Test status" + view = model.View() + + if !strings.Contains(view, "Test status") { + t.Errorf("Expected view to contain status message") + } + + model.approvalPending = true + model.approvalCommand = "test command" + view = model.View() + + if !strings.Contains(view, "test command") { + t.Errorf("Expected view to contain approval command") + } + + if !strings.Contains(view, "Yes - Execute") { + t.Errorf("Expected view to contain approval options") + } +} + +func TestChatInputModel_WindowResize(t *testing.T) { + model := NewChatInputModel() + + _, cmd := model.Update(tea.WindowSizeMsg{Width: 120, Height: 30}) + + if cmd != nil { + t.Errorf("Expected no command from WindowSizeMsg, got %v", cmd) + } + + if model.width != 120 { + t.Errorf("Expected width to be 120, got %d", model.width) + } + + if model.height != 30 { + t.Errorf("Expected height to be 30, got %d", model.height) + } +} + +func TestChatInputModel_SpinnerTick(t *testing.T) { + model := NewChatInputModel() + model.showSpinner = true + + _, cmd := model.Update(SpinnerTick{}) + + if cmd == nil { + t.Errorf("Expected command to continue spinner ticking") + } + + if model.spinnerFrame == 0 { + t.Errorf("Expected spinnerFrame to be incremented") + } + + model.showSpinner = false + _, cmd = model.Update(SpinnerTick{}) + + if cmd != nil { + t.Errorf("Expected no command when spinner is off, got %v", cmd) + } +} + +// Helper function to compare string slices +func equalStringSlices(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/internal/helpviewer.go b/internal/helpviewer.go new file mode 100644 index 00000000..7f37a4de --- /dev/null +++ b/internal/helpviewer.go @@ -0,0 +1,131 @@ +package internal + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/bubbletea" +) + +// HelpViewerModel represents a professional help display interface +type HelpViewerModel struct { + width int + height int + done bool +} + +// NewHelpViewerModel creates a new help viewer +func NewHelpViewerModel() *HelpViewerModel { + return &HelpViewerModel{ + width: 80, + height: 20, + done: false, + } +} + +func (m *HelpViewerModel) Init() tea.Cmd { + return nil +} + +func (m *HelpViewerModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + return m, nil + + case tea.KeyMsg: + // Any key press closes the help + m.done = true + return m, tea.Quit + } + + return m, nil +} + +func (m *HelpViewerModel) View() string { + var b strings.Builder + + // Center the content + maxWidth := min(m.width, 80) + padding := (m.width - maxWidth) / 2 + if padding < 0 { + padding = 0 + } + + // Title + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString("šŸ’¬ \033[1;36mChat Session Help\033[0m\n\n") + + // Commands section + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString("šŸ”§ \033[1mCommands:\033[0m\n") + commands := [][]string{ + {"/exit, /quit", "Exit chat session"}, + {"/clear", "Clear conversation history"}, + {"/history", "Show conversation history"}, + {"/models", "Show current and available models"}, + {"/switch", "Switch to a different model"}, + {"/compact", "Export conversation to markdown file"}, + {"/help, ?", "Show this help"}, + } + + for _, cmd := range commands { + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString(fmt.Sprintf(" \033[36m%-15s\033[0m - %s\n", cmd[0], cmd[1])) + } + b.WriteString("\n") + + // File references section + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString("šŸ“ \033[1mFile References:\033[0m\n") + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString(" \033[36m@filename.txt\033[0m - Include file contents in your message\n") + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString(" \033[36m@./config.yaml\033[0m - Include contents from current directory\n") + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString(" \033[36m@../README.md\033[0m - Include contents from parent directory\n") + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString(" \033[90mMaximum file size: 100KB\033[0m\n\n") + + // Tools section + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString("šŸ”§ \033[1mTools:\033[0m\n") + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString(" Models can invoke available tools automatically during conversation\n") + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString(" Use \033[36minfer tools list\033[0m to see whitelisted commands\n") + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString(" Use \033[36minfer tools enable/disable\033[0m to control tool access\n\n") + + // Input tips section + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString("šŸ’” \033[1mInput Tips:\033[0m\n") + inputTips := [][]string{ + {"Ctrl+D", "Send message"}, + {"Ctrl+C", "Cancel current input"}, + {"Esc", "Cancel generation while model is responding"}, + {"Tab", "Scroll through chat history"}, + {"↑↓", "Navigate text and history"}, + {"Enter", "New line in multi-line input"}, + } + + for _, tip := range inputTips { + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString(fmt.Sprintf(" \033[36m%-8s\033[0m - %s\n", tip[0], tip[1])) + } + b.WriteString("\n") + + // Footer with instructions + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString(strings.Repeat("─", min(maxWidth, 60)) + "\n") + b.WriteString(strings.Repeat(" ", padding)) + b.WriteString("šŸ’” \033[90;1mPress any key to return to chat\033[0m") + + return b.String() +} + +// IsDone returns true if help viewing is complete +func (m *HelpViewerModel) IsDone() bool { + return m.done +} diff --git a/internal/llm_tools.go b/internal/llm_tools.go index 20218ce6..c17b401c 100644 --- a/internal/llm_tools.go +++ b/internal/llm_tools.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" + "github.com/charmbracelet/bubbletea" "github.com/inference-gateway/cli/config" ) @@ -19,6 +20,13 @@ func NewLLMToolsManager(cfg *config.Config) *LLMToolsManager { } } +// NewLLMToolsManagerWithUI creates a new LLM tools manager with UI integration +func NewLLMToolsManagerWithUI(cfg *config.Config, program *tea.Program, inputModel *ChatInputModel) *LLMToolsManager { + return &LLMToolsManager{ + toolEngine: NewToolEngineWithUI(cfg, program, inputModel), + } +} + // BashTool represents the Bash tool that LLMs can invoke type BashTool struct { manager *LLMToolsManager diff --git a/internal/modelselector.go b/internal/modelselector.go new file mode 100644 index 00000000..a6787036 --- /dev/null +++ b/internal/modelselector.go @@ -0,0 +1,202 @@ +package internal + +import ( + "fmt" + "strings" + + "github.com/charmbracelet/bubbletea" +) + +// ModelSelectorModel represents a professional model selection interface +type ModelSelectorModel struct { + models []string + filteredModels []string + cursor int + searchQuery string + width int + height int + selected string + cancelled bool + done bool +} + +// NewModelSelectorModel creates a new model selector +func NewModelSelectorModel(models []string) *ModelSelectorModel { + return &ModelSelectorModel{ + models: models, + filteredModels: models, + cursor: 0, + searchQuery: "", + width: 80, + height: 20, + selected: "", + cancelled: false, + done: false, + } +} + +func (m *ModelSelectorModel) Init() tea.Cmd { + return nil +} + +func (m *ModelSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + return m, nil + + case tea.KeyMsg: + switch msg.String() { + case "ctrl+c", "esc": + m.cancelled = true + m.done = true + return m, tea.Quit + + case "enter": + if len(m.filteredModels) > 0 && m.cursor < len(m.filteredModels) { + m.selected = m.filteredModels[m.cursor] + m.done = true + return m, tea.Quit + } + return m, nil + + case "up": + if m.cursor > 0 { + m.cursor-- + } + return m, nil + + case "down": + if m.cursor < len(m.filteredModels)-1 { + m.cursor++ + } + return m, nil + + case "backspace": + if len(m.searchQuery) > 0 { + m.searchQuery = m.searchQuery[:len(m.searchQuery)-1] + m.filterModels() + m.adjustCursor() + } + return m, nil + + default: + if len(msg.String()) == 1 && msg.String()[0] >= 32 && msg.String()[0] <= 126 { + m.searchQuery += msg.String() + m.filterModels() + m.adjustCursor() + } + return m, nil + } + } + + return m, nil +} + +func (m *ModelSelectorModel) filterModels() { + if m.searchQuery == "" { + m.filteredModels = m.models + return + } + + var filtered []string + query := strings.ToLower(m.searchQuery) + + for _, model := range m.models { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, query) { + filtered = append(filtered, model) + } + } + + m.filteredModels = filtered +} + +func (m *ModelSelectorModel) adjustCursor() { + if m.cursor >= len(m.filteredModels) { + if len(m.filteredModels) > 0 { + m.cursor = len(m.filteredModels) - 1 + } else { + m.cursor = 0 + } + } +} + +func (m *ModelSelectorModel) View() string { + var b strings.Builder + + b.WriteString("šŸ¤– Select a model for the chat session\n\n") + + searchBox := fmt.Sprintf("šŸ” Search: %s", m.searchQuery) + if len(m.searchQuery) == 0 { + searchBox += "│" + } + b.WriteString(searchBox + "\n") + b.WriteString(strings.Repeat("─", min(m.width, 60)) + "\n\n") + + if len(m.filteredModels) != len(m.models) { + b.WriteString(fmt.Sprintf("Showing %d of %d models\n\n", len(m.filteredModels), len(m.models))) + } + + if len(m.filteredModels) == 0 { + b.WriteString("āŒ No models match your search\n") + } else { + maxVisible := min(10, m.height-8) + startIdx := 0 + endIdx := len(m.filteredModels) + + if len(m.filteredModels) > maxVisible { + if m.cursor >= maxVisible/2 { + startIdx = min(m.cursor-maxVisible/2, len(m.filteredModels)-maxVisible) + endIdx = startIdx + maxVisible + } else { + endIdx = maxVisible + } + } + + for i := startIdx; i < endIdx && i < len(m.filteredModels); i++ { + model := m.filteredModels[i] + if i == m.cursor { + b.WriteString(fmt.Sprintf("ā–¶ \033[36;1m%s\033[0m\n", model)) + } else { + b.WriteString(fmt.Sprintf(" %s\n", model)) + } + } + + if len(m.filteredModels) > maxVisible { + if startIdx > 0 { + b.WriteString("\n ↑ More models above\n") + } + if endIdx < len(m.filteredModels) { + b.WriteString(" ↓ More models below\n") + } + } + } + + b.WriteString("\n") + b.WriteString(strings.Repeat("─", min(m.width, 60)) + "\n") + b.WriteString("šŸ’” \033[90mType to search • ↑↓ Navigate • Enter Select • Esc Cancel\033[0m") + + return b.String() +} + +// IsSelected returns true if a model was selected +func (m *ModelSelectorModel) IsSelected() bool { + return m.done && !m.cancelled && m.selected != "" +} + +// IsCancelled returns true if selection was cancelled +func (m *ModelSelectorModel) IsCancelled() bool { + return m.cancelled +} + +// GetSelected returns the selected model +func (m *ModelSelectorModel) GetSelected() string { + return m.selected +} + +// IsDone returns true if selection process is complete +func (m *ModelSelectorModel) IsDone() bool { + return m.done +} diff --git a/internal/safety.go b/internal/safety.go index 1c18b7bd..41fcab98 100644 --- a/internal/safety.go +++ b/internal/safety.go @@ -2,8 +2,9 @@ package internal import ( "fmt" + "time" - "github.com/manifoldco/promptui" + "github.com/charmbracelet/bubbletea" ) // ApprovalSession tracks approval decisions for a session @@ -41,47 +42,35 @@ func (ad ApprovalDecision) String() string { } } -// PromptForApproval prompts the user for command execution approval -func (as *ApprovalSession) PromptForApproval(command string) (ApprovalDecision, error) { +// PromptForApprovalBubbleTea prompts the user for command execution approval using Bubble Tea +func (as *ApprovalSession) PromptForApprovalBubbleTea(command string, program *tea.Program, inputModel *ChatInputModel) (ApprovalDecision, error) { if as.skipApproval { return ApprovalAllow, nil } - fmt.Printf("\nāš ļø Command execution approval required:\n") - fmt.Printf("Command: %s\n\n", command) + // Send approval request to the chat interface + program.Send(ApprovalRequestMsg{Command: command}) - options := []string{ - "Yes - Execute this command", - "Yes, and don't ask again - Execute this and all future commands", - "No - Cancel command execution", - } - - prompt := promptui.Select{ - Label: "Please select an option", - Items: options, - Templates: &promptui.SelectTemplates{ - Label: "{{ . }}:", - Active: "ā–¶ {{ . | cyan | bold }}", - Inactive: " {{ . }}", - Selected: "āœ“ {{ . | green }}", - }, - } + // Wait for user response + for { + time.Sleep(50 * time.Millisecond) - index, _, err := prompt.Run() - if err != nil { - return ApprovalDeny, fmt.Errorf("selection failed: %w", err) - } + if !inputModel.IsApprovalPending() { + response := inputModel.GetApprovalResponse() + inputModel.ResetApproval() - switch index { - case 0: - return ApprovalAllow, nil - case 1: - as.skipApproval = true - return ApprovalAllowAll, nil - case 2: - return ApprovalDeny, nil - default: - return ApprovalDeny, fmt.Errorf("invalid selection") + switch response { + case 1: + return ApprovalAllow, nil + case 2: + as.skipApproval = true + return ApprovalAllowAll, nil + case 0: + return ApprovalDeny, nil + default: + return ApprovalDeny, fmt.Errorf("invalid or cancelled selection") + } + } } } diff --git a/internal/tools.go b/internal/tools.go index c2158b4e..194469b5 100644 --- a/internal/tools.go +++ b/internal/tools.go @@ -8,6 +8,7 @@ import ( "strings" "time" + "github.com/charmbracelet/bubbletea" "github.com/inference-gateway/cli/config" ) @@ -15,6 +16,8 @@ import ( type ToolEngine struct { config *config.Config approvalSession *ApprovalSession + program *tea.Program + inputModel *ChatInputModel } // NewToolEngine creates a new tool engine @@ -22,6 +25,18 @@ func NewToolEngine(cfg *config.Config) *ToolEngine { return &ToolEngine{ config: cfg, approvalSession: NewApprovalSession(), + program: nil, + inputModel: nil, + } +} + +// NewToolEngineWithUI creates a new tool engine with UI integration +func NewToolEngineWithUI(cfg *config.Config, program *tea.Program, inputModel *ChatInputModel) *ToolEngine { + return &ToolEngine{ + config: cfg, + approvalSession: NewApprovalSession(), + program: program, + inputModel: inputModel, } } @@ -45,7 +60,17 @@ func (te *ToolEngine) ExecuteBash(command string) (*ToolResult, error) { } if te.config.Tools.Safety.RequireApproval { - decision, err := te.approvalSession.PromptForApproval(command) + var decision ApprovalDecision + var err error + + if te.program != nil && te.inputModel != nil { + // Use Bubble Tea UI for approval + decision, err = te.approvalSession.PromptForApprovalBubbleTea(command, te.program, te.inputModel) + } else { + // Fallback to console approval (shouldn't happen in chat mode) + return nil, fmt.Errorf("approval UI not available") + } + if err != nil { return nil, fmt.Errorf("approval prompt failed: %w", err) }