diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 59f84e524cc7..0a8a9bce30aa 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -11,6 +11,7 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/websearch" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/core/templates" @@ -289,6 +290,45 @@ func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator return echo.ErrBadRequest } + // [Deep Research] Web Search Logic + if input.WebSearchOptions != nil && len(input.Messages) > 0 { + lastIdx := len(input.Messages) - 1 + lastMsg := input.Messages[lastIdx] + + // Only search if the last message is from the user + if lastMsg.Role == "user" && lastMsg.Content != nil { + var query string + + // Safe Type Assertion: Handle Content as string or *string + switch v := lastMsg.Content.(type) { + case string: + query = v + case *string: + if v != nil { + query = *v + } + } + + // Only proceed if we successfully extracted a query + if query != "" { + xlog.Debug("Web Search requested", "query", query) + + searcher := websearch.New() + citations, err := searcher.Search(c.Request().Context(), query) + if err != nil { + xlog.Error("Web search failed", "error", err) + } else if len(citations) > 0 { + // Augment the prompt with search results + newContent := websearch.AugmmentedSystemPrompt(query, citations) + + // Assign back to the interface{} field + // used a pointer to string to be safe with most LocalAI versions + input.Messages[lastIdx].Content = &newContent + } + } + } + } + extraUsage := c.Request().Header.Get("Extra-Usage") != "" config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) diff --git a/core/schema/message.go b/core/schema/message.go index b488ea49991b..d7a4fde98f16 100644 --- a/core/schema/message.go +++ b/core/schema/message.go @@ -11,17 +11,18 @@ import ( type Message struct { // The message role Role string `json:"role,omitempty" yaml:"role"` - // The message name (used for tools calls) Name string `json:"name,omitempty" yaml:"name"` // The message content Content interface{} `json:"content" yaml:"content"` - StringContent string `json:"string_content,omitempty" yaml:"string_content,omitempty"` - StringImages []string `json:"string_images,omitempty" yaml:"string_images,omitempty"` - StringVideos []string `json:"string_videos,omitempty" yaml:"string_videos,omitempty"` - StringAudios []string `json:"string_audios,omitempty" yaml:"string_audios,omitempty"` + // Annotations for citations + Annotations []Annotation `json:"annotations,omitempty"` + StringContent string `json:"string_content,omitempty" yaml:"string_content,omitempty"` + StringImages []string `json:"string_images,omitempty" yaml:"string_images,omitempty"` + StringVideos []string `json:"string_videos,omitempty" yaml:"string_videos,omitempty"` + StringAudios []string `json:"string_audios,omitempty" yaml:"string_audios,omitempty"` // A result of a function call FunctionCall interface{} `json:"function_call,omitempty" yaml:"function_call,omitempty"` diff --git a/core/schema/openai.go b/core/schema/openai.go index 74ed2859e3e2..4d7995b82673 100644 --- a/core/schema/openai.go +++ b/core/schema/openai.go @@ -6,6 +6,33 @@ import ( functions "github.com/mudler/LocalAI/pkg/functions" ) +type ApproximateLocation struct { + Country string `json:"country,omitempty"` + City string `json:"city,omitempty"` + Region string `json:"region,omitempty"` +} + +type UserLocation struct { + Type string `json:"type,omitempty"` + Approximate *ApproximateLocation `json:"approximate,omitempty"` +} + +type WebSearchOptions struct { + UserLocation *UserLocation `json:"user_location,omitempty"` +} + +type UrlCitation struct { + EndIndex int `json:"end_index,omitempty"` + StartIndex int `json:"start_index,omitempty"` + Title string `json:"title,omitempty"` + URL string `json:"url,omitempty"` +} + +type Annotation struct { + Type string `json:"type,omitempty"` + UrlCitation *UrlCitation `json:"url_citation,omitempty"` +} + // APIError provides error information returned by the OpenAI API. type APIError struct { Code any `json:"code,omitempty"` @@ -150,6 +177,9 @@ type OpenAIRequest struct { // Messages is read only by chat/completion API calls Messages []Message `json:"messages" yaml:"messages"` + // WebSearchOptions for Deep Research + WebSearchOptions *WebSearchOptions `json:"web_search_options,omitempty"` + // A list of available functions to call Functions functions.Functions `json:"functions" yaml:"functions"` FunctionCall interface{} `json:"function_call" yaml:"function_call"` // might be a string or an object diff --git a/core/services/websearch/seach.go b/core/services/websearch/seach.go new file mode 100644 index 000000000000..dbffe973b56f --- /dev/null +++ b/core/services/websearch/seach.go @@ -0,0 +1,59 @@ +package websearch + +import ( + "context" + "fmt" + "strings" + + "github.com/mudler/LocalAI/core/schema" +) + +type Searcher interface { + Search(ctx context.Context, query string) ([]schema.UrlCitation, error) +} + +type SimpleSearch struct{} + +func New() *SimpleSearch { + return &SimpleSearch{} +} + +func (s *SimpleSearch) Search(ctx context.Context, query string) ([]schema.UrlCitation, error) { + // TODO: Implement actual DuckDuckGo, Google Custom Search, or browsing logic here. + // For now, returned a placeholder to validate the API schema flow. + + fmt.Printf("[WebSearch] Searching for :%s\n", query) + + results := []schema.UrlCitation{ + { + Title: "LocalAI Documentation", + URL: "https://localai.io", + StartIndex: 0, + EndIndex: 100, // Arbitrary indices for citation highlighting + }, + { + Title: "Github - LocalAI", + URL: "https://github.com/mudler/LocalAI", + StartIndex: 0, + EndIndex: 0, + }, + } + return results, nil +} + +// AugmmentedSystemPrompt adds search results to the context +func AugmmentedSystemPrompt(originalPrompt string, citations []schema.UrlCitation) string { + + var sb strings.Builder + sb.WriteString("I found the following information from the web:\n\n") + + for i, c := range citations { + sb.WriteString(fmt.Sprintf("[%d] %s (%s)\n", i+1, c.Title, c.URL)) + } + + sb.WriteString("\nPlease use this information to answer the user's question.\n\n") + sb.WriteString("Original System Promt:\n") + sb.WriteString(originalPrompt) + + return sb.String() +}