Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ms_agent/tools/search/search_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ class SearchEngineType(enum.Enum):
EXA = 'exa'
SERPAPI = 'serpapi'
ARXIV = 'arxiv'
TAVILY = 'tavily'


# Mapping from engine type to tool name
ENGINE_TOOL_NAMES: Dict[str, str] = {
'exa': 'exa_search',
'serpapi': 'serpapi_search',
'arxiv': 'arxiv_search',
'tavily': 'tavily_search',
}


Expand Down
3 changes: 3 additions & 0 deletions ms_agent/tools/search/tavily/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa
from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult
from ms_agent.tools.search.tavily.search import TavilySearch
101 changes: 101 additions & 0 deletions ms_agent/tools/search/tavily/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# flake8: noqa
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

import json

from ms_agent.tools.search.search_base import BaseResult, SearchResponse


@dataclass
class TavilySearchRequest:

# The search query string
query: str

# Number of results to return, default is 5
num_results: Optional[int] = 5

# Search depth: 'basic' or 'advanced'
search_depth: Optional[str] = 'basic'

# Topic category: 'general', 'news', or 'finance'
topic: Optional[str] = 'general'

# Domains to include in search
include_domains: Optional[List[str]] = None

# Domains to exclude from search
exclude_domains: Optional[List[str]] = None

def to_dict(self) -> Dict[str, Any]:
"""
Convert the request parameters to a dictionary.
"""
d = {
'query': self.query,
'max_results': self.num_results,
'search_depth': self.search_depth,
'topic': self.topic,
}
if self.include_domains:
d['include_domains'] = self.include_domains
if self.exclude_domains:
d['exclude_domains'] = self.exclude_domains
return d

def to_json(self) -> str:
"""
Convert the request parameters to a JSON string.
"""
return json.dumps(self.to_dict(), ensure_ascii=False)


@dataclass
class TavilySearchResult:

# The original search query string
query: str

# Optional arguments for the search request
arguments: Dict[str, Any] = field(default_factory=dict)

# The response from the Tavily search API (dict with 'results' key)
response: SearchResponse = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The type hint for response should specify the generic type of SearchResponse. Since the TavilySearch.search method populates this with SearchResponse[BaseResult], the type hint here should reflect that for better type safety and clarity.

Suggested change
response: SearchResponse = None
response: Optional[SearchResponse[BaseResult]] = None


def to_list(self):
"""
Convert the search results to a list of dictionaries.
"""
if not self.response or not self.response.results:
print('***Warning: No search results found.')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print for warnings is not ideal in a library context. It's better to use the logger module for consistent logging practices, which allows for configurable output and severity levels.

Suggested change
print('***Warning: No search results found.')
logger.warning('No search results found.')

return []

if not self.query:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print for warnings is not ideal in a library context. It's better to use the logger module for consistent logging practices, which allows for configurable output and severity levels.

Suggested change
if not self.query:
logger.warning('No query provided for search results.')

print('***Warning: No query provided for search results.')
return []

res_list: List[Any] = []
for res in self.response.results:
res_list.append({
'url': res.url,
'id': res.id,
'title': res.title,
'highlights': res.highlights,
'highlight_scores': res.highlight_scores,
'summary': res.summary,
'markdown': res.markdown,
})

return res_list

@staticmethod
def load_from_disk(file_path: str) -> List[Dict[str, Any]]:
"""
Load search results from a local file.
"""
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f'Search results loaded from {file_path}')
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using print for informational messages is not ideal in a library context. It's better to use the logger module for consistent logging practices, which allows for configurable output and severity levels.

Suggested change
print(f'Search results loaded from {file_path}')
logger.info(f'Search results loaded from {file_path}')


return data
125 changes: 125 additions & 0 deletions ms_agent/tools/search/tavily/search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# flake8: noqa
import os
from typing import TYPE_CHECKING

from tavily import TavilyClient
from ms_agent.tools.search.tavily.schema import TavilySearchRequest, TavilySearchResult
from ms_agent.tools.search.search_base import (BaseResult, SearchEngine,
SearchEngineType,
SearchResponse)

if TYPE_CHECKING:
from ms_agent.llm.utils import Tool


class TavilySearch(SearchEngine):
"""
Search engine using Tavily API.

Best for: AI-optimized web search, general and news queries,
high relevance results with built-in content extraction.
"""

engine_type = SearchEngineType.TAVILY

def __init__(self, api_key: str = None):

api_key = api_key or os.getenv('TAVILY_API_KEY')
assert api_key, 'TAVILY_API_KEY must be set either as an argument or as an environment variable'
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using assert for validating API keys is not recommended in production code, as assert statements can be optimized out by the Python interpreter, leading to unexpected behavior if the key is missing. A ValueError or RuntimeError provides a more robust and explicit error handling mechanism.

Suggested change
assert api_key, 'TAVILY_API_KEY must be set either as an argument or as an environment variable'
if not api_key:
raise ValueError('TAVILY_API_KEY must be set either as an argument or as an environment variable')


self.client = TavilyClient(api_key=api_key)

def search(self, search_request: TavilySearchRequest) -> TavilySearchResult:
"""
Perform a search using the Tavily API with the provided search request parameters.

:param search_request: An instance of TavilySearchRequest containing search parameters.
:return: An instance of TavilySearchResult containing the search results.
"""
search_args: dict = search_request.to_dict()
search_result: TavilySearchResult = TavilySearchResult(
query=search_request.query,
arguments=search_args,
)
try:
raw_response = self.client.search(**search_args)
# Map Tavily results to BaseResult schema
results = []
for item in raw_response.get('results', []):
results.append(
BaseResult(
url=item.get('url', ''),
id=item.get('url', ''),
title=item.get('title', ''),
summary=item.get('content', ''),
markdown=item.get('raw_content'),
))
search_result.response = SearchResponse(results=results)
except Exception as e:
raise RuntimeError(f'Failed to perform search: {e}') from e

return search_result

@classmethod
def get_tool_definition(cls, server_name: str = 'web_search') -> 'Tool':
"""Return the tool definition for Tavily search engine."""
from ms_agent.llm.utils import Tool
return Tool(
tool_name=cls.get_tool_name(),
server_name=server_name,
description=(
'Search the web using Tavily AI-optimized search engine. '
'Best for: general web queries, news, and finance topics. '
'Returns highly relevant results with content extraction.'),
parameters={
'type': 'object',
'properties': {
'query': {
'type':
'string',
'description':
'The search query. Use natural language for best results.',
},
'num_results': {
'type':
'integer',
'minimum':
1,
'maximum':
10,
'description':
'Number of results to return. Default is 5.',
},
'search_depth': {
'type':
'string',
'enum': ['basic', 'advanced'],
'description':
('Search depth. "basic" for fast results, '
'"advanced" for higher relevance. Default is "basic".'
),
},
'topic': {
'type':
'string',
'enum': ['general', 'news', 'finance'],
'description':
('Topic category for the search. '
'Default is "general".'),
},
},
'required': ['query'],
},
)

@classmethod
def build_request_from_args(cls, **kwargs) -> TavilySearchRequest:
"""Build TavilySearchRequest from tool call arguments."""
return TavilySearchRequest(
query=kwargs['query'],
num_results=kwargs.get('num_results', 5),
search_depth=kwargs.get('search_depth', 'basic'),
topic=kwargs.get('topic', 'general'),
include_domains=kwargs.get('include_domains'),
exclude_domains=kwargs.get('exclude_domains'),
)
14 changes: 13 additions & 1 deletion ms_agent/tools/search/websearch_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ def get_search_engine_class(engine_type: str) -> Type[SearchEngine]:
elif engine_type == 'arxiv':
from ms_agent.tools.search.arxiv import ArxivSearch
return ArxivSearch
elif engine_type == 'tavily':
from ms_agent.tools.search.tavily import TavilySearch
return TavilySearch
else:
logger.warning(
f"Unknown search engine '{engine_type}', falling back to arxiv")
Expand Down Expand Up @@ -238,6 +241,9 @@ def get_search_engine(engine_type: str,
api_key=api_key or os.getenv('SERPAPI_API_KEY'),
provider=kwargs.get('provider', default_provider),
)
elif engine_type == 'tavily':
from ms_agent.tools.search.tavily import TavilySearch
return TavilySearch(api_key=api_key or os.getenv('TAVILY_API_KEY'))
elif engine_type == 'arxiv':
from ms_agent.tools.search.arxiv import ArxivSearch
return ArxivSearch()
Expand Down Expand Up @@ -296,7 +302,7 @@ class WebSearchTool(ToolBase):
SERVER_NAME = 'web_search'

# Registry of supported search engines
SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv')
SUPPORTED_ENGINES = ('exa', 'serpapi', 'arxiv', 'tavily')

# Process-wide (class-level) usage tracking for summarization calls.
# This is intentionally separate from LLMAgent usage totals.
Expand Down Expand Up @@ -404,6 +410,9 @@ def __init__(self, config, **kwargs):
'serpapi': (getattr(tool_cfg, 'serpapi_api_key', None)
or os.getenv('SERPAPI_API_KEY'))
if tool_cfg else os.getenv('SERPAPI_API_KEY'),
'tavily': (getattr(tool_cfg, 'tavily_api_key', None)
or os.getenv('TAVILY_API_KEY'))
if tool_cfg else os.getenv('TAVILY_API_KEY'),
}

# SerpApi provider (google, bing, baidu)
Expand Down Expand Up @@ -508,6 +517,9 @@ async def connect(self) -> None:
api_key=self._api_keys.get('serpapi'),
provider=self._serpapi_provider,
)
elif engine_type == 'tavily':
self._engines[engine_type] = engine_cls(
api_key=self._api_keys.get('tavily'))
else: # arxiv
self._engines[engine_type] = engine_cls()

Expand Down
12 changes: 10 additions & 2 deletions ms_agent/tools/search_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ms_agent.tools.search.exa import ExaSearch
from ms_agent.tools.search.search_base import SearchEngineType
from ms_agent.tools.search.serpapi import SerpApiSearch
from ms_agent.tools.search.tavily import TavilySearch
from ms_agent.utils.logger import get_logger

logger = get_logger()
Expand All @@ -23,7 +24,8 @@ def set_search_env_overrides(env_overrides: Optional[Dict[str, str]]) -> None:
Expected keys (all optional):
- 'EXA_API_KEY'
- 'SERPAPI_API_KEY'
- SEARCH_ENGINE_OVERRIDE_ENV (e.g. 'exa' / 'serpapi' / 'arxiv')
- 'TAVILY_API_KEY'
- SEARCH_ENGINE_OVERRIDE_ENV (e.g. 'exa' / 'serpapi' / 'arxiv' / 'tavily')
"""
if not env_overrides:
if hasattr(_search_env_local, 'overrides'):
Expand Down Expand Up @@ -135,14 +137,16 @@ def get_web_search_tool(config_file: str):
or '')).strip().lower()
if engine_override and engine_override in (SearchEngineType.EXA.value,
SearchEngineType.SERPAPI.value,
SearchEngineType.ARXIV.value):
SearchEngineType.ARXIV.value,
SearchEngineType.TAVILY.value):
search_config['engine'] = engine_override

engine_name = (search_config.get('engine', '') or '').lower()

# Per-request API key overrides (thread-local) take precedence
override_exa_key = local_env.get('EXA_API_KEY')
override_serp_key = local_env.get('SERPAPI_API_KEY')
override_tavily_key = local_env.get('TAVILY_API_KEY')

if engine_name == SearchEngineType.EXA.value:
return ExaSearch(
Expand All @@ -153,6 +157,10 @@ def get_web_search_tool(config_file: str):
api_key=override_serp_key or search_config.get(
'serpapi_api_key', os.getenv('SERPAPI_API_KEY', None)),
provider=search_config.get('provider', 'google').lower())
elif engine_name == SearchEngineType.TAVILY.value:
return TavilySearch(
api_key=override_tavily_key or search_config.get(
'tavily_api_key', os.getenv('TAVILY_API_KEY', None)))
elif engine_name == SearchEngineType.ARXIV.value:
return ArxivSearch()
else:
Expand Down
1 change: 1 addition & 0 deletions requirements/research.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Pillow
python-dotenv
requests
rich
tavily-python