diff --git a/cmd/prompt.go b/cmd/prompt.go new file mode 100644 index 00000000..fc19fe10 --- /dev/null +++ b/cmd/prompt.go @@ -0,0 +1,336 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/inference-gateway/cli/config" + "github.com/inference-gateway/cli/internal/container" + "github.com/inference-gateway/cli/internal/domain" + "github.com/inference-gateway/cli/internal/logger" + "github.com/inference-gateway/cli/internal/ui" + sdk "github.com/inference-gateway/sdk" + "github.com/spf13/cobra" +) + +var promptCmd = &cobra.Command{ + Use: "prompt [prompt_text]", + Short: "Execute a one-off prompt in background mode", + Long: `Execute a one-off prompt that runs in background mode until the task is complete. +This command can work with URLs (including GitHub issues) using the Fetch tool. + +Examples: + infer prompt "Please analyze https://github.com/owner/repo/issues/123" + infer prompt "Help me understand this issue: https://github.com/owner/repo/issues/456" + infer prompt "Optimize the database queries in the user service"`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + promptText := args[0] + return executeBackgroundPrompt(promptText) + }, +} + +// BackgroundExecutor handles background execution of prompts +type BackgroundExecutor struct { + services *container.ServiceContainer + maxIterations int +} + +// executeBackgroundPrompt executes a prompt in background mode +func executeBackgroundPrompt(promptText string) error { + cfg, err := config.LoadConfig("") + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + serviceContainer := container.NewServiceContainer(cfg) + + executor := &BackgroundExecutor{ + services: serviceContainer, + maxIterations: 10, + } + + logger.Info("background_execution_starting", "prompt_text", promptText) + return executor.Execute(promptText) +} + +// Execute runs the background prompt execution +func (e *BackgroundExecutor) Execute(promptText string) error { + ctx := context.Background() + + model, err := e.selectModelRobust(ctx) + if err != nil { + return fmt.Errorf("failed to select model: %w", err) + } + + logger.Info("model_selected", "model", model) + + // Use system prompt from config + cfg := e.services.GetConfig() + systemPrompt := cfg.Chat.SystemPrompt + + logger.Debug("background_execution_started", + "model", model, + "prompt_text", promptText, + "system_prompt", systemPrompt) + + return e.executeIteratively(ctx, model, systemPrompt, promptText) +} + + + +// selectModelRobust selects the configured default model only +func (e *BackgroundExecutor) selectModelRobust(ctx context.Context) (string, error) { + cfg := e.services.GetConfig() + + if cfg.Chat.DefaultModel == "" { + return "", fmt.Errorf("no default model configured in .infer/config.yaml") + } + + if err := e.services.GetModelService().SelectModel(cfg.Chat.DefaultModel); err != nil { + return "", fmt.Errorf("failed to select configured default model '%s': %w", cfg.Chat.DefaultModel, err) + } + + return cfg.Chat.DefaultModel, nil +} + + +// sendMessageDirectWithToolCalls sends a message and returns both content and tool calls +func (e *BackgroundExecutor) sendMessageDirectWithToolCalls(ctx context.Context, model string, messages []sdk.Message) (string, []sdk.ChatCompletionMessageToolCall, error) { + parts := strings.SplitN(model, "/", 2) + if len(parts) != 2 { + return "", nil, fmt.Errorf("invalid model format, expected 'provider/model'") + } + provider := parts[0] + modelName := parts[1] + + cfg := e.services.GetConfig() + client := sdk.NewClient(&sdk.ClientOptions{ + BaseURL: strings.TrimSuffix(cfg.Gateway.URL, "/") + "/v1", + APIKey: cfg.Gateway.APIKey, + }) + + messages = e.addToolsIfAvailable(messages) + + providerType := sdk.Provider(provider) + response, err := client.GenerateContent(ctx, providerType, modelName, messages) + if err != nil { + return "", nil, fmt.Errorf("failed to generate content: %w", err) + } + + if len(response.Choices) == 0 { + return "", nil, fmt.Errorf("no choices in response") + } + + choice := response.Choices[0] + content := choice.Message.Content + var toolCalls []sdk.ChatCompletionMessageToolCall + + if choice.Message.ToolCalls != nil { + toolCalls = *choice.Message.ToolCalls + } + + return content, toolCalls, nil +} + +// addToolsIfAvailable adds tools to messages if tool service is available +func (e *BackgroundExecutor) addToolsIfAvailable(messages []sdk.Message) []sdk.Message { + toolService := e.services.GetToolService() + if toolService == nil { + return messages + } + + availableTools := toolService.ListTools() + if len(availableTools) == 0 { + return messages + } + + toolsMessage := e.createToolsSystemMessage(availableTools) + + var result []sdk.Message + systemAdded := false + + for _, msg := range messages { + if msg.Role == sdk.System && !systemAdded { + result = append(result, msg, toolsMessage) + systemAdded = true + } else { + result = append(result, msg) + } + } + + if !systemAdded { + result = append([]sdk.Message{toolsMessage}, result...) + } + + return result +} + +// createToolsSystemMessage creates a system message describing available tools +func (e *BackgroundExecutor) createToolsSystemMessage(tools []domain.ToolDefinition) sdk.Message { + content := "You have access to the following tools:\n\n" + + for _, tool := range tools { + content += fmt.Sprintf("- **%s**: %s\n", tool.Name, tool.Description) + } + + content += "\nTo use a tool, respond with a tool call using the proper format. The system will execute the tool and provide you with the results." + + return sdk.Message{ + Role: sdk.System, + Content: content, + } +} + +// executeIteratively executes the prompt iteratively until completion +func (e *BackgroundExecutor) executeIteratively(ctx context.Context, model, systemPrompt, promptText string) error { + messages := []sdk.Message{ + {Role: sdk.System, Content: systemPrompt}, + {Role: sdk.User, Content: promptText}, + } + + for iteration := 1; iteration <= e.maxIterations; iteration++ { + logger.Info("iteration_starting", "iteration", iteration, "max_iterations", e.maxIterations) + + logger.Debug("sending_message_to_model", + "iteration", iteration, + "model", model, + "message_count", len(messages), + "messages", messages) + + response, toolCalls, err := e.sendMessageDirectWithToolCalls(ctx, model, messages) + if err != nil { + logger.Error("failed_to_send_message", "error", err, "model", model) + return fmt.Errorf("failed to send message: %w", err) + } + + logger.Debug("received_assistant_response", + "iteration", iteration, + "response_length", len(response), + "tool_calls_count", len(toolCalls)) + + logger.Info("assistant_response", "iteration", iteration, "content", response) + + assistantMsg := sdk.Message{ + Role: sdk.Assistant, + Content: response, + } + + if len(toolCalls) > 0 { + assistantMsg.ToolCalls = &toolCalls + } + + messages = append(messages, assistantMsg) + + if len(toolCalls) > 0 { + logger.Info("processing_tool_calls", "count", len(toolCalls), "iteration", iteration) + + toolResultsProcessed := false + for _, toolCall := range toolCalls { + logger.Info("executing_tool", "tool_name", toolCall.Function.Name, "iteration", iteration) + + toolResult, err := e.executeToolCall(ctx, toolCall) + if err != nil { + logger.Error("tool_execution_failed", "tool_name", toolCall.Function.Name, "error", err) + toolResult = fmt.Sprintf("Tool execution failed: %v", err) + } + + toolResultMsg := sdk.Message{ + Role: sdk.Tool, + Content: toolResult, + ToolCallId: &toolCall.Id, + } + messages = append(messages, toolResultMsg) + toolResultsProcessed = true + + logger.Info("tool_result", "tool_name", toolCall.Function.Name, "result", toolResult) + } + + if toolResultsProcessed { + continue + } + } + + if e.isTaskCompleted(response) { + logger.Info("task_completed", "iteration", iteration) + return nil + } + + followUpPrompt := e.generateFollowUpPrompt(response, iteration) + logger.Debug("generated_follow_up_prompt", + "iteration", iteration, + "follow_up_prompt", followUpPrompt) + + messages = append(messages, sdk.Message{ + Role: sdk.User, + Content: followUpPrompt, + }) + } + + logger.Warn("max_iterations_reached", "max_iterations", e.maxIterations) + return nil +} + +// executeToolCall executes a single tool call and returns the result +func (e *BackgroundExecutor) executeToolCall(ctx context.Context, toolCall sdk.ChatCompletionMessageToolCall) (string, error) { + toolService := e.services.GetToolService() + if toolService == nil { + return "", fmt.Errorf("tool service not available") + } + + var args map[string]interface{} + if toolCall.Function.Arguments != "" { + if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &args); err != nil { + return "", fmt.Errorf("failed to parse tool arguments: %w", err) + } + } + + result, err := toolService.ExecuteTool(ctx, toolCall.Function.Name, args) + if err != nil { + return "", fmt.Errorf("tool execution failed: %w", err) + } + + return ui.FormatToolResultForLLM(result), nil +} + + +// isTaskCompleted checks if the task appears to be completed based on the response +func (e *BackgroundExecutor) isTaskCompleted(response string) bool { + completionIndicators := []string{ + "task completed", + "solution implemented", + "issue resolved", + "implementation complete", + "problem solved", + "finished", + "done", + } + + responseLower := strings.ToLower(response) + for _, indicator := range completionIndicators { + if strings.Contains(responseLower, indicator) { + return true + } + } + + return false +} + +// generateFollowUpPrompt generates a follow-up prompt to continue the task +func (e *BackgroundExecutor) generateFollowUpPrompt(response string, iteration int) string { + prompts := []string{ + "Please continue with the next steps to complete this task.", + "What additional work is needed to fully resolve this issue?", + "Are there any remaining steps or considerations for this task?", + "Please provide any additional implementation details or next steps.", + } + + promptIndex := (iteration - 1) % len(prompts) + return prompts[promptIndex] +} + +func init() { + rootCmd.AddCommand(promptCmd) +} diff --git a/go.mod b/go.mod index 8f9d8b73..1cf95bd7 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,8 @@ require ( github.com/charmbracelet/bubbles v0.21.0 github.com/charmbracelet/bubbletea v1.3.6 github.com/charmbracelet/lipgloss v1.1.0 - github.com/inference-gateway/sdk v1.11.1 + github.com/go-resty/resty/v2 v2.16.5 + github.com/inference-gateway/sdk v1.11.0 github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 github.com/muesli/reflow v0.3.0 github.com/spf13/cobra v1.9.1 @@ -26,7 +27,6 @@ require ( github.com/charmbracelet/x/term v0.2.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect - github.com/go-resty/resty/v2 v2.16.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kr/pretty v0.2.1 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect diff --git a/go.sum b/go.sum index e450ae83..d2e944a4 100644 --- a/go.sum +++ b/go.sum @@ -21,14 +21,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4= github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM= -github.com/go-resty/resty/v2 v2.16.3 h1:zacNT7lt4b8M/io2Ahj6yPypL7bqx9n1iprfQuodV+E= -github.com/go-resty/resty/v2 v2.16.3/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ07xAwp/fiA= +github.com/go-resty/resty/v2 v2.16.5 h1:hBKqmWrr7uRc3euHVqmh1HTHcKn99Smr7o5spptdhTM= +github.com/go-resty/resty/v2 v2.16.5/go.mod h1:hkJtXbA2iKHzJheXYvQ8snQES5ZLGKMwQ07xAwp/fiA= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/inference-gateway/sdk v1.11.1 h1:3MS6isTvfy8efZky1BmaDfmeltUn5m9krdKORuV3nU0= -github.com/inference-gateway/sdk v1.11.1/go.mod h1:3TTD7Kbr7FRt+9ZbCPAm3u0tXUIWx7flZuwrRgZgrdk= +github.com/inference-gateway/sdk v1.11.0 h1:eeq/VE8/2m+kFajwXGOFnDNvskkyfAwFZDxOLiIEv2A= +github.com/inference-gateway/sdk v1.11.0/go.mod h1:3TTD7Kbr7FRt+9ZbCPAm3u0tXUIWx7flZuwrRgZgrdk= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= diff --git a/internal/services/github.go b/internal/services/github.go new file mode 100644 index 00000000..84a6ba29 --- /dev/null +++ b/internal/services/github.go @@ -0,0 +1,192 @@ +package services + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/go-resty/resty/v2" +) + +// GitHubIssue represents a GitHub issue +type GitHubIssue struct { + Number int `json:"number"` + Title string `json:"title"` + Body string `json:"body"` + State string `json:"state"` + HTMLURL string `json:"html_url"` + User User `json:"user"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + Labels []Label `json:"labels"` + Comments int `json:"comments"` + Repository string `json:"-"` // Will be set manually +} + +// User represents a GitHub user +type User struct { + Login string `json:"login"` +} + +// Label represents a GitHub label +type Label struct { + Name string `json:"name"` + Color string `json:"color"` +} + +// GitHubService handles GitHub API operations +type GitHubService struct { + client *resty.Client +} + +// NewGitHubService creates a new GitHub service +func NewGitHubService() *GitHubService { + client := resty.New(). + SetTimeout(30 * time.Second). + SetRetryCount(3). + SetRetryWaitTime(1 * time.Second). + SetRetryMaxWaitTime(5 * time.Second) + + return &GitHubService{ + client: client, + } +} + +// FetchIssue fetches a GitHub issue by repository and issue number +func (g *GitHubService) FetchIssue(ctx context.Context, repository string, issueNumber int) (*GitHubIssue, error) { + if repository == "" { + return nil, fmt.Errorf("repository cannot be empty") + } + if issueNumber <= 0 { + return nil, fmt.Errorf("issue number must be positive") + } + + url := fmt.Sprintf("https://api.github.com/repos/%s/issues/%d", repository, issueNumber) + + resp, err := g.client.R(). + SetContext(ctx). + SetHeader("Accept", "application/vnd.github.v3+json"). + SetHeader("User-Agent", "inference-gateway-cli/1.0"). + Get(url) + + if err != nil { + return nil, fmt.Errorf("failed to fetch issue: %w", err) + } + + if resp.StatusCode() == http.StatusNotFound { + return nil, fmt.Errorf("issue #%d not found in repository %s", issueNumber, repository) + } + + if resp.StatusCode() != http.StatusOK { + return nil, fmt.Errorf("GitHub API returned status %d: %s", resp.StatusCode(), resp.String()) + } + + var issue GitHubIssue + if err := json.Unmarshal(resp.Body(), &issue); err != nil { + return nil, fmt.Errorf("failed to parse issue response: %w", err) + } + + issue.Repository = repository + return &issue, nil +} + +// ParseIssueReference parses various GitHub issue reference formats +func (g *GitHubService) ParseIssueReference(reference string) (repository string, issueNumber int, err error) { + reference = strings.TrimSpace(reference) + + if reference == "" { + return "", 0, fmt.Errorf("issue reference cannot be empty") + } + + // Handle different formats: + // 1. "123" - just number (requires repository context) + // 2. "owner/repo#123" - full reference + // 3. "https://github.com/owner/repo/issues/123" - full URL + // 4. "#123" - number with hash (requires repository context) + + // Handle URL format + if strings.HasPrefix(reference, "https://github.com/") { + return g.parseFromURL(reference) + } + + // Handle owner/repo#123 format + if strings.Contains(reference, "/") && strings.Contains(reference, "#") { + parts := strings.Split(reference, "#") + if len(parts) != 2 { + return "", 0, fmt.Errorf("invalid issue reference format: %s", reference) + } + + repository = parts[0] + number, err := strconv.Atoi(parts[1]) + if err != nil { + return "", 0, fmt.Errorf("invalid issue number: %s", parts[1]) + } + + return repository, number, nil + } + + // Handle #123 or 123 format (number only) + numberStr := strings.TrimPrefix(reference, "#") + number, err := strconv.Atoi(numberStr) + if err != nil { + return "", 0, fmt.Errorf("invalid issue number: %s", numberStr) + } + + // Return empty repository - caller needs to provide context + return "", number, nil +} + +// parseFromURL parses issue information from a GitHub URL +func (g *GitHubService) parseFromURL(url string) (repository string, issueNumber int, err error) { + // Expected format: https://github.com/owner/repo/issues/123 + url = strings.TrimPrefix(url, "https://github.com/") + parts := strings.Split(url, "/") + + if len(parts) < 4 || parts[2] != "issues" { + return "", 0, fmt.Errorf("invalid GitHub issue URL format") + } + + repository = fmt.Sprintf("%s/%s", parts[0], parts[1]) + number, err := strconv.Atoi(parts[3]) + if err != nil { + return "", 0, fmt.Errorf("invalid issue number in URL: %s", parts[3]) + } + + return repository, number, nil +} + +// FormatIssueForPrompt formats a GitHub issue into a prompt-friendly string +func (g *GitHubService) FormatIssueForPrompt(issue *GitHubIssue) string { + var builder strings.Builder + + builder.WriteString(fmt.Sprintf("# GitHub Issue #%d: %s\n\n", issue.Number, issue.Title)) + builder.WriteString(fmt.Sprintf("**Repository:** %s\n", issue.Repository)) + builder.WriteString(fmt.Sprintf("**Status:** %s\n", issue.State)) + builder.WriteString(fmt.Sprintf("**Created by:** %s\n", issue.User.Login)) + builder.WriteString(fmt.Sprintf("**URL:** %s\n\n", issue.HTMLURL)) + + if len(issue.Labels) > 0 { + builder.WriteString("**Labels:** ") + for i, label := range issue.Labels { + if i > 0 { + builder.WriteString(", ") + } + builder.WriteString(label.Name) + } + builder.WriteString("\n\n") + } + + builder.WriteString("## Description\n\n") + if issue.Body != "" { + builder.WriteString(issue.Body) + } else { + builder.WriteString("*No description provided*") + } + builder.WriteString("\n") + + return builder.String() +}