diff --git a/.infer/config.yaml b/.infer/config.yaml index 7a008561..239b4927 100644 --- a/.infer/config.yaml +++ b/.infer/config.yaml @@ -32,3 +32,19 @@ compact: chat: default_model: "" system_prompt: "" +fetch: + enabled: false + whitelisted_domains: + - github.com + github: + enabled: false + token: "" + base_url: https://api.github.com + safety: + max_size: 8192 + timeout: 30 + allow_redirect: true + cache: + enabled: true + ttl: 3600 + max_size: 52428800 diff --git a/cmd/config.go b/cmd/config.go index 69330477..93ea751d 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "os" + "strings" "time" "github.com/inference-gateway/cli/config" @@ -171,6 +172,98 @@ var configToolsExcludePathRemoveCmd = &cobra.Command{ RunE: removeExcludedPath, } +var configFetchCmd = &cobra.Command{ + Use: "fetch", + Short: "Manage fetch tool settings", + Long: `Manage the fetch tool that allows LLMs to retrieve content from whitelisted URLs. +The fetch tool supports GitHub integration and URL pattern matching for secure content retrieval.`, +} + +var configFetchEnableCmd = &cobra.Command{ + Use: "enable", + Short: "Enable the fetch tool", + Long: `Enable the fetch tool to allow LLMs to retrieve content from whitelisted sources.`, + RunE: enableFetch, +} + +var configFetchDisableCmd = &cobra.Command{ + Use: "disable", + Short: "Disable the fetch tool", + Long: `Disable the fetch tool to prevent LLMs from retrieving any external content.`, + RunE: disableFetch, +} + +var configFetchListCmd = &cobra.Command{ + Use: "list", + Short: "List whitelisted domains", + Long: `Display all whitelisted domains that can be fetched by LLMs.`, + RunE: listFetchDomains, +} + +var configFetchAddDomainCmd = &cobra.Command{ + Use: "add-domain ", + Short: "Add a domain to the whitelist", + Long: `Add a domain to the whitelist of allowed fetch sources (e.g., github.com, example.org).`, + Args: cobra.ExactArgs(1), + RunE: addFetchDomain, +} + +var configFetchRemoveDomainCmd = &cobra.Command{ + Use: "remove-domain ", + Short: "Remove a domain from the whitelist", + Long: `Remove a domain from the whitelist of allowed fetch sources.`, + Args: cobra.ExactArgs(1), + RunE: removeFetchDomain, +} + +var configFetchGitHubCmd = &cobra.Command{ + Use: "github", + Short: "Manage GitHub integration settings", + Long: `Manage GitHub-specific fetch settings including API access and optimization features.`, +} + +var configFetchGitHubEnableCmd = &cobra.Command{ + Use: "enable", + Short: "Enable GitHub integration", + Long: `Enable GitHub API integration for optimized fetching of GitHub issues and pull requests.`, + RunE: enableGitHubFetch, +} + +var configFetchGitHubDisableCmd = &cobra.Command{ + Use: "disable", + Short: "Disable GitHub integration", + Long: `Disable GitHub API integration, falling back to regular HTTP fetching.`, + RunE: disableGitHubFetch, +} + +var configFetchGitHubTokenCmd = &cobra.Command{ + Use: "set-token ", + Short: "Set GitHub API token", + Long: `Set the GitHub API token for authenticated requests to increase rate limits.`, + Args: cobra.ExactArgs(1), + RunE: setGitHubToken, +} + +var configFetchCacheCmd = &cobra.Command{ + Use: "cache", + Short: "Manage fetch cache settings", + Long: `Manage caching settings for fetched content to improve performance.`, +} + +var configFetchCacheStatusCmd = &cobra.Command{ + Use: "status", + Short: "Show cache status and statistics", + Long: `Display current cache status, statistics, and configuration.`, + RunE: fetchCacheStatus, +} + +var configFetchCacheClearCmd = &cobra.Command{ + Use: "clear", + Short: "Clear the fetch cache", + Long: `Clear all cached content to free up memory and force fresh fetches.`, + RunE: fetchCacheClear, +} + func setDefaultModel(modelName string) error { cfg, err := config.LoadConfig("") if err != nil { @@ -211,6 +304,7 @@ func init() { configCmd.AddCommand(setSystemCmd) configCmd.AddCommand(configInitCmd) configCmd.AddCommand(configToolsCmd) + configCmd.AddCommand(configFetchCmd) configToolsCmd.AddCommand(configToolsEnableCmd) configToolsCmd.AddCommand(configToolsDisableCmd) @@ -228,9 +322,25 @@ func init() { configToolsExcludePathCmd.AddCommand(configToolsExcludePathAddCmd) configToolsExcludePathCmd.AddCommand(configToolsExcludePathRemoveCmd) + configFetchCmd.AddCommand(configFetchEnableCmd) + configFetchCmd.AddCommand(configFetchDisableCmd) + configFetchCmd.AddCommand(configFetchListCmd) + configFetchCmd.AddCommand(configFetchAddDomainCmd) + configFetchCmd.AddCommand(configFetchRemoveDomainCmd) + configFetchCmd.AddCommand(configFetchGitHubCmd) + configFetchCmd.AddCommand(configFetchCacheCmd) + + configFetchGitHubCmd.AddCommand(configFetchGitHubEnableCmd) + configFetchGitHubCmd.AddCommand(configFetchGitHubDisableCmd) + configFetchGitHubCmd.AddCommand(configFetchGitHubTokenCmd) + + configFetchCacheCmd.AddCommand(configFetchCacheStatusCmd) + configFetchCacheCmd.AddCommand(configFetchCacheClearCmd) + configInitCmd.Flags().Bool("overwrite", false, "Overwrite existing configuration file") configToolsListCmd.Flags().StringP("format", "f", "text", "Output format (text, json)") configToolsExecCmd.Flags().StringP("format", "f", "text", "Output format (text, json)") + configFetchListCmd.Flags().StringP("format", "f", "text", "Output format (text, json)") rootCmd.AddCommand(configCmd) } @@ -522,3 +632,247 @@ func removeExcludedPath(cmd *cobra.Command, args []string) error { fmt.Printf("Tools can now access this path again\n") return nil } + +func enableFetch(cmd *cobra.Command, args []string) error { + _, err := loadAndUpdateConfig(func(c *config.Config) { + c.Fetch.Enabled = true + }) + if err != nil { + return err + } + + fmt.Printf("%s\n", ui.FormatSuccess("Fetch tool enabled successfully")) + fmt.Printf("Configuration saved to: %s\n", getConfigPath()) + fmt.Println("You can now configure whitelisted sources with 'infer config fetch add-source '") + return nil +} + +func disableFetch(cmd *cobra.Command, args []string) error { + _, err := loadAndUpdateConfig(func(c *config.Config) { + c.Fetch.Enabled = false + }) + if err != nil { + return err + } + + fmt.Printf("%s\n", ui.FormatErrorCLI("Fetch tool disabled successfully")) + fmt.Printf("Configuration saved to: %s\n", getConfigPath()) + return nil +} + +func listFetchDomains(cmd *cobra.Command, args []string) error { + cfg, err := config.LoadConfig("") + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + format, _ := cmd.Flags().GetString("format") + + if format == "json" { + data := map[string]interface{}{ + "enabled": cfg.Fetch.Enabled, + "whitelisted_domains": cfg.Fetch.WhitelistedDomains, + "github": map[string]interface{}{ + "enabled": cfg.Fetch.GitHub.Enabled, + "base_url": cfg.Fetch.GitHub.BaseURL, + "has_token": cfg.Fetch.GitHub.Token != "", + }, + "safety": map[string]interface{}{ + "max_size": cfg.Fetch.Safety.MaxSize, + "timeout": cfg.Fetch.Safety.Timeout, + "allow_redirect": cfg.Fetch.Safety.AllowRedirect, + }, + "cache": map[string]interface{}{ + "enabled": cfg.Fetch.Cache.Enabled, + "ttl": cfg.Fetch.Cache.TTL, + "max_size": cfg.Fetch.Cache.MaxSize, + }, + } + jsonOutput, err := json.MarshalIndent(data, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal output: %w", err) + } + fmt.Println(string(jsonOutput)) + return nil + } + + fmt.Printf("Fetch Tool Status: ") + if cfg.Fetch.Enabled { + fmt.Printf("%s\n", ui.FormatSuccess("Enabled")) + } else { + fmt.Printf("%s\n", ui.FormatErrorCLI("Disabled")) + } + + fmt.Printf("\nWhitelisted Domains (%d):\n", len(cfg.Fetch.WhitelistedDomains)) + if len(cfg.Fetch.WhitelistedDomains) == 0 { + fmt.Printf(" • None configured\n") + } else { + for _, domain := range cfg.Fetch.WhitelistedDomains { + fmt.Printf(" • %s\n", domain) + } + } + + fmt.Printf("\nGitHub Integration:\n") + if cfg.Fetch.GitHub.Enabled { + fmt.Printf(" • Status: %s\n", ui.FormatSuccess("Enabled")) + fmt.Printf(" • Base URL: %s\n", cfg.Fetch.GitHub.BaseURL) + if cfg.Fetch.GitHub.Token != "" { + fmt.Printf(" • Token: %s\n", ui.FormatSuccess("Configured")) + } else { + fmt.Printf(" • Token: %s\n", ui.FormatWarning("Not configured")) + } + } else { + fmt.Printf(" • Status: %s\n", ui.FormatErrorCLI("Disabled")) + } + + fmt.Printf("\nSafety Settings:\n") + fmt.Printf(" • Max size: %d bytes (%.1f MB)\n", cfg.Fetch.Safety.MaxSize, float64(cfg.Fetch.Safety.MaxSize)/(1024*1024)) + fmt.Printf(" • Timeout: %d seconds\n", cfg.Fetch.Safety.Timeout) + fmt.Printf(" • Allow redirects: %t\n", cfg.Fetch.Safety.AllowRedirect) + + fmt.Printf("\nCache Settings:\n") + if cfg.Fetch.Cache.Enabled { + fmt.Printf(" • Status: %s\n", ui.FormatSuccess("Enabled")) + fmt.Printf(" • TTL: %d seconds\n", cfg.Fetch.Cache.TTL) + fmt.Printf(" • Max size: %d bytes (%.1f MB)\n", cfg.Fetch.Cache.MaxSize, float64(cfg.Fetch.Cache.MaxSize)/(1024*1024)) + } else { + fmt.Printf(" • Status: %s\n", ui.FormatErrorCLI("Disabled")) + } + + return nil +} + +func addFetchDomain(cmd *cobra.Command, args []string) error { + domainToAdd := args[0] + + // Basic domain validation + if strings.Contains(domainToAdd, "://") { + return fmt.Errorf("please provide just the domain (e.g., 'github.com'), not a full URL") + } + + _, err := loadAndUpdateConfig(func(c *config.Config) { + for _, existingDomain := range c.Fetch.WhitelistedDomains { + if existingDomain == domainToAdd { + return + } + } + c.Fetch.WhitelistedDomains = append(c.Fetch.WhitelistedDomains, domainToAdd) + }) + if err != nil { + return err + } + + fmt.Printf("%s\n", ui.FormatSuccess(fmt.Sprintf("Added '%s' to whitelisted domains", domainToAdd))) + fmt.Printf("LLMs can now fetch content from this domain and its subdomains\n") + return nil +} + +func removeFetchDomain(cmd *cobra.Command, args []string) error { + domainToRemove := args[0] + var found bool + + _, err := loadAndUpdateConfig(func(c *config.Config) { + for i, existingDomain := range c.Fetch.WhitelistedDomains { + if existingDomain == domainToRemove { + c.Fetch.WhitelistedDomains = append(c.Fetch.WhitelistedDomains[:i], c.Fetch.WhitelistedDomains[i+1:]...) + found = true + return + } + } + }) + if err != nil { + return err + } + + if !found { + fmt.Printf("%s\n", ui.FormatWarning(fmt.Sprintf("Domain '%s' was not in the whitelist", domainToRemove))) + return nil + } + + fmt.Printf("%s\n", ui.FormatSuccess(fmt.Sprintf("Removed '%s' from whitelisted domains", domainToRemove))) + fmt.Printf("LLMs can no longer fetch content from this domain\n") + return nil +} + +func enableGitHubFetch(cmd *cobra.Command, args []string) error { + _, err := loadAndUpdateConfig(func(c *config.Config) { + c.Fetch.GitHub.Enabled = true + }) + if err != nil { + return err + } + + fmt.Printf("%s\n", ui.FormatSuccess("GitHub integration enabled")) + fmt.Printf("LLMs can now use optimized GitHub fetching with 'github:owner/repo#123' syntax\n") + fmt.Printf("Set a GitHub token with 'infer config fetch github set-token ' for higher rate limits\n") + return nil +} + +func disableGitHubFetch(cmd *cobra.Command, args []string) error { + _, err := loadAndUpdateConfig(func(c *config.Config) { + c.Fetch.GitHub.Enabled = false + }) + if err != nil { + return err + } + + fmt.Printf("%s\n", ui.FormatErrorCLI("GitHub integration disabled")) + fmt.Printf("GitHub URLs will now be fetched using regular HTTP requests\n") + return nil +} + +func setGitHubToken(cmd *cobra.Command, args []string) error { + token := args[0] + + _, err := loadAndUpdateConfig(func(c *config.Config) { + c.Fetch.GitHub.Token = token + }) + if err != nil { + return err + } + + fmt.Printf("%s\n", ui.FormatSuccess("GitHub token configured successfully")) + fmt.Printf("GitHub API requests will now use authentication for higher rate limits\n") + return nil +} + +func fetchCacheStatus(cmd *cobra.Command, args []string) error { + cfg, err := config.LoadConfig("") + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + services := container.NewServiceContainer(cfg) + fetchService := services.GetFetchService() + stats := fetchService.GetCacheStats() + + fmt.Printf("Cache Status: ") + if cfg.Fetch.Cache.Enabled { + fmt.Printf("%s\n", ui.FormatSuccess("Enabled")) + } else { + fmt.Printf("%s\n", ui.FormatErrorCLI("Disabled")) + } + + fmt.Printf("\nCache Statistics:\n") + fmt.Printf(" • Entries: %d\n", stats["entries"]) + fmt.Printf(" • Total size: %d bytes (%.1f KB)\n", stats["total_size"], float64(stats["total_size"].(int64))/1024) + fmt.Printf(" • Max size: %d bytes (%.1f MB)\n", stats["max_size"], float64(stats["max_size"].(int64))/(1024*1024)) + fmt.Printf(" • TTL: %d seconds\n", stats["ttl"]) + + return nil +} + +func fetchCacheClear(cmd *cobra.Command, args []string) error { + cfg, err := config.LoadConfig("") + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + services := container.NewServiceContainer(cfg) + fetchService := services.GetFetchService() + fetchService.ClearCache() + + fmt.Printf("%s\n", ui.FormatSuccess("Fetch cache cleared successfully")) + fmt.Printf("All cached content has been removed\n") + return nil +} diff --git a/config/config.go b/config/config.go index 155f0db9..b4b1d98e 100644 --- a/config/config.go +++ b/config/config.go @@ -17,6 +17,7 @@ type Config struct { Tools ToolsConfig `yaml:"tools"` Compact CompactConfig `yaml:"compact"` Chat ChatConfig `yaml:"chat"` + Fetch FetchConfig `yaml:"fetch"` } // GatewayConfig contains gateway connection settings @@ -62,6 +63,36 @@ type ChatConfig struct { SystemPrompt string `yaml:"system_prompt"` } +// FetchConfig contains settings for content fetching +type FetchConfig struct { + Enabled bool `yaml:"enabled"` + WhitelistedDomains []string `yaml:"whitelisted_domains"` + GitHub GitHubFetchConfig `yaml:"github"` + Safety FetchSafetyConfig `yaml:"safety"` + Cache FetchCacheConfig `yaml:"cache"` +} + +// GitHubFetchConfig contains GitHub-specific fetch settings +type GitHubFetchConfig struct { + Enabled bool `yaml:"enabled"` + Token string `yaml:"token"` + BaseURL string `yaml:"base_url"` +} + +// FetchSafetyConfig contains safety settings for fetch operations +type FetchSafetyConfig struct { + MaxSize int64 `yaml:"max_size"` + Timeout int `yaml:"timeout"` + AllowRedirect bool `yaml:"allow_redirect"` +} + +// FetchCacheConfig contains cache settings for fetch operations +type FetchCacheConfig struct { + Enabled bool `yaml:"enabled"` + TTL int `yaml:"ttl"` + MaxSize int64 `yaml:"max_size"` +} + // DefaultConfig returns a default configuration func DefaultConfig() *Config { return &Config{ @@ -103,6 +134,25 @@ func DefaultConfig() *Config { DefaultModel: "", SystemPrompt: "", }, + Fetch: FetchConfig{ + Enabled: false, + WhitelistedDomains: []string{"github.com"}, + GitHub: GitHubFetchConfig{ + Enabled: false, + Token: "", + BaseURL: "https://api.github.com", + }, + Safety: FetchSafetyConfig{ + MaxSize: 8192, // 8KB + Timeout: 30, // 30 seconds + AllowRedirect: true, + }, + Cache: FetchCacheConfig{ + Enabled: true, + TTL: 3600, // 1 hour + MaxSize: 52428800, // 50MB + }, + }, } } diff --git a/internal/app/chat_application.go b/internal/app/chat_application.go index 08e598f1..7b96b299 100644 --- a/internal/app/chat_application.go +++ b/internal/app/chat_application.go @@ -774,34 +774,41 @@ func (app *ChatApplication) approveToolCall() tea.Cmd { toolService := app.services.GetToolService() result, err := toolService.ExecuteTool(context.Background(), toolCall.Name, toolCall.Arguments) + var toolContent string if err != nil { - return ui.ShowErrorMsg{ - Error: fmt.Sprintf("Tool execution failed: %v", err), - Sticky: true, - } + toolContent = fmt.Sprintf("Tool execution failed: %v", err) + } else { + toolContent = result } toolResultEntry := domain.ConversationEntry{ Message: sdk.Message{ Role: sdk.Tool, - Content: result, + Content: toolContent, ToolCallId: &toolCall.ID, }, Time: time.Now(), } conversationRepo := app.services.GetConversationRepository() - if err := conversationRepo.AddMessage(toolResultEntry); err != nil { + if saveErr := conversationRepo.AddMessage(toolResultEntry); saveErr != nil { return ui.ShowErrorMsg{ - Error: fmt.Sprintf("Failed to save tool result: %v", err), + Error: fmt.Sprintf("Failed to save tool result: %v", saveErr), Sticky: false, } } + var statusMessage string + if err != nil { + statusMessage = fmt.Sprintf("Tool failed: %s - sending error to model...", toolCall.Name) + } else { + statusMessage = ui.FormatSuccess(fmt.Sprintf("Tool executed: %s - sending to model...", toolCall.Name)) + } + return tea.Batch( func() tea.Msg { return ui.SetStatusMsg{ - Message: ui.FormatSuccess(fmt.Sprintf("Tool executed: %s - sending to model...", toolCall.Name)), + Message: statusMessage, Spinner: true, } }, diff --git a/internal/container/container.go b/internal/container/container.go index 5948baf7..02091118 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -20,6 +20,7 @@ type ServiceContainer struct { chatService domain.ChatService toolService domain.ToolService fileService domain.FileService + fetchService domain.FetchService // UI components theme ui.Theme @@ -57,8 +58,10 @@ func (c *ServiceContainer) initializeDomainServices() { c.fileService = services.NewLocalFileService(c.config) + c.fetchService = services.NewFetchService(c.config) + if c.config.Tools.Enabled { - c.toolService = services.NewLLMToolService(c.config, c.fileService) + c.toolService = services.NewLLMToolService(c.config, c.fileService, c.fetchService) } else { c.toolService = services.NewNoOpToolService() } @@ -145,6 +148,10 @@ func (c *ServiceContainer) GetFileService() domain.FileService { return c.fileService } +func (c *ServiceContainer) GetFetchService() domain.FetchService { + return c.fetchService +} + func (c *ServiceContainer) GetTheme() ui.Theme { return c.theme } diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 58d88c94..830b5fa3 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -113,3 +113,22 @@ type FileInfo struct { Size int64 IsDir bool } + +// FetchResult represents the result of a fetch operation +type FetchResult struct { + Content string `json:"content"` + URL string `json:"url"` + Status int `json:"status"` + Size int64 `json:"size"` + ContentType string `json:"content_type"` + Cached bool `json:"cached"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// FetchService handles content fetching operations +type FetchService interface { + ValidateURL(url string) error + FetchContent(ctx context.Context, target string) (*FetchResult, error) + ClearCache() + GetCacheStats() map[string]interface{} +} diff --git a/internal/services/fetch.go b/internal/services/fetch.go new file mode 100644 index 00000000..15642ec3 --- /dev/null +++ b/internal/services/fetch.go @@ -0,0 +1,391 @@ +package services + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/inference-gateway/cli/config" + "github.com/inference-gateway/cli/internal/domain" + "github.com/inference-gateway/cli/internal/logger" +) + +// FetchService handles content fetching operations +type FetchService struct { + config *config.Config + client *http.Client + cache map[string]CacheEntry +} + +// CacheEntry represents a cached fetch result +type CacheEntry struct { + Content string `json:"content"` + Timestamp time.Time `json:"timestamp"` + URL string `json:"url"` +} + +// GitHubReference represents a GitHub issue or PR reference +type GitHubReference struct { + Owner string + Repo string + Number int + Type string // "issue" or "pull" +} + +// NewFetchService creates a new FetchService +func NewFetchService(cfg *config.Config) *FetchService { + return &FetchService{ + config: cfg, + client: &http.Client{ + Timeout: time.Duration(cfg.Fetch.Safety.Timeout) * time.Second, + }, + cache: make(map[string]CacheEntry), + } +} + +// ValidateURL checks if a URL's domain is whitelisted for fetching +func (f *FetchService) ValidateURL(targetURL string) error { + if !f.config.Fetch.Enabled { + return fmt.Errorf("fetch tool is not enabled - use 'infer config fetch enable' to enable it") + } + + if len(f.config.Fetch.WhitelistedDomains) == 0 { + return fmt.Errorf("no whitelisted domains configured - use 'infer config fetch add-domain' to configure allowed domains") + } + + parsedURL, err := url.Parse(targetURL) + if err != nil { + return fmt.Errorf("invalid URL format: %w", err) + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return fmt.Errorf("only HTTP and HTTPS URLs are allowed") + } + + domain := parsedURL.Hostname() + if domain == "" { + return fmt.Errorf("unable to extract domain from URL: %s", targetURL) + } + + for _, whitelistedDomain := range f.config.Fetch.WhitelistedDomains { + if domain == whitelistedDomain || strings.HasSuffix(domain, "."+whitelistedDomain) { + logger.Debug("URL domain matches whitelist", "url", targetURL, "domain", domain, "whitelisted", whitelistedDomain) + return nil + } + } + + return fmt.Errorf("domain not whitelisted: %s (from URL: %s)", domain, targetURL) +} + +// FetchContent fetches content from a URL or GitHub reference +func (f *FetchService) FetchContent(ctx context.Context, target string) (*domain.FetchResult, error) { + if err := f.ValidateURL(target); err != nil { + if githubRef, parseErr := f.parseGitHubReference(target); parseErr == nil { + return f.fetchGitHubContent(ctx, githubRef) + } + return nil, err + } + + if entry, found := f.getCachedContent(target); found { + logger.Debug("Returning cached content", "url", target) + return &domain.FetchResult{ + Content: entry.Content, + URL: entry.URL, + Cached: true, + ContentType: "text/plain", + }, nil + } + + return f.fetchURL(ctx, target) +} + +// fetchURL performs the actual HTTP request +func (f *FetchService) fetchURL(ctx context.Context, targetURL string) (*domain.FetchResult, error) { + req, err := http.NewRequestWithContext(ctx, "GET", targetURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("User-Agent", "Inference-Gateway-CLI/1.0") + + resp, err := f.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch URL: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + logger.Error("Failed to close response body", "error", err) + } + }() + + if resp.StatusCode != http.StatusOK { + return &domain.FetchResult{ + URL: targetURL, + Status: resp.StatusCode, + }, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status) + } + + if resp.ContentLength > f.config.Fetch.Safety.MaxSize { + return nil, fmt.Errorf("content too large: %d bytes (max: %d bytes)", resp.ContentLength, f.config.Fetch.Safety.MaxSize) + } + + limitedReader := io.LimitReader(resp.Body, f.config.Fetch.Safety.MaxSize) + content, err := io.ReadAll(limitedReader) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + result := &domain.FetchResult{ + Content: string(content), + URL: targetURL, + Status: resp.StatusCode, + Size: int64(len(content)), + ContentType: resp.Header.Get("Content-Type"), + Cached: false, + } + + f.cacheContent(targetURL, string(content)) + + logger.Debug("Successfully fetched content", "url", targetURL, "size", len(content), "status", resp.StatusCode) + return result, nil +} + +// parseGitHubReference parses GitHub-style references like "github:owner/repo#123" or "source=github ticket_nr=123" +func (f *FetchService) parseGitHubReference(target string) (*GitHubReference, error) { + if strings.HasPrefix(target, "github:") { + return f.parseGitHubColonReference(target) + } + + if strings.Contains(target, "source=github") { + return f.parseGitHubParameterReference(target) + } + + return nil, fmt.Errorf("not a GitHub reference") +} + +// parseGitHubColonReference parses "github:owner/repo#123" format +func (f *FetchService) parseGitHubColonReference(target string) (*GitHubReference, error) { + reference := strings.TrimPrefix(target, "github:") + + parts := strings.Split(reference, "#") + if len(parts) != 2 { + return nil, fmt.Errorf("invalid GitHub reference format, expected: github:owner/repo#123") + } + + repoPath := parts[0] + numberStr := parts[1] + + repoParts := strings.Split(repoPath, "/") + if len(repoParts) != 2 { + return nil, fmt.Errorf("invalid repository format, expected: owner/repo") + } + + number, err := strconv.Atoi(numberStr) + if err != nil { + return nil, fmt.Errorf("invalid issue/PR number: %w", err) + } + + return &GitHubReference{ + Owner: repoParts[0], + Repo: repoParts[1], + Number: number, + Type: "issue", // Will be determined by GitHub API response + }, nil +} + +// parseGitHubParameterReference parses "source=github ticket_nr=123" format +func (f *FetchService) parseGitHubParameterReference(target string) (*GitHubReference, error) { + var ticketNr int + var owner, repo string + + parts := strings.Fields(target) + for _, part := range parts { + if strings.HasPrefix(part, "ticket_nr=") { + numberStr := strings.TrimPrefix(part, "ticket_nr=") + var err error + ticketNr, err = strconv.Atoi(numberStr) + if err != nil { + return nil, fmt.Errorf("invalid ticket number: %w", err) + } + } else if strings.HasPrefix(part, "owner=") { + owner = strings.TrimPrefix(part, "owner=") + } else if strings.HasPrefix(part, "repo=") { + repo = strings.TrimPrefix(part, "repo=") + } + } + + if ticketNr == 0 { + return nil, fmt.Errorf("missing ticket_nr parameter") + } + + if owner == "" { + owner = "inference-gateway" + } + if repo == "" { + repo = "cli" + } + + return &GitHubReference{ + Owner: owner, + Repo: repo, + Number: ticketNr, + Type: "issue", // Will be determined by GitHub API response + }, nil +} + +// fetchGitHubContent fetches content from GitHub API +func (f *FetchService) fetchGitHubContent(ctx context.Context, ref *GitHubReference) (*domain.FetchResult, error) { + if !f.config.Fetch.GitHub.Enabled { + return nil, fmt.Errorf("GitHub integration is not enabled - use 'infer config fetch github enable' to enable it") + } + + apiURL := fmt.Sprintf("%s/repos/%s/%s/issues/%d", f.config.Fetch.GitHub.BaseURL, ref.Owner, ref.Repo, ref.Number) + + req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create GitHub API request: %w", err) + } + + req.Header.Set("User-Agent", "Inference-Gateway-CLI/1.0") + req.Header.Set("Accept", "application/vnd.github.v3+json") + + if f.config.Fetch.GitHub.Token != "" { + req.Header.Set("Authorization", "token "+f.config.Fetch.GitHub.Token) + } + + resp, err := f.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch from GitHub API: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + logger.Error("Failed to close GitHub API response body", "error", err) + } + }() + + if resp.StatusCode != http.StatusOK { + return &domain.FetchResult{ + URL: apiURL, + Status: resp.StatusCode, + }, fmt.Errorf("GitHub API error: HTTP %d", resp.StatusCode) + } + + var issue struct { + Title string `json:"title"` + Body string `json:"body"` + State string `json:"state"` + Number int `json:"number"` + HTMLURL string `json:"html_url"` + User struct { + Login string `json:"login"` + } `json:"user"` + PullRequest *struct{} `json:"pull_request,omitempty"` + } + + if err := json.NewDecoder(resp.Body).Decode(&issue); err != nil { + return nil, fmt.Errorf("failed to parse GitHub API response: %w", err) + } + + ref.Type = "issue" + if issue.PullRequest != nil { + ref.Type = "pull" + } + + var typeTitle string + switch ref.Type { + case "issue": + typeTitle = "Issue" + case "pull": + typeTitle = "Pull Request" + default: + typeTitle = ref.Type + } + + content := fmt.Sprintf("# %s #%d: %s\n\n**Author:** @%s\n**State:** %s\n**URL:** %s\n\n%s", + typeTitle, issue.Number, issue.Title, issue.User.Login, issue.State, issue.HTMLURL, issue.Body) + + result := &domain.FetchResult{ + Content: content, + URL: apiURL, + Status: resp.StatusCode, + Size: int64(len(content)), + ContentType: "application/json", + Cached: false, + Metadata: map[string]string{ + "github_type": ref.Type, + "github_owner": ref.Owner, + "github_repo": ref.Repo, + "github_number": strconv.Itoa(ref.Number), + "github_title": issue.Title, + "github_state": issue.State, + }, + } + + cacheKey := fmt.Sprintf("github:%s/%s#%d", ref.Owner, ref.Repo, ref.Number) + f.cacheContent(cacheKey, content) + + logger.Debug("Successfully fetched GitHub content", "type", ref.Type, "owner", ref.Owner, "repo", ref.Repo, "number", ref.Number) + return result, nil +} + +// getCachedContent retrieves content from cache if available and not expired +func (f *FetchService) getCachedContent(url string) (CacheEntry, bool) { + if !f.config.Fetch.Cache.Enabled { + return CacheEntry{}, false + } + + entry, exists := f.cache[url] + if !exists { + return CacheEntry{}, false + } + + if time.Since(entry.Timestamp) > time.Duration(f.config.Fetch.Cache.TTL)*time.Second { + delete(f.cache, url) + return CacheEntry{}, false + } + + return entry, true +} + +// cacheContent stores content in cache +func (f *FetchService) cacheContent(url, content string) { + if !f.config.Fetch.Cache.Enabled { + return + } + + f.cache[url] = CacheEntry{ + Content: content, + Timestamp: time.Now(), + URL: url, + } + + logger.Debug("Content cached", "url", url, "size", len(content)) +} + +// ClearCache clears all cached content +func (f *FetchService) ClearCache() { + f.cache = make(map[string]CacheEntry) + logger.Debug("Cache cleared") +} + +// GetCacheStats returns cache statistics +func (f *FetchService) GetCacheStats() map[string]interface{} { + totalSize := int64(0) + for _, entry := range f.cache { + totalSize += int64(len(entry.Content)) + } + + return map[string]interface{}{ + "entries": len(f.cache), + "total_size": totalSize, + "enabled": f.config.Fetch.Cache.Enabled, + "max_size": f.config.Fetch.Cache.MaxSize, + "ttl": f.config.Fetch.Cache.TTL, + } +} diff --git a/internal/services/tool.go b/internal/services/tool.go index 4bf651ff..f635ad6d 100644 --- a/internal/services/tool.go +++ b/internal/services/tool.go @@ -34,17 +34,19 @@ type FileReadResult struct { // LLMToolService implements ToolService with direct tool execution type LLMToolService struct { - config *config.Config - fileService domain.FileService - enabled bool + config *config.Config + fileService domain.FileService + fetchService domain.FetchService + enabled bool } // NewLLMToolService creates a new LLM tool service -func NewLLMToolService(cfg *config.Config, fileService domain.FileService) *LLMToolService { +func NewLLMToolService(cfg *config.Config, fileService domain.FileService, fetchService domain.FetchService) *LLMToolService { return &LLMToolService{ - config: cfg, - fileService: fileService, - enabled: cfg.Tools.Enabled, + config: cfg, + fileService: fileService, + fetchService: fetchService, + enabled: cfg.Tools.Enabled, } } @@ -53,7 +55,7 @@ func (s *LLMToolService) ListTools() []domain.ToolDefinition { return []domain.ToolDefinition{} } - return []domain.ToolDefinition{ + tools := []domain.ToolDefinition{ { Name: "Bash", Description: "Execute whitelisted bash commands securely", @@ -105,6 +107,31 @@ func (s *LLMToolService) ListTools() []domain.ToolDefinition { }, }, } + + if s.config.Fetch.Enabled { + tools = append(tools, domain.ToolDefinition{ + Name: "Fetch", + Description: "Fetch content from whitelisted URLs or GitHub references. Supports 'github:owner/repo#123' syntax for GitHub issues/PRs.", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "url": map[string]interface{}{ + "type": "string", + "description": "The URL to fetch content from, or GitHub reference (github:owner/repo#123)", + }, + "format": map[string]interface{}{ + "type": "string", + "description": "Output format (text or json)", + "enum": []string{"text", "json"}, + "default": "text", + }, + }, + "required": []string{"url"}, + }, + }) + } + + return tools } func (s *LLMToolService) ExecuteTool(ctx context.Context, name string, args map[string]interface{}) (string, error) { @@ -117,6 +144,8 @@ func (s *LLMToolService) ExecuteTool(ctx context.Context, name string, args map[ return s.executeBashTool(ctx, args) case "Read": return s.executeReadTool(args) + case "Fetch": + return s.executeFetchTool(ctx, args) default: return "", fmt.Errorf("unknown tool: %s", name) } @@ -150,6 +179,8 @@ func (s *LLMToolService) ValidateTool(name string, args map[string]interface{}) return s.validateBashTool(args) case "Read": return s.validateReadTool(args) + case "Fetch": + return s.validateFetchTool(args) default: return nil } @@ -405,3 +436,80 @@ func (s *LLMToolService) formatReadResult(result *FileReadResult) string { output += fmt.Sprintf("Content:\n%s", result.Content) return output } + +// executeFetchTool handles Fetch tool execution +func (s *LLMToolService) executeFetchTool(ctx context.Context, args map[string]interface{}) (string, error) { + if !s.config.Fetch.Enabled { + return "", fmt.Errorf("fetch tool is not enabled") + } + + url, ok := args["url"].(string) + if !ok { + return "", fmt.Errorf("url parameter is required and must be a string") + } + + format, ok := args["format"].(string) + if !ok { + format = "text" + } + + result, err := s.fetchService.FetchContent(ctx, url) + if err != nil { + return "", fmt.Errorf("fetch failed: %w", err) + } + + if format == "json" { + jsonOutput, err := json.MarshalIndent(result, "", " ") + if err != nil { + return "", fmt.Errorf("failed to marshal result: %w", err) + } + return string(jsonOutput), nil + } + + return s.formatFetchResult(result), nil +} + +// validateFetchTool validates Fetch tool arguments +func (s *LLMToolService) validateFetchTool(args map[string]interface{}) error { + if !s.config.Fetch.Enabled { + return fmt.Errorf("fetch tool is not enabled") + } + + url, ok := args["url"].(string) + if !ok { + return fmt.Errorf("url parameter is required and must be a string") + } + + if err := s.fetchService.ValidateURL(url); err != nil { + return fmt.Errorf("URL validation failed: %w", err) + } + + return nil +} + +// formatFetchResult formats fetch result for text output +func (s *LLMToolService) formatFetchResult(result *domain.FetchResult) string { + output := fmt.Sprintf("URL: %s\n", result.URL) + if result.Status > 0 { + output += fmt.Sprintf("Status: %d\n", result.Status) + } + output += fmt.Sprintf("Size: %d bytes\n", result.Size) + if result.ContentType != "" { + output += fmt.Sprintf("Content-Type: %s\n", result.ContentType) + } + if result.Cached { + output += "Source: Cache\n" + } else { + output += "Source: Live\n" + } + + if len(result.Metadata) > 0 { + output += "Metadata:\n" + for key, value := range result.Metadata { + output += fmt.Sprintf(" %s: %s\n", key, value) + } + } + + output += fmt.Sprintf("Content:\n%s", result.Content) + return output +}