Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 71 additions & 24 deletions src/smolagents/default_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,13 @@ class DuckDuckGoSearchTool(Tool):
Args:
max_results (`int`, default `10`): Maximum number of search results to return.
rate_limit (`float`, default `1.0`): Maximum queries per second. Set to `None` to disable rate limiting.
timeout (`int`, default `10`): Request timeout in seconds to prevent hanging.
**kwargs: Additional keyword arguments for the `DDGS` client.

Examples:
```python
>>> from smolagents import DuckDuckGoSearchTool
>>> web_search_tool = DuckDuckGoSearchTool(max_results=5, rate_limit=2.0)
>>> web_search_tool = DuckDuckGoSearchTool(max_results=5, rate_limit=2.0, timeout=15)
>>> results = web_search_tool("Hugging Face")
>>> print(results)
```
Expand All @@ -120,10 +121,11 @@ class DuckDuckGoSearchTool(Tool):
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "string"

def __init__(self, max_results: int = 10, rate_limit: float | None = 1.0, **kwargs):
def __init__(self, max_results: int = 10, rate_limit: float | None = 1.0, timeout: int = 10, **kwargs):
super().__init__()
self.max_results = max_results
self.rate_limit = rate_limit
self.timeout = timeout
self._min_interval = 1.0 / rate_limit if rate_limit else 0.0
self._last_request_time = 0.0
try:
Expand All @@ -132,6 +134,9 @@ def __init__(self, max_results: int = 10, rate_limit: float | None = 1.0, **kwar
raise ImportError(
"You must install package `ddgs` to run this tool: for instance run `pip install ddgs`."
) from e
# Pass timeout to DDGS if supported
if "timeout" not in kwargs:
kwargs["timeout"] = timeout
self.ddgs = DDGS(**kwargs)

def forward(self, query: str) -> str:
Expand Down Expand Up @@ -169,11 +174,12 @@ class GoogleSearchTool(Tool):
}
output_type = "string"

def __init__(self, provider: str = "serpapi"):
def __init__(self, provider: str = "serpapi", timeout: int = 10):
super().__init__()
import os

self.provider = provider
self.timeout = timeout
if provider == "serpapi":
self.organic_key = "organic_results"
api_key_env_name = "SERPAPI_API_KEY"
Expand Down Expand Up @@ -204,7 +210,12 @@ def forward(self, query: str, filter_year: int | None = None) -> str:
if filter_year is not None:
params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"

response = requests.get(base_url, params=params)
try:
response = requests.get(base_url, params=params, timeout=self.timeout)
except requests.exceptions.Timeout:
raise Exception(f"Search request timed out after {self.timeout} seconds. Please try again.")
except requests.exceptions.RequestException as e:
raise Exception(f"Search request failed: {str(e)}")

if response.status_code == 200:
results = response.json()
Expand Down Expand Up @@ -257,11 +268,12 @@ class ApiWebSearchTool(Tool):
headers (`dict`, *optional*): Headers for API requests.
params (`dict`, *optional*): Parameters for API requests.
rate_limit (`float`, default `1.0`): Maximum queries per second. Set to `None` to disable rate limiting.
timeout (`int`, default `10`): Request timeout in seconds to prevent hanging.

Examples:
```python
>>> from smolagents import ApiWebSearchTool
>>> web_search_tool = ApiWebSearchTool(rate_limit=50.0)
>>> web_search_tool = ApiWebSearchTool(rate_limit=50.0, timeout=15)
>>> results = web_search_tool("Hugging Face")
>>> print(results)
```
Expand All @@ -280,6 +292,7 @@ def __init__(
headers: dict = None,
params: dict = None,
rate_limit: float | None = 1.0,
timeout: int = 10,
):
import os

Expand All @@ -290,6 +303,7 @@ def __init__(
self.headers = headers or {"X-Subscription-Token": self.api_key}
self.params = params or {"count": 10}
self.rate_limit = rate_limit
self.timeout = timeout
self._min_interval = 1.0 / rate_limit if rate_limit else 0.0
self._last_request_time = 0.0

Expand All @@ -311,8 +325,14 @@ def forward(self, query: str) -> str:

self._enforce_rate_limit()
params = {**self.params, "q": query}
response = requests.get(self.endpoint, headers=self.headers, params=params)
response.raise_for_status()
try:
response = requests.get(self.endpoint, headers=self.headers, params=params, timeout=self.timeout)
response.raise_for_status()
except requests.exceptions.Timeout:
raise Exception(f"Search request timed out after {self.timeout} seconds. Please try again.")
except requests.exceptions.RequestException as e:
raise Exception(f"Search request failed: {str(e)}")

data = response.json()
results = self.extract_results(data)
return self.format_markdown(results)
Expand Down Expand Up @@ -342,10 +362,11 @@ class WebSearchTool(Tool):
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
output_type = "string"

def __init__(self, max_results: int = 10, engine: str = "duckduckgo"):
def __init__(self, max_results: int = 10, engine: str = "duckduckgo", timeout: int = 10):
super().__init__()
self.max_results = max_results
self.engine = engine
self.timeout = timeout

def forward(self, query: str) -> str:
results = self.search(query)
Expand All @@ -369,12 +390,19 @@ def parse_results(self, results: list) -> str:
def search_duckduckgo(self, query: str) -> list:
import requests

response = requests.get(
"https://lite.duckduckgo.com/lite/",
params={"q": query},
headers={"User-Agent": "Mozilla/5.0"},
)
response.raise_for_status()
try:
response = requests.get(
"https://lite.duckduckgo.com/lite/",
params={"q": query},
headers={"User-Agent": "Mozilla/5.0"},
timeout=self.timeout,
)
response.raise_for_status()
except requests.exceptions.Timeout:
raise Exception(f"DuckDuckGo search timed out after {self.timeout} seconds. Please try again.")
except requests.exceptions.RequestException as e:
raise Exception(f"DuckDuckGo search failed: {str(e)}")

parser = self._create_duckduckgo_parser()
parser.feed(response.text)
return parser.results
Expand Down Expand Up @@ -430,11 +458,18 @@ def search_bing(self, query: str) -> list:

import requests

response = requests.get(
"https://www.bing.com/search",
params={"q": query, "format": "rss"},
)
response.raise_for_status()
try:
response = requests.get(
"https://www.bing.com/search",
params={"q": query, "format": "rss"},
timeout=self.timeout,
)
response.raise_for_status()
except requests.exceptions.Timeout:
raise Exception(f"Bing search timed out after {self.timeout} seconds. Please try again.")
except requests.exceptions.RequestException as e:
raise Exception(f"Bing search failed: {str(e)}")

root = ET.fromstring(response.text)
items = root.findall(".//item")
results = [
Expand All @@ -461,9 +496,10 @@ class VisitWebpageTool(Tool):
}
output_type = "string"

def __init__(self, max_output_length: int = 40000):
def __init__(self, max_output_length: int = 40000, timeout: int = 20):
super().__init__()
self.max_output_length = max_output_length
self.timeout = timeout

def _truncate_content(self, content: str, max_length: int) -> str:
if len(content) <= max_length:
Expand All @@ -484,8 +520,8 @@ def forward(self, url: str) -> str:
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
) from e
try:
# Send a GET request to the URL with a 20-second timeout
response = requests.get(url, timeout=20)
# Send a GET request to the URL with configurable timeout
response = requests.get(url, timeout=self.timeout)
response.raise_for_status() # Raise an exception for bad status codes

# Convert the HTML content to Markdown
Expand All @@ -497,7 +533,7 @@ def forward(self, url: str) -> str:
return self._truncate_content(markdown_content, self.max_output_length)

except requests.exceptions.Timeout:
return "The request timed out. Please try again later or check the URL."
return f"The request timed out after {self.timeout} seconds. Please try again or check the URL."
except RequestException as e:
return f"Error fetching the webpage: {str(e)}"
except Exception as e:
Expand All @@ -515,6 +551,7 @@ class WikipediaSearchTool(Tool):
See: http://meta.wikimedia.org/wiki/List_of_Wikipedias
content_type (`Literal["summary", "text"]`, default `"text"`): Type of content to fetch. Can be "summary" for a short summary or "text" for the full article.
extract_format (`Literal["HTML", "WIKI"]`, default `"WIKI"`): Extraction format of the output. Can be `"WIKI"` or `"HTML"`.
timeout (`int`, default `10`): Request timeout in seconds to prevent hanging.

Example:
```python
Expand All @@ -526,6 +563,7 @@ class WikipediaSearchTool(Tool):
>>> language="en",
>>> content_type="summary", # or "text"
>>> extract_format="WIKI",
>>> timeout=15,
>>> )
>>> ],
>>> model=InferenceClientModel(),
Expand All @@ -550,6 +588,7 @@ def __init__(
language: str = "en",
content_type: str = "text",
extract_format: str = "WIKI",
timeout: int = 10,
):
super().__init__()
try:
Expand All @@ -564,6 +603,7 @@ def __init__(
self.user_agent = user_agent
self.language = language
self.content_type = content_type
self.timeout = timeout

# Map string format to wikipediaapi.ExtractFormat
extract_format_map = {
Expand All @@ -577,7 +617,10 @@ def __init__(
self.extract_format = extract_format_map[extract_format]

self.wiki = wikipediaapi.Wikipedia(
user_agent=self.user_agent, language=self.language, extract_format=self.extract_format
user_agent=self.user_agent,
language=self.language,
extract_format=self.extract_format,
timeout=self.timeout,
)

def forward(self, query: str) -> str:
Expand All @@ -600,6 +643,10 @@ def forward(self, query: str) -> str:
return f"✅ **Wikipedia Page:** {title}\n\n**Content:** {text}\n\n🔗 **Read more:** {url}"

except Exception as e:
# Check if it's a timeout-related error
error_str = str(e).lower()
if "timeout" in error_str or "timed out" in error_str:
return f"Wikipedia search timed out after {self.timeout} seconds. Please try again with a different query."
return f"Error fetching Wikipedia summary: {str(e)}"


Expand Down