-
Notifications
You must be signed in to change notification settings - Fork 479
feat: add Tavily as configurable search engine in core search framework #891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
| # flake8: noqa | ||
| from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult | ||
| from ms_agent.tools.search.tavily.search import TavilySearch |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,101 @@ | ||||||
| # flake8: noqa | ||||||
| from dataclasses import dataclass, field | ||||||
| from typing import Any, Dict, List, Optional | ||||||
|
|
||||||
| import json | ||||||
|
|
||||||
| from ms_agent.tools.search.search_base import BaseResult, SearchResponse | ||||||
|
|
||||||
|
|
||||||
| @dataclass | ||||||
| class TavilySearchRequest: | ||||||
|
|
||||||
| # The search query string | ||||||
| query: str | ||||||
|
|
||||||
| # Number of results to return, default is 5 | ||||||
| num_results: Optional[int] = 5 | ||||||
|
|
||||||
| # Search depth: 'basic' or 'advanced' | ||||||
| search_depth: Optional[str] = 'basic' | ||||||
|
|
||||||
| # Topic category: 'general', 'news', or 'finance' | ||||||
| topic: Optional[str] = 'general' | ||||||
|
|
||||||
| # Domains to include in search | ||||||
| include_domains: Optional[List[str]] = None | ||||||
|
|
||||||
| # Domains to exclude from search | ||||||
| exclude_domains: Optional[List[str]] = None | ||||||
|
|
||||||
| def to_dict(self) -> Dict[str, Any]: | ||||||
| """ | ||||||
| Convert the request parameters to a dictionary. | ||||||
| """ | ||||||
| d = { | ||||||
| 'query': self.query, | ||||||
| 'max_results': self.num_results, | ||||||
| 'search_depth': self.search_depth, | ||||||
| 'topic': self.topic, | ||||||
| } | ||||||
| if self.include_domains: | ||||||
| d['include_domains'] = self.include_domains | ||||||
| if self.exclude_domains: | ||||||
| d['exclude_domains'] = self.exclude_domains | ||||||
| return d | ||||||
|
|
||||||
| def to_json(self) -> str: | ||||||
| """ | ||||||
| Convert the request parameters to a JSON string. | ||||||
| """ | ||||||
| return json.dumps(self.to_dict(), ensure_ascii=False) | ||||||
|
|
||||||
|
|
||||||
| @dataclass | ||||||
| class TavilySearchResult: | ||||||
|
|
||||||
| # The original search query string | ||||||
| query: str | ||||||
|
|
||||||
| # Optional arguments for the search request | ||||||
| arguments: Dict[str, Any] = field(default_factory=dict) | ||||||
|
|
||||||
| # The response from the Tavily search API (dict with 'results' key) | ||||||
| response: SearchResponse = None | ||||||
|
|
||||||
| def to_list(self): | ||||||
| """ | ||||||
| Convert the search results to a list of dictionaries. | ||||||
| """ | ||||||
| if not self.response or not self.response.results: | ||||||
| print('***Warning: No search results found.') | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| return [] | ||||||
|
|
||||||
| if not self.query: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| print('***Warning: No query provided for search results.') | ||||||
| return [] | ||||||
|
|
||||||
| res_list: List[Any] = [] | ||||||
| for res in self.response.results: | ||||||
| res_list.append({ | ||||||
| 'url': res.url, | ||||||
| 'id': res.id, | ||||||
| 'title': res.title, | ||||||
| 'highlights': res.highlights, | ||||||
| 'highlight_scores': res.highlight_scores, | ||||||
| 'summary': res.summary, | ||||||
| 'markdown': res.markdown, | ||||||
| }) | ||||||
|
|
||||||
| return res_list | ||||||
|
|
||||||
| @staticmethod | ||||||
| def load_from_disk(file_path: str) -> List[Dict[str, Any]]: | ||||||
| """ | ||||||
| Load search results from a local file. | ||||||
| """ | ||||||
| with open(file_path, 'r', encoding='utf-8') as f: | ||||||
| data = json.load(f) | ||||||
| print(f'Search results loaded from {file_path}') | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||
|
|
||||||
| return data | ||||||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,125 @@ | ||||||||
| # flake8: noqa | ||||||||
| import os | ||||||||
| from typing import TYPE_CHECKING | ||||||||
|
|
||||||||
| from tavily import TavilyClient | ||||||||
| from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult | ||||||||
| from ms_agent.tools.search.search_base import (BaseResult, SearchEngine, | ||||||||
| SearchEngineType, | ||||||||
| SearchResponse) | ||||||||
|
|
||||||||
| if TYPE_CHECKING: | ||||||||
| from ms_agent.llm.utils import Tool | ||||||||
|
|
||||||||
|
|
||||||||
| class TavilySearch(SearchEngine): | ||||||||
| """ | ||||||||
| Search engine using Tavily API. | ||||||||
|
|
||||||||
| Best for: AI-optimized web search, general and news queries, | ||||||||
| high relevance results with built-in content extraction. | ||||||||
| """ | ||||||||
|
|
||||||||
| engine_type = SearchEngineType.TAVILY | ||||||||
|
|
||||||||
| def __init__(self, api_key: str = None): | ||||||||
|
|
||||||||
| api_key = api_key or os.getenv('TAVILY_API_KEY') | ||||||||
| assert api_key, 'TAVILY_API_KEY must be set either as an argument or as an environment variable' | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using
Suggested change
|
||||||||
|
|
||||||||
| self.client = TavilyClient(api_key=api_key) | ||||||||
|
|
||||||||
| def search(self, search_request: TavilySearchRequest) -> TavilySearchResult: | ||||||||
| """ | ||||||||
| Perform a search using the Tavily API with the provided search request parameters. | ||||||||
|
|
||||||||
| :param search_request: An instance of TavilySearchRequest containing search parameters. | ||||||||
| :return: An instance of TavilySearchResult containing the search results. | ||||||||
| """ | ||||||||
| search_args: dict = search_request.to_dict() | ||||||||
| search_result: TavilySearchResult = TavilySearchResult( | ||||||||
| query=search_request.query, | ||||||||
| arguments=search_args, | ||||||||
| ) | ||||||||
| try: | ||||||||
| raw_response = self.client.search(**search_args) | ||||||||
| # Map Tavily results to BaseResult schema | ||||||||
| results = [] | ||||||||
| for item in raw_response.get('results', []): | ||||||||
| results.append( | ||||||||
| BaseResult( | ||||||||
| url=item.get('url', ''), | ||||||||
| id=item.get('url', ''), | ||||||||
| title=item.get('title', ''), | ||||||||
| summary=item.get('content', ''), | ||||||||
| markdown=item.get('raw_content'), | ||||||||
| )) | ||||||||
| search_result.response = SearchResponse(results=results) | ||||||||
| except Exception as e: | ||||||||
| raise RuntimeError(f'Failed to perform search: {e}') from e | ||||||||
|
|
||||||||
| return search_result | ||||||||
|
|
||||||||
| @classmethod | ||||||||
| def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool': | ||||||||
| """Return the tool definition for Tavily search engine.""" | ||||||||
| from ms_agent.llm.utils import Tool | ||||||||
| return Tool( | ||||||||
| tool_name=cls.get_tool_name(), | ||||||||
| server_name=server_name, | ||||||||
| description=( | ||||||||
| 'Search the web using Tavily AI-optimized search engine. ' | ||||||||
| 'Best for: general web queries, news, and finance topics. ' | ||||||||
| 'Returns highly relevant results with content extraction.'), | ||||||||
| parameters={ | ||||||||
| 'type': 'object', | ||||||||
| 'properties': { | ||||||||
| 'query': { | ||||||||
| 'type': | ||||||||
| 'string', | ||||||||
| 'description': | ||||||||
| 'The search query. Use natural language for best results.', | ||||||||
| }, | ||||||||
| 'num_results': { | ||||||||
| 'type': | ||||||||
| 'integer', | ||||||||
| 'minimum': | ||||||||
| 1, | ||||||||
| 'maximum': | ||||||||
| 10, | ||||||||
| 'description': | ||||||||
| 'Number of results to return. Default is 5.', | ||||||||
| }, | ||||||||
| 'search_depth': { | ||||||||
| 'type': | ||||||||
| 'string', | ||||||||
| 'enum': ['basic', 'advanced'], | ||||||||
| 'description': | ||||||||
| ('Search depth. "basic" for fast results, ' | ||||||||
| '"advanced" for higher relevance. Default is "basic".' | ||||||||
| ), | ||||||||
| }, | ||||||||
| 'topic': { | ||||||||
| 'type': | ||||||||
| 'string', | ||||||||
| 'enum': ['general', 'news', 'finance'], | ||||||||
| 'description': | ||||||||
| ('Topic category for the search. ' | ||||||||
| 'Default is "general".'), | ||||||||
| }, | ||||||||
| }, | ||||||||
| 'required': ['query'], | ||||||||
| }, | ||||||||
| ) | ||||||||
|
|
||||||||
| @classmethod | ||||||||
| def build_request_from_args(cls, **kwargs) -> TavilySearchRequest: | ||||||||
| """Build TavilySearchRequest from tool call arguments.""" | ||||||||
| return TavilySearchRequest( | ||||||||
| query=kwargs['query'], | ||||||||
| num_results=kwargs.get('num_results', 5), | ||||||||
| search_depth=kwargs.get('search_depth', 'basic'), | ||||||||
| topic=kwargs.get('topic', 'general'), | ||||||||
| include_domains=kwargs.get('include_domains'), | ||||||||
| exclude_domains=kwargs.get('exclude_domains'), | ||||||||
| ) | ||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,3 +14,4 @@ Pillow | |
| python-dotenv | ||
| requests | ||
| rich | ||
| tavily-python | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type hint for
responseshould specify the generic type ofSearchResponse. Since theTavilySearch.searchmethod populates this withSearchResponse[BaseResult], the type hint here should reflect that for better type safety and clarity.