diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py index 62679de80..882b7c05f 100644 --- a/src/smolagents/default_tools.py +++ b/src/smolagents/default_tools.py @@ -104,12 +104,13 @@ class DuckDuckGoSearchTool(Tool): Args: max_results (`int`, default `10`): Maximum number of search results to return. rate_limit (`float`, default `1.0`): Maximum queries per second. Set to `None` to disable rate limiting. + timeout (`int`, default `10`): Request timeout in seconds to prevent hanging. **kwargs: Additional keyword arguments for the `DDGS` client. Examples: ```python >>> from smolagents import DuckDuckGoSearchTool - >>> web_search_tool = DuckDuckGoSearchTool(max_results=5, rate_limit=2.0) + >>> web_search_tool = DuckDuckGoSearchTool(max_results=5, rate_limit=2.0, timeout=15) >>> results = web_search_tool("Hugging Face") >>> print(results) ``` @@ -120,10 +121,11 @@ class DuckDuckGoSearchTool(Tool): inputs = {"query": {"type": "string", "description": "The search query to perform."}} output_type = "string" - def __init__(self, max_results: int = 10, rate_limit: float | None = 1.0, **kwargs): + def __init__(self, max_results: int = 10, rate_limit: float | None = 1.0, timeout: int = 10, **kwargs): super().__init__() self.max_results = max_results self.rate_limit = rate_limit + self.timeout = timeout self._min_interval = 1.0 / rate_limit if rate_limit else 0.0 self._last_request_time = 0.0 try: @@ -132,6 +134,9 @@ def __init__(self, max_results: int = 10, rate_limit: float | None = 1.0, **kwar raise ImportError( "You must install package `ddgs` to run this tool: for instance run `pip install ddgs`." ) from e + # Pass timeout to DDGS if supported + if "timeout" not in kwargs: + kwargs["timeout"] = timeout self.ddgs = DDGS(**kwargs) def forward(self, query: str) -> str: @@ -169,11 +174,12 @@ class GoogleSearchTool(Tool): } output_type = "string" - def __init__(self, provider: str = "serpapi"): + def __init__(self, provider: str = "serpapi", timeout: int = 10): super().__init__() import os self.provider = provider + self.timeout = timeout if provider == "serpapi": self.organic_key = "organic_results" api_key_env_name = "SERPAPI_API_KEY" @@ -204,7 +210,12 @@ def forward(self, query: str, filter_year: int | None = None) -> str: if filter_year is not None: params["tbs"] = f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}" - response = requests.get(base_url, params=params) + try: + response = requests.get(base_url, params=params, timeout=self.timeout) + except requests.exceptions.Timeout: + raise Exception(f"Search request timed out after {self.timeout} seconds. Please try again.") + except requests.exceptions.RequestException as e: + raise Exception(f"Search request failed: {str(e)}") if response.status_code == 200: results = response.json() @@ -257,11 +268,12 @@ class ApiWebSearchTool(Tool): headers (`dict`, *optional*): Headers for API requests. params (`dict`, *optional*): Parameters for API requests. rate_limit (`float`, default `1.0`): Maximum queries per second. Set to `None` to disable rate limiting. + timeout (`int`, default `10`): Request timeout in seconds to prevent hanging. Examples: ```python >>> from smolagents import ApiWebSearchTool - >>> web_search_tool = ApiWebSearchTool(rate_limit=50.0) + >>> web_search_tool = ApiWebSearchTool(rate_limit=50.0, timeout=15) >>> results = web_search_tool("Hugging Face") >>> print(results) ``` @@ -280,6 +292,7 @@ def __init__( headers: dict = None, params: dict = None, rate_limit: float | None = 1.0, + timeout: int = 10, ): import os @@ -290,6 +303,7 @@ def __init__( self.headers = headers or {"X-Subscription-Token": self.api_key} self.params = params or {"count": 10} self.rate_limit = rate_limit + self.timeout = timeout self._min_interval = 1.0 / rate_limit if rate_limit else 0.0 self._last_request_time = 0.0 @@ -311,8 +325,14 @@ def forward(self, query: str) -> str: self._enforce_rate_limit() params = {**self.params, "q": query} - response = requests.get(self.endpoint, headers=self.headers, params=params) - response.raise_for_status() + try: + response = requests.get(self.endpoint, headers=self.headers, params=params, timeout=self.timeout) + response.raise_for_status() + except requests.exceptions.Timeout: + raise Exception(f"Search request timed out after {self.timeout} seconds. Please try again.") + except requests.exceptions.RequestException as e: + raise Exception(f"Search request failed: {str(e)}") + data = response.json() results = self.extract_results(data) return self.format_markdown(results) @@ -342,10 +362,11 @@ class WebSearchTool(Tool): inputs = {"query": {"type": "string", "description": "The search query to perform."}} output_type = "string" - def __init__(self, max_results: int = 10, engine: str = "duckduckgo"): + def __init__(self, max_results: int = 10, engine: str = "duckduckgo", timeout: int = 10): super().__init__() self.max_results = max_results self.engine = engine + self.timeout = timeout def forward(self, query: str) -> str: results = self.search(query) @@ -369,12 +390,19 @@ def parse_results(self, results: list) -> str: def search_duckduckgo(self, query: str) -> list: import requests - response = requests.get( - "https://lite.duckduckgo.com/lite/", - params={"q": query}, - headers={"User-Agent": "Mozilla/5.0"}, - ) - response.raise_for_status() + try: + response = requests.get( + "https://lite.duckduckgo.com/lite/", + params={"q": query}, + headers={"User-Agent": "Mozilla/5.0"}, + timeout=self.timeout, + ) + response.raise_for_status() + except requests.exceptions.Timeout: + raise Exception(f"DuckDuckGo search timed out after {self.timeout} seconds. Please try again.") + except requests.exceptions.RequestException as e: + raise Exception(f"DuckDuckGo search failed: {str(e)}") + parser = self._create_duckduckgo_parser() parser.feed(response.text) return parser.results @@ -430,11 +458,18 @@ def search_bing(self, query: str) -> list: import requests - response = requests.get( - "https://www.bing.com/search", - params={"q": query, "format": "rss"}, - ) - response.raise_for_status() + try: + response = requests.get( + "https://www.bing.com/search", + params={"q": query, "format": "rss"}, + timeout=self.timeout, + ) + response.raise_for_status() + except requests.exceptions.Timeout: + raise Exception(f"Bing search timed out after {self.timeout} seconds. Please try again.") + except requests.exceptions.RequestException as e: + raise Exception(f"Bing search failed: {str(e)}") + root = ET.fromstring(response.text) items = root.findall(".//item") results = [ @@ -461,9 +496,10 @@ class VisitWebpageTool(Tool): } output_type = "string" - def __init__(self, max_output_length: int = 40000): + def __init__(self, max_output_length: int = 40000, timeout: int = 20): super().__init__() self.max_output_length = max_output_length + self.timeout = timeout def _truncate_content(self, content: str, max_length: int) -> str: if len(content) <= max_length: @@ -484,8 +520,8 @@ def forward(self, url: str) -> str: "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`." ) from e try: - # Send a GET request to the URL with a 20-second timeout - response = requests.get(url, timeout=20) + # Send a GET request to the URL with configurable timeout + response = requests.get(url, timeout=self.timeout) response.raise_for_status() # Raise an exception for bad status codes # Convert the HTML content to Markdown @@ -497,7 +533,7 @@ def forward(self, url: str) -> str: return self._truncate_content(markdown_content, self.max_output_length) except requests.exceptions.Timeout: - return "The request timed out. Please try again later or check the URL." + return f"The request timed out after {self.timeout} seconds. Please try again or check the URL." except RequestException as e: return f"Error fetching the webpage: {str(e)}" except Exception as e: @@ -515,6 +551,7 @@ class WikipediaSearchTool(Tool): See: http://meta.wikimedia.org/wiki/List_of_Wikipedias content_type (`Literal["summary", "text"]`, default `"text"`): Type of content to fetch. Can be "summary" for a short summary or "text" for the full article. extract_format (`Literal["HTML", "WIKI"]`, default `"WIKI"`): Extraction format of the output. Can be `"WIKI"` or `"HTML"`. + timeout (`int`, default `10`): Request timeout in seconds to prevent hanging. Example: ```python @@ -526,6 +563,7 @@ class WikipediaSearchTool(Tool): >>> language="en", >>> content_type="summary", # or "text" >>> extract_format="WIKI", + >>> timeout=15, >>> ) >>> ], >>> model=InferenceClientModel(), @@ -550,6 +588,7 @@ def __init__( language: str = "en", content_type: str = "text", extract_format: str = "WIKI", + timeout: int = 10, ): super().__init__() try: @@ -564,6 +603,7 @@ def __init__( self.user_agent = user_agent self.language = language self.content_type = content_type + self.timeout = timeout # Map string format to wikipediaapi.ExtractFormat extract_format_map = { @@ -577,7 +617,10 @@ def __init__( self.extract_format = extract_format_map[extract_format] self.wiki = wikipediaapi.Wikipedia( - user_agent=self.user_agent, language=self.language, extract_format=self.extract_format + user_agent=self.user_agent, + language=self.language, + extract_format=self.extract_format, + timeout=self.timeout, ) def forward(self, query: str) -> str: @@ -600,6 +643,10 @@ def forward(self, query: str) -> str: return f"āœ… **Wikipedia Page:** {title}\n\n**Content:** {text}\n\nšŸ”— **Read more:** {url}" except Exception as e: + # Check if it's a timeout-related error + error_str = str(e).lower() + if "timeout" in error_str or "timed out" in error_str: + return f"Wikipedia search timed out after {self.timeout} seconds. Please try again with a different query." return f"Error fetching Wikipedia summary: {str(e)}"