diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml
index 2e4f5c67f..774fa5296 100644
--- a/.github/workflows/quality.yml
+++ b/.github/workflows/quality.yml
@@ -16,20 +16,15 @@ jobs:
python-version: "3.12"
# Setup venv
- # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
- name: Setup venv + uv
run: |
pip install --upgrade uv
uv venv
- name: Install dependencies
- run: uv pip install "smolagents[test] @ ."
- - run: uv run ruff check tests src # linter
- - run: uv run ruff format --check tests src # formatter
+ run: uv pip install "smolagents[quality] @ ."
- # Run type checking at least on smolagents root file to check all modules
- # that can be lazy-loaded actually exist.
- # - run: uv run mypy src/smolagents/__init__.py --follow-imports=silent --show-traceback
-
- # Run mypy on full package
- # - run: uv run mypy src
\ No newline at end of file
+ # Equivalent of "make quality" but step by step
+ - run: uv run ruff check examples src tests utils # linter
+ - run: uv run ruff format --check examples src tests utils # formatter
+ - run: uv run python utils/check_tests_in_ci.py
\ No newline at end of file
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index c720ec0f5..412c1e8ed 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -20,9 +20,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
-
# Setup venv
- # TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.
- name: Setup venv + uv
run: |
pip install --upgrade uv
@@ -33,33 +31,59 @@ jobs:
run: |
uv pip install "smolagents[test] @ ."
+ # Run all tests separately for individual feedback
+ # Use 'if success() || failure()' so that all tests are run even if one failed
+ # See https://stackoverflow.com/a/62112985
- name: Agent tests
run: |
- uv run pytest -sv ./tests/test_agents.py
+ uv run pytest ./tests/test_agents.py
+ if: ${{ success() || failure() }}
+
- name: Default tools tests
run: |
- uv run pytest -sv ./tests/test_default_tools.py
+ uv run pytest ./tests/test_default_tools.py
+ if: ${{ success() || failure() }}
+
+ # - name: Docs tests # Disabled for now (slow test + requires API keys)
+ # run: |
+ # uv run pytest ./tests/test_all_docs.py
+
- name: Final answer tests
run: |
- uv run pytest -sv ./tests/test_final_answer.py
+ uv run pytest ./tests/test_final_answer.py
+ if: ${{ success() || failure() }}
+
- name: Models tests
run: |
- uv run pytest -sv ./tests/test_models.py
+ uv run pytest ./tests/test_models.py
+ if: ${{ success() || failure() }}
+
- name: Monitoring tests
run: |
- uv run pytest -sv ./tests/test_monitoring.py
+ uv run pytest ./tests/test_monitoring.py
+ if: ${{ success() || failure() }}
+
- name: Python interpreter tests
run: |
- uv run pytest -sv ./tests/test_python_interpreter.py
+ uv run pytest ./tests/test_python_interpreter.py
+ if: ${{ success() || failure() }}
+
- name: Search tests
run: |
- uv run pytest -sv ./tests/test_search.py
+ uv run pytest ./tests/test_search.py
+ if: ${{ success() || failure() }}
+
- name: Tools tests
run: |
- uv run pytest -sv ./tests/test_tools.py
+ uv run pytest ./tests/test_tools.py
+ if: ${{ success() || failure() }}
+
- name: Types tests
run: |
- uv run pytest -sv ./tests/test_types.py
+ uv run pytest ./tests/test_types.py
+ if: ${{ success() || failure() }}
+
- name: Utils tests
run: |
- uv run pytest -sv ./tests/test_utils.py
\ No newline at end of file
+ uv run pytest ./tests/test_utils.py
+ if: ${{ success() || failure() }}
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index b5ffd0f2e..0e346b751 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -91,8 +91,8 @@ happy to make the changes or help you make a contribution if you're interested!
## I want to become a maintainer of the project. How do I get there?
-Smolagents is a project led and managed by Hugging Face as an initial fork of Transformers Agents. We are more than
-happy to have motivated individuals from other organizations join us as maintainers with the goal of helping Smolagents
+smolagents is a project led and managed by Hugging Face. We are more than
+happy to have motivated individuals from other organizations join us as maintainers with the goal of helping smolagents
make a dent in the world of Agents.
If you are such an individual (or organization), please reach out to us and let's collaborate.
\ No newline at end of file
diff --git a/Makefile b/Makefile
index a24c1aefd..c8e7c04f6 100644
--- a/Makefile
+++ b/Makefile
@@ -1,53 +1,18 @@
.PHONY: quality style test docs utils
-check_dirs := .
+check_dirs := examples src tests utils
-# Check that source code meets quality standards
-
-extra_quality_checks:
- python utils/check_copies.py
- python utils/check_dummies.py
- python utils/check_repo.py
- doc-builder style smolagents docs/source --max_len 119
-
-# this target runs checks on all files
+# Check code quality of the source code
quality:
ruff check $(check_dirs)
ruff format --check $(check_dirs)
- doc-builder style smolagents docs/source --max_len 119 --check_only
+ python utils/check_tests_in_ci.py
-# Format source code automatically and check is there are any problems left that need manual fixing
+# Format source code automatically
style:
ruff check $(check_dirs) --fix
ruff format $(check_dirs)
- doc-builder style smolagents docs/source --max_len 119
-# Run tests for the library
-test_big_modeling:
- python -m pytest -s -v ./tests/test_big_modeling.py ./tests/test_modeling_utils.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_big_modeling.log",)
-
-test_core:
- python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_core.log",)
-
-test_cli:
- python -m pytest -s -v ./tests/test_cli.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_cli.log",)
-
-
-# Since the new version of pytest will *change* how things are collected, we need `deepspeed` to
-# run after test_core and test_cli
+# Run smolagents tests
test:
- $(MAKE) test_core
- $(MAKE) test_cli
- $(MAKE) test_big_modeling
- $(MAKE) test_deepspeed
- $(MAKE) test_fsdp
-
-test_examples:
- python -m pytest -s -v ./tests/test_examples.py $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_examples.log",)
-
-# Same as test but used to install only the base dependencies
-test_prod:
- $(MAKE) test_core
-
-test_rest:
- python -m pytest -s -v ./tests/test_examples.py::FeatureExamplesTests $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_rest.log",)
+ pytest ./tests/
\ No newline at end of file
diff --git a/README.md b/README.md
index 74a82b58d..c0b14c77b 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,10 @@ limitations under the License.
-
🤗 smolagents - a smol library to build great agents!
+
+

+
smolagents - a smol library to build great agents!
+
`smolagents` is a library that enables you to run powerful agents in a few lines of code. It offers:
@@ -66,8 +69,8 @@ In our `CodeAgent`, the LLM engine writes its actions in code. This approach is
and [reaches higher performance on difficult benchmarks](https://huggingface.co/papers/2411.01747). Head to [our high-level intro to agents](https://huggingface.co/docs/smolagents/conceptual_guides/intro_agents) to learn more on that.
Especially, since code execution can be a security concern (arbitrary code execution!), we provide options at runtime:
- - a secure python interpreter to run code more safely in your environment
- - a sandboxed environment using [E2B](https://e2b.dev/).
+ - a secure python interpreter to run code more safely in your environment (more secure than raw code execution but still risky)
+ - a sandboxed environment using [E2B](https://e2b.dev/) (removes the risk to your own system).
## How smol is it really?
@@ -88,6 +91,36 @@ We've created [`CodeAgent`](https://huggingface.co/docs/smolagents/reference/age
This comparison shows that open source models can now take on the best closed models!
+## Contributing
+
+To contribute, follow our [contribution guide](https://github.com/huggingface/smolagents/blob/main/CONTRIBUTING.md).
+
+At any moment, feel welcome to open an issue, citing your exact error traces and package versions if it's a bug.
+It's often even better to open a PR with your proposed fixes/changes!
+
+To install dev dependencies, run:
+```
+pip install -e ".[dev]"
+```
+
+When making changes to the codebase, please check that it follows the repo's code quality requirements by running:
+To check code quality of the source code:
+```
+make quality
+```
+
+If the checks fail, you can run the formatter with:
+```
+make style
+```
+
+And commit the changes.
+
+To run tests locally, run this command:
+```bash
+pytest .
+```
+
## Citing smolagents
If you use `smolagents` in your publication, please cite it by using the following BibTeX entry.
diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index bec73b788..71faa4d92 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -8,6 +8,8 @@
sections:
- local: tutorials/building_good_agents
title: ✨ Building good agents
+ - local: tutorials/inspect_runs
+ title: 📊 Inspect your agent runs using telemetry
- local: tutorials/tools
title: 🛠️ Tools - in-depth guide
- local: tutorials/secure_code_execution
diff --git a/docs/source/en/examples/multiagents.md b/docs/source/en/examples/multiagents.md
index 7901de2b6..c4bb51413 100644
--- a/docs/source/en/examples/multiagents.md
+++ b/docs/source/en/examples/multiagents.md
@@ -64,7 +64,7 @@ model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
## 🔍 Create a web search tool
-For web browsing, we can already use our pre-existing [`DuckDuckGoSearchTool`](https://github.com/huggingface/smolagents/blob/main/src/smolagents/default_tools/search.py) tool to provide a Google search equivalent.
+For web browsing, we can already use our pre-existing [`DuckDuckGoSearchTool`](https://github.com/huggingface/smolagents/blob/main/src/smolagents/default_tools.py#L151-L176) tool to provide a Google search equivalent.
But then we will also need to be able to peak into the page found by the `DuckDuckGoSearchTool`.
To do so, we could import the library's built-in `VisitWebpageTool`, but we will build it again to see how it's done.
@@ -196,4 +196,4 @@ Seems like we'll need some sizeable powerplants if the [scaling hypothesis](http
Our agents managed to efficiently collaborate towards solving the task! ✅
-💡 You can easily extend this orchestration to more agents: one does the code execution, one the web search, one handles file loadings...
\ No newline at end of file
+💡 You can easily extend this orchestration to more agents: one does the code execution, one the web search, one handles file loadings...
diff --git a/docs/source/en/guided_tour.md b/docs/source/en/guided_tour.md
index 9db8ecdb9..dd6a8214f 100644
--- a/docs/source/en/guided_tour.md
+++ b/docs/source/en/guided_tour.md
@@ -89,8 +89,9 @@ from smolagents import CodeAgent, LiteLLMModel
model = LiteLLMModel(
model_id="ollama_chat/llama3.2", # This model is a bit weak for agentic behaviours though
- api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
+ api_base="http://localhost:11434", # replace with 127.0.0.1:11434 or remote open-ai compatible server if necessary
api_key="YOUR_API_KEY" # replace with API key if necessary
+ num_ctx=8192 # ollama default is 2048 which will fail horribly. 8192 works for easy tasks, more is better. Check https://huggingface.co/spaces/NyxKrage/LLM-Model-VRAM-Calculator to calculate how much VRAM this will need for the selected model.
)
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
@@ -336,7 +337,7 @@ from smolagents import (
)
# Import tool from Hub
-image_generation_tool = load_tool("m-ric/text-to-image")
+image_generation_tool = load_tool("m-ric/text-to-image", trust_remote_code=True)
model = HfApiModel(model_id)
diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index fbcfba065..170c0222b 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -15,6 +15,10 @@ rendered properly in your Markdown viewer.
# `smolagents`
+
+

+
+
This library is the simplest framework out there to build powerful agents! By the way, wtf are "agents"? We provide our definition [in this page](conceptual_guides/intro_agents), where you'll also find tips for when to use them or not (spoilers: you'll often be better off without agents).
This library offers:
diff --git a/docs/source/en/reference/agents.md b/docs/source/en/reference/agents.md
index 2149c0dcb..76b2ecb6b 100644
--- a/docs/source/en/reference/agents.md
+++ b/docs/source/en/reference/agents.md
@@ -136,8 +136,22 @@ messages = [
{"role": "user", "content": "No need to help, take it easy."},
]
-model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest", temperature=0.2)
-print(model(messages, max_tokens=10))
+model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest", temperature=0.2, max_tokens=10)
+print(model(messages))
```
[[autodoc]] LiteLLMModel
+
+### OpenAiServerModel
+
+This class lets you call any OpenAIServer compatible model.
+Here's how you can set it (you can customise the `api_base` url to point to another server):
+```py
+from smolagents import OpenAIServerModel
+
+model = OpenAIServerModel(
+ model_id="gpt-4o",
+ api_base="https://api.openai.com/v1",
+ api_key=os.environ["OPENAI_API_KEY"],
+)
+```
\ No newline at end of file
diff --git a/docs/source/en/reference/tools.md b/docs/source/en/reference/tools.md
index 022ad35d2..9d787740c 100644
--- a/docs/source/en/reference/tools.md
+++ b/docs/source/en/reference/tools.md
@@ -57,6 +57,10 @@ contains the API docs for the underlying classes.
[[autodoc]] VisitWebpageTool
+### UserInputTool
+
+[[autodoc]] UserInputTool
+
## ToolCollection
[[autodoc]] ToolCollection
diff --git a/docs/source/en/tutorials/building_good_agents.md b/docs/source/en/tutorials/building_good_agents.md
index f2d37a20e..6cef92d15 100644
--- a/docs/source/en/tutorials/building_good_agents.md
+++ b/docs/source/en/tutorials/building_good_agents.md
@@ -19,7 +19,7 @@ rendered properly in your Markdown viewer.
There's a world of difference between building an agent that works and one that doesn't.
How can we build agents that fall into the latter category?
-In this guide, we're going to see best practices for building agents.
+In this guide, we're going to talk about best practices for building agents.
> [!TIP]
> If you're new to building agents, make sure to first read the [intro to agents](../conceptual_guides/intro_agents) and the [guided tour of smolagents](../guided_tour).
@@ -67,7 +67,7 @@ def get_weather_report_at_coordinates(coordinates, date_time):
# Dummy function, returns a list of [temperature in °C, risk of rain on a scale 0-1, wave height in m]
return [28.0, 0.35, 0.85]
-def get_coordinates_from_location(location):
+def convert_location_to_coordinates(location):
# Returns dummy coordinates
return [3.3, -42.0]
diff --git a/docs/source/en/tutorials/inspect_runs.md b/docs/source/en/tutorials/inspect_runs.md
new file mode 100644
index 000000000..021cf7ba6
--- /dev/null
+++ b/docs/source/en/tutorials/inspect_runs.md
@@ -0,0 +1,104 @@
+
+# Inspecting runs with OpenTelemetry
+
+[[open-in-colab]]
+
+> [!TIP]
+> If you're new to building agents, make sure to first read the [intro to agents](../conceptual_guides/intro_agents) and the [guided tour of smolagents](../guided_tour).
+
+### Why log your agent runs?
+
+Agent runs are complicated to debug.
+
+Validating that a run went properly is hard, since agent workflows are [unpredictable by design](../conceptual_guides/intro_agents) (if they were predictable, you'd just be using good old code).
+
+And inspecting a run is hard as well: multi-step agents tend to quickly fill a console with logs, and most of the errors are just "LLM dumb" kind of errors, from which the LLM auto-corrects in the next step by writing better code or tool calls.
+
+So using instrumentation to record agent runs is necessary in production for later inspection and monitoring!
+
+We've adopted the [OpenTelemetry](https://opentelemetry.io/) standard for instrumenting agent runs.
+
+This means that you can just run some instrumentation code, then run your agents normally, and everything gets logged into your platform.
+
+Here's how it goes:
+First install the required packages. Here we install [Phoenix by Arize AI](https://github.com/Arize-ai/phoenix) because that's a good solution to collect and inspect the logs, but there are other OpenTelemetry-compatible platforms that you could use for this collection & inspection part.
+
+```shell
+pip install smolagents
+pip install arize-phoenix opentelemetry-sdk opentelemetry-exporter-otlp openinference-instrumentation-smolagents
+```
+
+Then run the collector in the background.
+
+```shell
+python -m phoenix.server.main serve
+```
+
+Finally, set up `SmolagentsInstrumentor` to trace your agents and send the traces to Phoenix at the endpoint defined below.
+
+```python
+from opentelemetry import trace
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import BatchSpanProcessor
+
+from openinference.instrumentation.smolagents import SmolagentsInstrumentor
+from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
+from opentelemetry.sdk.trace.export import ConsoleSpanExporter, SimpleSpanProcessor
+
+endpoint = "http://0.0.0.0:6006/v1/traces"
+trace_provider = TracerProvider()
+trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))
+
+SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
+```
+Then you can run your agents!
+
+```py
+from smolagents import (
+ CodeAgent,
+ ToolCallingAgent,
+ ManagedAgent,
+ DuckDuckGoSearchTool,
+ VisitWebpageTool,
+ HfApiModel,
+)
+
+model = HfApiModel()
+
+agent = ToolCallingAgent(
+ tools=[DuckDuckGoSearchTool(), VisitWebpageTool()],
+ model=model,
+)
+managed_agent = ManagedAgent(
+ agent=agent,
+ name="managed_agent",
+ description="This is an agent that can do web search.",
+)
+manager_agent = CodeAgent(
+ tools=[],
+ model=model,
+ managed_agents=[managed_agent],
+)
+manager_agent.run(
+ "If the US keeps its 2024 growth rate, how many years will it take for the GDP to double?"
+)
+```
+And you can then navigate to `http://0.0.0.0:6006/projects/` to inspect your run!
+
+
+
+You can see that the CodeAgent called its managed ToolCallingAgent (by the way, the managed agent could be have been a CodeAgent as well) to ask it to run the web search for the U.S. 2024 growth rate. Then the managed agent returned its report and the manager agent acted upon it to calculate the economy doubling time! Sweet, isn't it?
\ No newline at end of file
diff --git a/docs/source/en/tutorials/tools.md b/docs/source/en/tutorials/tools.md
index 41556fa33..d9da1e94f 100644
--- a/docs/source/en/tutorials/tools.md
+++ b/docs/source/en/tutorials/tools.md
@@ -131,7 +131,7 @@ And voilà, here's your image! 🏖️
-Then you can use this tool just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit` and generate an image of it.
+Then you can use this tool just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit` and generate an image of it. This example also shows how you can pass additional arguments to the agent.
```python
from smolagents import CodeAgent, HfApiModel
@@ -140,7 +140,7 @@ model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
agent = CodeAgent(tools=[image_generation_tool], model=model)
agent.run(
- "Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
+ "Improve this prompt, then generate an image of it.", additional_args={'user_prompt': 'A rabbit wearing a space suit'}
)
```
@@ -204,13 +204,17 @@ agent.run(
### Use a collection of tools
-You can leverage tool collections by using the `ToolCollection` object, with the slug of the collection you want to use.
+You can leverage tool collections by using the `ToolCollection` object. It supports loading either a collection from the Hub or an MCP server tools.
+
+#### Tool Collection from a collection in the Hub
+
+You can leverage it with the slug of the collection you want to use.
Then pass them as a list to initialize your agent, and start using them!
```py
from smolagents import ToolCollection, CodeAgent
-image_tool_collection = ToolCollection(
+image_tool_collection = ToolCollection.from_hub(
collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f",
token=""
)
@@ -220,3 +224,24 @@ agent.run("Please draw me a picture of rivers and lakes.")
```
To speed up the start, tools are loaded only if called by the agent.
+
+#### Tool Collection from any MCP server
+
+Leverage tools from the hundreds of MCP servers available on [glama.ai](https://glama.ai/mcp/servers) or [smithery.ai](https://smithery.ai/).
+
+The MCP servers tools can be loaded in a `ToolCollection` object as follow:
+
+```py
+from smolagents import ToolCollection, CodeAgent
+from mcp import StdioServerParameters
+
+server_parameters = StdioServerParameters(
+ command="uv",
+ args=["--quiet", "pubmedmcp@0.1.3"],
+ env={"UV_PYTHON": "3.12", **os.environ},
+)
+
+with ToolCollection.from_mcp(server_parameters) as tool_collection:
+ agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True)
+ agent.run("Please find a remedy for hangover.")
+```
\ No newline at end of file
diff --git a/docs/source/zh/examples/multiagents.md b/docs/source/zh/examples/multiagents.md
index 4ea4e51b2..67eed890e 100644
--- a/docs/source/zh/examples/multiagents.md
+++ b/docs/source/zh/examples/multiagents.md
@@ -13,13 +13,13 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->
-# Orchestrate a multi-agent system 🤖🤝🤖
+# 编排 multi-agent 系统 🤖🤝🤖
[[open-in-colab]]
-In this notebook we will make a **multi-agent web browser: an agentic system with several agents collaborating to solve problems using the web!**
+此notebook将构建一个 **multi-agent 网络浏览器:一个有多个代理协作,使用网络进行搜索解决问题的代理系统**
-It will be a simple hierarchy, using a `ManagedAgent` object to wrap the managed web search agent:
+`ManagedAgent` 对象将封装这些管理网络搜索的agent,形成一个简单的层次结构:
```
+----------------+
@@ -38,38 +38,39 @@ It will be a simple hierarchy, using a `ManagedAgent` object to wrap the managed
| Visit webpage tool |
+--------------------------------+
```
-Let's set up this system.
-
-Run the line below to install the required dependencies:
+我们来一起构建这个系统。运行下列代码以安装依赖包:
```
!pip install markdownify duckduckgo-search smolagents --upgrade -q
```
-Let's login in order to call the HF Inference API:
+我们需要登录Hugging Face Hub以调用HF的Inference API:
-```py
-from huggingface_hub import notebook_login
+```
+from huggingface_hub import login
-notebook_login()
+login()
```
-⚡️ Our agent will be powered by [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) using `HfApiModel` class that uses HF's Inference API: the Inference API allows to quickly and easily run any OS model.
+⚡️ HF的Inference API 可以快速轻松地运行任何开源模型,因此我们的agent将使用HF的Inference API
+中的`HfApiModel`类来调用
+[Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct)模型。
-_Note:_ The Inference API hosts models based on various criteria, and deployed models may be updated or replaced without prior notice. Learn more about it [here](https://huggingface.co/docs/api-inference/supported-models).
+_Note:_ 基于多参数和部署模型的 Inference API 可能在没有预先通知的情况下更新或替换模型。了解更多信息,请参阅[这里](https://huggingface.co/docs/api-inference/supported-models)。
```py
model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
```
-## 🔍 Create a web search tool
-
-For web browsing, we can already use our pre-existing [`DuckDuckGoSearchTool`](https://github.com/huggingface/smolagents/blob/main/src/smolagents/default_tools/search.py) tool to provide a Google search equivalent.
+## 🔍 创建网络搜索工具
-But then we will also need to be able to peak into the page found by the `DuckDuckGoSearchTool`.
-To do so, we could import the library's built-in `VisitWebpageTool`, but we will build it again to see how it's done.
+虽然我们可以使用已经存在的
+[`DuckDuckGoSearchTool`](https://github.com/huggingface/smolagents/blob/main/src/smolagents/default_tools.py#L151-L176)
+工具作为谷歌搜索的平替进行网页浏览,然后我们也需要能够查看`DuckDuckGoSearchTool`找到的页面。为此,我
+们可以直接导入库的内置
+`VisitWebpageTool`。但是我们将重新构建它以了解其工作原理。
-So let's create our `VisitWebpageTool` tool from scratch using `markdownify`.
+我们将使用`markdownify` 来从头构建我们的`VisitWebpageTool`工具。
```py
import re
@@ -108,19 +109,19 @@ def visit_webpage(url: str) -> str:
return f"An unexpected error occurred: {str(e)}"
```
-Ok, now let's initialize and test our tool!
+现在我们初始化这个工具并测试它!
```py
print(visit_webpage("https://en.wikipedia.org/wiki/Hugging_Face")[:500])
```
-## Build our multi-agent system 🤖🤝🤖
+## 构建我们的 multi-agent 系统 🤖🤝🤖
-Now that we have all the tools `search` and `visit_webpage`, we can use them to create the web agent.
+现在我们有了所有工具`search`和`visit_webpage`,我们可以使用它们来创建web agent。
-Which configuration to choose for this agent?
-- Web browsing is a single-timeline task that does not require parallel tool calls, so JSON tool calling works well for that. We thus choose a `JsonAgent`.
-- Also, since sometimes web search requires exploring many pages before finding the correct answer, we prefer to increase the number of `max_steps` to 10.
+我们该选取什么样的配置来构建这个agent呢?
+- 网页浏览是一个单线程任务,不需要并行工具调用,因此JSON工具调用对于这个任务非常有效。因此我们选择`JsonAgent`。
+- 有时候网页搜索需要探索许多页面才能找到正确答案,所以我们更喜欢将 `max_steps` 增加到10。
```py
from smolagents import (
@@ -141,7 +142,7 @@ web_agent = ToolCallingAgent(
)
```
-We then wrap this agent into a `ManagedAgent` that will make it callable by its manager agent.
+然后我们将这个agent封装到一个`ManagedAgent`中,使其可以被其管理的agent调用。
```py
managed_web_agent = ManagedAgent(
@@ -151,11 +152,7 @@ managed_web_agent = ManagedAgent(
)
```
-Finally we create a manager agent, and upon initialization we pass our managed agent to it in its `managed_agents` argument.
-
-Since this agent is the one tasked with the planning and thinking, advanced reasoning will be beneficial, so a `CodeAgent` will be the best choice.
-
-Also, we want to ask a question that involves the current year and does additional data calculations: so let us add `additional_authorized_imports=["time", "numpy", "pandas"]`, just in case the agent needs these packages.
+最后,我们创建一个manager agent,在初始化时将我们的managed agent传递给它的`managed_agents`参数。因为这个agent负责计划和思考,所以高级推理将是有益的,因此`CodeAgent`将是最佳选择。此外,我们想要问一个涉及当前年份的问题,并进行额外的数据计算:因此让我们添加`additional_authorized_imports=["time", "numpy", "pandas"]`,以防agent需要这些包。
```py
manager_agent = CodeAgent(
@@ -166,34 +163,32 @@ manager_agent = CodeAgent(
)
```
-That's all! Now let's run our system! We select a question that requires both some calculation and research:
+可以了!现在让我们运行我们的系统!我们选择一个需要一些计算和研究的问题:
```py
answer = manager_agent.run("If LLM training continues to scale up at the current rhythm until 2030, what would be the electric power in GW required to power the biggest training runs by 2030? What would that correspond to, compared to some countries? Please provide a source for any numbers used.")
```
-We get this report as the answer:
+我们用这个report 来回答这个问题:
```
-Based on current growth projections and energy consumption estimates, if LLM trainings continue to scale up at the
+Based on current growth projections and energy consumption estimates, if LLM trainings continue to scale up at the
current rhythm until 2030:
-1. The electric power required to power the biggest training runs by 2030 would be approximately 303.74 GW, which
+1. The electric power required to power the biggest training runs by 2030 would be approximately 303.74 GW, which
translates to about 2,660,762 GWh/year.
-2. Comparing this to countries' electricity consumption:
+1. Comparing this to countries' electricity consumption:
- It would be equivalent to about 34% of China's total electricity consumption.
- It would exceed the total electricity consumption of India (184%), Russia (267%), and Japan (291%).
- It would be nearly 9 times the electricity consumption of countries like Italy or Mexico.
-3. Source of numbers:
+2. Source of numbers:
- The initial estimate of 5 GW for future LLM training comes from AWS CEO Matt Garman.
- The growth projection used a CAGR of 79.80% from market research by Springs.
- - Country electricity consumption data is from the U.S. Energy Information Administration, primarily for the year
+ - Country electricity consumption data is from the U.S. Energy Information Administration, primarily for the year
2021.
```
-Seems like we'll need some sizeable powerplants if the [scaling hypothesis](https://gwern.net/scaling-hypothesis) continues to hold true.
-
-Our agents managed to efficiently collaborate towards solving the task! ✅
+如果[scaling hypothesis](https://gwern.net/scaling-hypothesis)持续成立的话,我们需要一些庞大的动力配置。我们的agent成功地协作解决了这个任务!✅
-💡 You can easily extend this orchestration to more agents: one does the code execution, one the web search, one handles file loadings...
\ No newline at end of file
+💡 你可以轻松地将这个编排扩展到更多的agent:一个执行代码,一个进行网页搜索,一个处理文件加载⋯⋯
diff --git a/docs/source/zh/guided_tour.md b/docs/source/zh/guided_tour.md
index 07988fee0..9816a4fa3 100644
--- a/docs/source/zh/guided_tour.md
+++ b/docs/source/zh/guided_tour.md
@@ -173,9 +173,9 @@ Transformers 附带了一个用于增强 agent 的默认工具箱,您可以在
您可以通过调用 [`load_tool`] 函数和要执行的任务手动使用工具。
```python
-from smolagents import load_tool
+from smolagents import DuckDuckGoSearchTool
-search_tool = load_tool("web_search")
+search_tool = DuckDuckGoSearchTool()
print(search_tool("Who's the current president of Russia?"))
```
diff --git a/docs/source/zh/reference/agents.md b/docs/source/zh/reference/agents.md
index 9cdca7d0b..dc011d37e 100644
--- a/docs/source/zh/reference/agents.md
+++ b/docs/source/zh/reference/agents.md
@@ -136,8 +136,8 @@ messages = [
{"role": "user", "content": "No need to help, take it easy."},
]
-model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest", temperature=0.2)
-print(model(messages, max_tokens=10))
+model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest", temperature=0.2, max_tokens=10)
+print(model(messages))
```
[[autodoc]] LiteLLMModel
\ No newline at end of file
diff --git a/docs/source/zh/tutorials/tools.md b/docs/source/zh/tutorials/tools.md
index 216d93b96..a5d15eb36 100644
--- a/docs/source/zh/tutorials/tools.md
+++ b/docs/source/zh/tutorials/tools.md
@@ -186,7 +186,7 @@ from smolagents import HfApiModel
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
-agent.tools.append(model_download_tool)
+agent.tools[model_download_tool.name] = model_download_tool
```
现在我们可以利用新工具:
@@ -209,7 +209,7 @@ agent.run(
```py
from smolagents import ToolCollection, CodeAgent
-image_tool_collection = ToolCollection(
+image_tool_collection = ToolCollection.from_hub(
collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f",
token=""
)
diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb
index 8b49b0aa2..e0b59d591 100644
--- a/examples/benchmark.ipynb
+++ b/examples/benchmark.ipynb
@@ -181,7 +181,7 @@
"import datasets\n",
"import pandas as pd\n",
"\n",
- "eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"train\"]\n",
+ "eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"test\"]\n",
"pd.DataFrame(eval_ds)"
]
},
@@ -253,11 +253,14 @@
"\n",
" if is_vanilla_llm:\n",
" llm = agent\n",
- " answer = llm([{\"role\": \"user\", \"content\": question}])\n",
- " token_count = llm.last_input_token_count + llm.last_output_token_count\n",
- " intermediate_steps = []\n",
+ " answer = str(llm([{\"role\": \"user\", \"content\": question}]).content)\n",
+ " token_count = {\n",
+ " \"input\": llm.last_input_token_count,\n",
+ " \"output\": llm.last_output_token_count,\n",
+ " }\n",
+ " intermediate_steps = str([])\n",
" else:\n",
- " answer = agent.run(question)\n",
+ " answer = str(agent.run(question))\n",
" token_count = agent.monitor.get_total_token_counts()\n",
" intermediate_steps = str(agent.logs)\n",
" # Remove memory from logs to make them more compact.\n",
@@ -983,7 +986,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -1043,8 +1046,8 @@
"\n",
"\n",
"# Usage (after running your previous data processing code):\n",
- "mathjax_table = create_mathjax_table(pivot_df, formatted_df)\n",
- "print(mathjax_table)"
+ "# mathjax_table = create_mathjax_table(pivot_df, formatted_df)\n",
+ "# print(mathjax_table)"
]
}
],
diff --git a/examples/e2b_example.py b/examples/e2b_example.py
index 049dc159a..843e14406 100644
--- a/examples/e2b_example.py
+++ b/examples/e2b_example.py
@@ -4,8 +4,9 @@
load_dotenv()
+
class GetCatImageTool(Tool):
- name="get_cat_image"
+ name = "get_cat_image"
description = "Get a cat image"
inputs = {}
output_type = "image"
@@ -27,17 +28,22 @@ def forward(self):
get_cat_image = GetCatImageTool()
agent = CodeAgent(
- tools = [get_cat_image, VisitWebpageTool()],
+ tools=[get_cat_image, VisitWebpageTool()],
model=HfApiModel(),
- additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
- use_e2b_executor=True
+ additional_authorized_imports=[
+ "Pillow",
+ "requests",
+ "markdownify",
+ ], # "duckduckgo-search",
+ use_e2b_executor=True,
)
agent.run(
- "Return me an image of a cat. Directly use the image provided in your state.", additional_args={"cat_image":get_cat_image()}
-) # Asking to directly return the image from state tests that additional_args are properly sent to server.
+ "Return me an image of a cat. Directly use the image provided in your state.",
+ additional_args={"cat_image": get_cat_image()},
+) # Asking to directly return the image from state tests that additional_args are properly sent to server.
# Try the agent in a Gradio UI
from smolagents import GradioUI
-GradioUI(agent).launch()
\ No newline at end of file
+GradioUI(agent).launch()
diff --git a/examples/gradio_upload.py b/examples/gradio_upload.py
index 4b8425d83..061d22692 100644
--- a/examples/gradio_upload.py
+++ b/examples/gradio_upload.py
@@ -1,11 +1,5 @@
-from smolagents import (
- CodeAgent,
- HfApiModel,
- GradioUI
-)
+from smolagents import CodeAgent, HfApiModel, GradioUI
-agent = CodeAgent(
- tools=[], model=HfApiModel(), max_steps=4, verbosity_level=0
-)
+agent = CodeAgent(tools=[], model=HfApiModel(), max_steps=4, verbosity_level=1)
-GradioUI(agent, file_upload_folder='./data').launch()
+GradioUI(agent, file_upload_folder="./data").launch()
diff --git a/examples/inspect_runs.py b/examples/inspect_runs.py
new file mode 100644
index 000000000..3e24efaca
--- /dev/null
+++ b/examples/inspect_runs.py
@@ -0,0 +1,44 @@
+from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import SimpleSpanProcessor
+
+from openinference.instrumentation.smolagents import SmolagentsInstrumentor
+
+from smolagents import (
+ CodeAgent,
+ DuckDuckGoSearchTool,
+ VisitWebpageTool,
+ ManagedAgent,
+ ToolCallingAgent,
+ HfApiModel,
+)
+
+# Let's setup the instrumentation first
+
+trace_provider = TracerProvider()
+trace_provider.add_span_processor(
+ SimpleSpanProcessor(OTLPSpanExporter("http://0.0.0.0:6006/v1/traces"))
+)
+
+SmolagentsInstrumentor().instrument(tracer_provider=trace_provider, skip_dep_check=True)
+
+# Then we run the agentic part!
+model = HfApiModel()
+
+agent = ToolCallingAgent(
+ tools=[DuckDuckGoSearchTool(), VisitWebpageTool()],
+ model=model,
+)
+managed_agent = ManagedAgent(
+ agent=agent,
+ name="managed_agent",
+ description="This is an agent that can do web search.",
+)
+manager_agent = CodeAgent(
+ tools=[],
+ model=model,
+ managed_agents=[managed_agent],
+)
+manager_agent.run(
+ "If the US keeps it 2024 growth rate, how many years would it take for the GDP to double?"
+)
diff --git a/examples/rag.py b/examples/rag.py
index 4096d57f0..83a201d7e 100644
--- a/examples/rag.py
+++ b/examples/rag.py
@@ -8,7 +8,9 @@
knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
-knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
+knowledge_base = knowledge_base.filter(
+ lambda row: row["source"].startswith("huggingface/transformers")
+)
source_docs = [
Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
@@ -26,6 +28,7 @@
from smolagents import Tool
+
class RetrieverTool(Tool):
name = "retriever"
description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
@@ -39,9 +42,7 @@ class RetrieverTool(Tool):
def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
- self.retriever = BM25Retriever.from_documents(
- docs, k=10
- )
+ self.retriever = BM25Retriever.from_documents(docs, k=10)
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
@@ -56,14 +57,20 @@ def forward(self, query: str) -> str:
]
)
+
from smolagents import HfApiModel, CodeAgent
retriever_tool = RetrieverTool(docs_processed)
agent = CodeAgent(
- tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_steps=4, verbosity_level=2
+ tools=[retriever_tool],
+ model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"),
+ max_steps=4,
+ verbosity_level=2,
)
-agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
+agent_output = agent.run(
+ "For a transformers model training, which is slower, the forward or the backward pass?"
+)
print("Final output:")
print(agent_output)
diff --git a/examples/text_to_sql.py b/examples/text_to_sql.py
index 9a8b2014f..60b84f651 100644
--- a/examples/text_to_sql.py
+++ b/examples/text_to_sql.py
@@ -40,11 +40,14 @@
inspector = inspect(engine)
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
-table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
+table_description = "Columns:\n" + "\n".join(
+ [f" - {name}: {col_type}" for name, col_type in columns_info]
+)
print(table_description)
from smolagents import tool
+
@tool
def sql_engine(query: str) -> str:
"""
@@ -66,10 +69,11 @@ def sql_engine(query: str) -> str:
output += "\n" + str(row)
return output
+
from smolagents import CodeAgent, HfApiModel
agent = CodeAgent(
tools=[sql_engine],
model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
)
-agent.run("Can you give me the name of the client who got the most expensive receipt?")
\ No newline at end of file
+agent.run("Can you give me the name of the client who got the most expensive receipt?")
diff --git a/examples/tool_calling_agent_from_any_llm.py b/examples/tool_calling_agent_from_any_llm.py
index 68155fe9a..05daaa50e 100644
--- a/examples/tool_calling_agent_from_any_llm.py
+++ b/examples/tool_calling_agent_from_any_llm.py
@@ -9,6 +9,7 @@
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
model = LiteLLMModel(model_id="gpt-4o")
+
@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
@@ -21,6 +22,7 @@ def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
+
agent = ToolCallingAgent(tools=[get_weather], model=model)
-print(agent.run("What's the weather like in Paris?"))
\ No newline at end of file
+print(agent.run("What's the weather like in Paris?"))
diff --git a/examples/tool_calling_agent_mcp.py b/examples/tool_calling_agent_mcp.py
new file mode 100644
index 000000000..c0e613a1e
--- /dev/null
+++ b/examples/tool_calling_agent_mcp.py
@@ -0,0 +1,27 @@
+"""An example of loading a ToolCollection directly from an MCP server.
+
+Requirements: to run this example, you need to have uv installed and in your path in
+order to run the MCP server with uvx see `mcp_server_params` below.
+
+Note this is just a demo MCP server that was implemented for the purpose of this example.
+It only provide a single tool to search amongst pubmed papers abstracts.
+
+Usage:
+>>> uv run examples/tool_calling_agent_mcp.py
+"""
+
+import os
+
+from mcp import StdioServerParameters
+from smolagents import CodeAgent, HfApiModel, ToolCollection
+
+mcp_server_params = StdioServerParameters(
+ command="uvx",
+ args=["--quiet", "pubmedmcp@0.1.3"],
+ env={"UV_PYTHON": "3.12", **os.environ},
+)
+
+with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
+ # print(tool_collection.tools[0](request={"term": "efficient treatment hangover"}))
+ agent = CodeAgent(tools=tool_collection.tools, model=HfApiModel(), max_steps=4)
+ agent.run("Find me one risk associated with drinking alcohol regularly on low doses for humans.")
diff --git a/examples/tool_calling_agent_ollama.py b/examples/tool_calling_agent_ollama.py
index ad914f84e..c7198d68d 100644
--- a/examples/tool_calling_agent_ollama.py
+++ b/examples/tool_calling_agent_ollama.py
@@ -4,10 +4,11 @@
model = LiteLLMModel(
model_id="ollama_chat/llama3.2",
- api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
- api_key="your-api-key" # replace with API key if necessary
+ api_base="http://localhost:11434", # replace with remote open-ai compatible server if necessary
+ api_key="your-api-key", # replace with API key if necessary
)
+
@tool
def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
@@ -20,6 +21,7 @@ def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"""
return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
+
agent = ToolCallingAgent(tools=[get_weather], model=model)
print(agent.run("What's the weather like in Paris?"))
diff --git a/pyproject.toml b/pyproject.toml
index 1fc22662f..e3ff96df5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "smolagents"
-version = "1.3.0.dev"
+version = "1.4.0.dev"
description = "🤗 smolagents: a barebones library for agents. Agents write python code to call tools or orchestrate other agents."
authors = [
{ name="Aymeric Roucher", email="aymeric@hf.co" }, { name="Thomas Wolf"},
@@ -23,30 +23,46 @@ dependencies = [
"duckduckgo-search>=6.3.7",
"python-dotenv>=1.0.1",
"e2b-code-interpreter>=1.0.3",
- "openai>=1.58.1",
]
-[tool.ruff]
-lint.ignore = ["F403"]
-
[project.optional-dependencies]
-dev = [
+audio = [
+ "soundfile",
+]
+torch = [
"torch",
- "torchaudio",
- "torchvision",
- "sqlalchemy",
"accelerate",
- "soundfile",
+]
+litellm = [
"litellm>=1.55.10",
]
+mcp = [
+ "mcpadapt>=0.0.6",
+ "mcp",
+]
+openai = [
+ "openai>=1.58.1"
+]
+quality = [
+ "ruff>=0.9.0",
+]
test = [
- "torch",
- "torchaudio",
- "torchvision",
"pytest>=8.1.0",
- "sqlalchemy",
- "ruff>=0.5.0",
- "accelerate",
- "soundfile",
- "litellm>=1.55.10",
+ "smolagents[audio,litellm,mcp,openai,torch]",
+]
+dev = [
+ "smolagents[quality,test]",
+ "sqlalchemy", # for ./examples
+]
+
+[tool.pytest.ini_options]
+# Add the specified `OPTS` to the set of command line arguments as if they had been specified by the user.
+addopts = "-sv --durations=0"
+
+[tool.ruff]
+lint.ignore = ["F403"]
+
+[tool.ruff.lint.per-file-ignores]
+"examples/*" = [
+ "E402", # module-import-not-at-top-of-file
]
diff --git a/src/smolagents/__init__.py b/src/smolagents/__init__.py
index 055fba7fc..f457b7e34 100644
--- a/src/smolagents/__init__.py
+++ b/src/smolagents/__init__.py
@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__ = "1.3.0.dev"
+__version__ = "1.4.0.dev"
from typing import TYPE_CHECKING
diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py
index cfa8a6ff4..da791aa43 100644
--- a/src/smolagents/agents.py
+++ b/src/smolagents/agents.py
@@ -75,12 +75,12 @@ class ToolCall:
id: str
-class AgentStep:
+class AgentStepLog:
pass
@dataclass
-class ActionStep(AgentStep):
+class ActionStep(AgentStepLog):
agent_memory: List[Dict[str, str]] | None = None
tool_calls: List[ToolCall] | None = None
start_time: float | None = None
@@ -94,18 +94,18 @@ class ActionStep(AgentStep):
@dataclass
-class PlanningStep(AgentStep):
+class PlanningStep(AgentStepLog):
plan: str
facts: str
@dataclass
-class TaskStep(AgentStep):
+class TaskStep(AgentStepLog):
task: str
@dataclass
-class SystemPromptStep(AgentStep):
+class SystemPromptStep(AgentStepLog):
system_prompt: str
@@ -891,34 +891,22 @@ def __init__(
additional_authorized_imports: Optional[List[str]] = None,
planning_interval: Optional[int] = None,
use_e2b_executor: bool = False,
+ max_print_outputs_length: Optional[int] = None,
**kwargs,
):
if system_prompt is None:
system_prompt = CODE_SYSTEM_PROMPT
- super().__init__(
- tools=tools,
- model=model,
- system_prompt=system_prompt,
- grammar=grammar,
- planning_interval=planning_interval,
- **kwargs,
- )
+
self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else []
)
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
)
- if "{{authorized_imports}}" not in self.system_prompt:
+ if "{{authorized_imports}}" not in system_prompt:
raise AgentError(
"Tag '{{authorized_imports}}' should be provided in the prompt."
)
- self.system_prompt = self.system_prompt.replace(
- "{{authorized_imports}}",
- "You can import from any package you want."
- if "*" in self.authorized_imports
- else str(self.authorized_imports),
- )
if "*" in self.additional_authorized_imports:
self.logger.log(
@@ -926,6 +914,14 @@ def __init__(
0,
)
+ super().__init__(
+ tools=tools,
+ model=model,
+ system_prompt=system_prompt,
+ grammar=grammar,
+ planning_interval=planning_interval,
+ **kwargs,
+ )
if use_e2b_executor and len(self.managed_agents) > 0:
raise Exception(
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
@@ -942,8 +938,19 @@ def __init__(
self.python_executor = LocalPythonInterpreter(
self.additional_authorized_imports,
all_tools,
+ max_print_outputs_length=max_print_outputs_length,
)
+ def initialize_system_prompt(self):
+ super().initialize_system_prompt()
+ self.system_prompt = self.system_prompt.replace(
+ "{{authorized_imports}}",
+ "You can import from any package you want."
+ if "*" in self.authorized_imports
+ else str(self.authorized_imports),
+ )
+ return self.system_prompt
+
def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""
Perform one step in the ReAct framework: the agent thinks, acts, and observes the result.
diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py
index 59f6820f2..14a46ae24 100644
--- a/src/smolagents/default_tools.py
+++ b/src/smolagents/default_tools.py
@@ -31,7 +31,6 @@
)
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
from .types import AgentAudio
-from .utils import truncate_content
if is_torch_available():
from transformers.models.whisper import (
@@ -145,7 +144,7 @@ class UserInputTool(Tool):
output_type = "string"
def forward(self, question):
- user_input = input(f"{question} => ")
+ user_input = input(f"{question} => Type your answer here:")
return user_input
@@ -278,6 +277,7 @@ def forward(self, url: str) -> str:
import requests
from markdownify import markdownify
from requests.exceptions import RequestException
+ from smolagents.utils import truncate_content
except ImportError:
raise ImportError(
"You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
@@ -293,7 +293,7 @@ def forward(self, url: str) -> str:
# Remove multiple line breaks
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
- return truncate_content(markdown_content)
+ return truncate_content(markdown_content, 10000)
except RequestException as e:
return f"Error fetching the webpage: {str(e)}"
diff --git a/src/smolagents/e2b_executor.py b/src/smolagents/e2b_executor.py
index e8cc89347..8a20a9e27 100644
--- a/src/smolagents/e2b_executor.py
+++ b/src/smolagents/e2b_executor.py
@@ -43,7 +43,7 @@ def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
# )
# print("Installation of agents package finished.")
self.logger = logger
- additional_imports = additional_imports + ["pickle5"]
+ additional_imports = additional_imports + ["pickle5", "smolagents"]
if len(additional_imports) > 0:
execution = self.sbx.commands.run(
"pip install " + " ".join(additional_imports)
diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py
index 42a4183dc..e04056dbe 100644
--- a/src/smolagents/gradio_ui.py
+++ b/src/smolagents/gradio_ui.py
@@ -19,11 +19,13 @@
import mimetypes
import re
-from .agents import ActionStep, AgentStep, MultiStepAgent
+from typing import Optional
+
+from .agents import ActionStep, AgentStepLog, MultiStepAgent
from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
-def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
+def pull_messages_from_step(step_log: AgentStepLog, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
@@ -53,11 +55,13 @@ def stream_to_gradio(
task: str,
test_mode: bool = False,
reset_agent_memory: bool = False,
- **kwargs,
+ additional_args: Optional[dict] = None,
):
"""Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
- for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
+ for step_log in agent.run(
+ task, stream=True, reset=reset_agent_memory, additional_args=additional_args
+ ):
for message in pull_messages_from_step(step_log, test_mode=test_mode):
yield message
@@ -116,15 +120,15 @@ def upload_file(
"""
if file is None:
- return "No file uploaded"
+ return gr.Textbox("No file uploaded", visible=True), file_uploads_log
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
- return f"Error: {e}"
+ return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
if mime_type not in allowed_file_types:
- return "File type disallowed"
+ return gr.Textbox("File type disallowed", visible=True), file_uploads_log
# Sanitize file name
original_name = os.path.basename(file.name)
@@ -155,9 +159,11 @@ def upload_file(
def log_user_message(self, text_input, file_uploads_log):
return (
text_input
- + f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
- if len(file_uploads_log) > 0
- else "",
+ + (
+ f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
+ if len(file_uploads_log) > 0
+ else ""
+ ),
"",
)
@@ -170,12 +176,13 @@ def launch(self):
type="messages",
avatar_images=(
None,
- "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
),
+ resizeable=True,
)
# If an upload folder is provided, enable the upload feature
if self.file_upload_folder is not None:
- upload_file = gr.File(label="Upload a file", height=1)
+ upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(
label="Upload Status", interactive=False, visible=False
)
diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py
index e54e0594c..0c7b5bc38 100644
--- a/src/smolagents/local_python_executor.py
+++ b/src/smolagents/local_python_executor.py
@@ -21,6 +21,7 @@
import re
from collections.abc import Mapping
from importlib import import_module
+from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
@@ -45,7 +46,7 @@ class InterpreterError(ValueError):
and issubclass(getattr(builtins, name), BaseException)
}
-PRINT_OUTPUTS, MAX_LEN_OUTPUT = "", 50000
+PRINT_OUTPUTS, DEFAULT_MAX_LEN_OUTPUT = "", 50000
OPERATIONS_COUNT, MAX_OPERATIONS = 0, 10000000
@@ -591,7 +592,11 @@ def evaluate_call(
custom_tools: Dict[str, Callable],
authorized_imports: List[str],
) -> Any:
- if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
+ if not (
+ isinstance(call.func, ast.Attribute)
+ or isinstance(call.func, ast.Name)
+ or isinstance(call.func, ast.Subscript)
+ ):
raise InterpreterError(f"This is not a correct function: {call.func}).")
if isinstance(call.func, ast.Attribute):
obj = evaluate_ast(
@@ -617,6 +622,23 @@ def evaluate_call(
f"It is not permitted to evaluate other functions than the provided tools or functions defined in previous code (tried to execute {call.func.id})."
)
+ elif isinstance(call.func, ast.Subscript):
+ value = evaluate_ast(
+ call.func.value, state, static_tools, custom_tools, authorized_imports
+ )
+ index = evaluate_ast(
+ call.func.slice, state, static_tools, custom_tools, authorized_imports
+ )
+ if isinstance(value, (list, tuple)):
+ func = value[index]
+ else:
+ raise InterpreterError(
+ f"Cannot subscript object of type {type(value).__name__}"
+ )
+
+ if not callable(func):
+ raise InterpreterError(f"This is not a correct function: {call.func}).")
+ func_name = None
args = []
for arg in call.args:
if isinstance(arg, ast.Starred):
@@ -726,6 +748,8 @@ def evaluate_name(
return state[name.id]
elif name.id in static_tools:
return static_tools[name.id]
+ elif name.id in custom_tools:
+ return custom_tools[name.id]
elif name.id in ERRORS:
return ERRORS[name.id]
close_matches = difflib.get_close_matches(name.id, list(state.keys()))
@@ -1023,12 +1047,75 @@ def evaluate_with(
context.__exit__(None, None, None)
+def get_safe_module(unsafe_module, dangerous_patterns, visited=None):
+ """Creates a safe copy of a module or returns the original if it's a function"""
+ # If it's a function or non-module object, return it directly
+ if not isinstance(unsafe_module, ModuleType):
+ return unsafe_module
+
+ # Handle circular references: Initialize visited set for the first call
+ if visited is None:
+ visited = set()
+
+ module_id = id(unsafe_module)
+ if module_id in visited:
+ return unsafe_module # Return original for circular refs
+
+ visited.add(module_id)
+
+ # Create new module for actual modules
+ safe_module = ModuleType(unsafe_module.__name__)
+
+ # Copy all attributes by reference, recursively checking modules
+ for attr_name in dir(unsafe_module):
+ # Skip dangerous patterns at any level
+ if any(
+ pattern in f"{unsafe_module.__name__}.{attr_name}"
+ for pattern in dangerous_patterns
+ ):
+ continue
+
+ attr_value = getattr(unsafe_module, attr_name)
+
+ # Recursively process nested modules, passing visited set
+ if isinstance(attr_value, ModuleType):
+ attr_value = get_safe_module(
+ attr_value, dangerous_patterns, visited=visited
+ )
+
+ setattr(safe_module, attr_name, attr_value)
+
+ return safe_module
+
+
def import_modules(expression, state, authorized_imports):
+ dangerous_patterns = (
+ "_os",
+ "os",
+ "subprocess",
+ "_subprocess",
+ "pty",
+ "system",
+ "popen",
+ "spawn",
+ "shutil",
+ "sys",
+ "pathlib",
+ "io",
+ "socket",
+ "compile",
+ "eval",
+ "exec",
+ "multiprocessing",
+ )
+
def check_module_authorized(module_name):
if "*" in authorized_imports:
return True
else:
module_path = module_name.split(".")
+ if any([module in dangerous_patterns for module in module_path]):
+ return False
module_subpaths = [
".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
]
@@ -1037,8 +1124,10 @@ def check_module_authorized(module_name):
if isinstance(expression, ast.Import):
for alias in expression.names:
if check_module_authorized(alias.name):
- module = import_module(alias.name)
- state[alias.asname or alias.name] = module
+ raw_module = import_module(alias.name)
+ state[alias.asname or alias.name] = get_safe_module(
+ raw_module, dangerous_patterns
+ )
else:
raise InterpreterError(
f"Import of {alias.name} is not allowed. Authorized imports are: {str(authorized_imports)}"
@@ -1046,11 +1135,13 @@ def check_module_authorized(module_name):
return None
elif isinstance(expression, ast.ImportFrom):
if check_module_authorized(expression.module):
- module = __import__(
+ raw_module = __import__(
expression.module, fromlist=[alias.name for alias in expression.names]
)
for alias in expression.names:
- state[alias.asname or alias.name] = getattr(module, alias.name)
+ state[alias.asname or alias.name] = get_safe_module(
+ getattr(raw_module, alias.name), dangerous_patterns
+ )
else:
raise InterpreterError(f"Import from {expression.module} is not allowed.")
return None
@@ -1344,15 +1435,6 @@ def evaluate_ast(
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
-def truncate_print_outputs(
- print_outputs: str, max_len_outputs: int = MAX_LEN_OUTPUT
-) -> str:
- if len(print_outputs) < max_len_outputs:
- return print_outputs
- else:
- return f"Print outputs:\n{print_outputs[:max_len_outputs]}\n_Print outputs have been truncated over the limit of {max_len_outputs} characters._\n"
-
-
class FinalAnswerException(Exception):
def __init__(self, value):
self.value = value
@@ -1364,6 +1446,7 @@ def evaluate_python_code(
custom_tools: Optional[Dict[str, Callable]] = None,
state: Optional[Dict[str, Any]] = None,
authorized_imports: List[str] = BASE_BUILTIN_MODULES,
+ max_print_outputs_length: int = DEFAULT_MAX_LEN_OUTPUT,
):
"""
Evaluate a python expression using the content of the variables stored in a state and only evaluating a given set
@@ -1409,26 +1492,34 @@ def final_answer(value):
node, state, static_tools, custom_tools, authorized_imports
)
state["print_outputs"] = truncate_content(
- PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
+ PRINT_OUTPUTS, max_length=max_print_outputs_length
)
is_final_answer = False
return result, is_final_answer
except FinalAnswerException as e:
state["print_outputs"] = truncate_content(
- PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT
+ PRINT_OUTPUTS, max_length=max_print_outputs_length
)
is_final_answer = True
return e.value, is_final_answer
except InterpreterError as e:
- msg = truncate_content(PRINT_OUTPUTS, max_length=MAX_LEN_OUTPUT)
+ msg = truncate_content(PRINT_OUTPUTS, max_length=max_print_outputs_length)
msg += f"Code execution failed at line '{ast.get_source_segment(code, node)}' because of the following error:\n{e}"
raise InterpreterError(msg)
class LocalPythonInterpreter:
- def __init__(self, additional_authorized_imports: List[str], tools: Dict):
+ def __init__(
+ self,
+ additional_authorized_imports: List[str],
+ tools: Dict,
+ max_print_outputs_length: Optional[int] = None,
+ ):
self.custom_tools = {}
self.state = {}
+ self.max_print_outputs_length = max_print_outputs_length
+ if max_print_outputs_length is None:
+ self.max_print_outputs_length = DEFAULT_MAX_LEN_OUTPUT
self.additional_authorized_imports = additional_authorized_imports
self.authorized_imports = list(
set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
@@ -1450,6 +1541,7 @@ def __call__(
custom_tools=self.custom_tools,
state=self.state,
authorized_imports=self.authorized_imports,
+ max_print_outputs_length=self.max_print_outputs_length,
)
logs = self.state["print_outputs"]
return output, logs, is_final_answer
diff --git a/src/smolagents/models.py b/src/smolagents/models.py
index f25ced9c6..ca234f2a3 100644
--- a/src/smolagents/models.py
+++ b/src/smolagents/models.py
@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from dataclasses import dataclass
+from dataclasses import dataclass, asdict
import json
import logging
import os
@@ -32,9 +32,6 @@
StoppingCriteriaList,
is_torch_available,
)
-from transformers.utils.import_utils import _is_package_available
-
-import openai
from .tools import Tool
@@ -50,8 +47,14 @@
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```",
}
-if _is_package_available("litellm"):
- import litellm
+
+def get_dict_from_nested_dataclasses(obj):
+ def convert(obj):
+ if hasattr(obj, "__dataclass_fields__"):
+ return {k: convert(v) for k, v in asdict(obj).items()}
+ return obj
+
+ return convert(obj)
@dataclass
@@ -60,6 +63,14 @@ class ChatMessageToolCallDefinition:
name: str
description: Optional[str] = None
+ @classmethod
+ def from_hf_api(cls, tool_call_definition) -> "ChatMessageToolCallDefinition":
+ return cls(
+ arguments=tool_call_definition.arguments,
+ name=tool_call_definition.name,
+ description=tool_call_definition.description,
+ )
+
@dataclass
class ChatMessageToolCall:
@@ -67,6 +78,14 @@ class ChatMessageToolCall:
id: str
type: str
+ @classmethod
+ def from_hf_api(cls, tool_call) -> "ChatMessageToolCall":
+ return cls(
+ function=ChatMessageToolCallDefinition.from_hf_api(tool_call.function),
+ id=tool_call.id,
+ type=tool_call.type,
+ )
+
@dataclass
class ChatMessage:
@@ -74,6 +93,19 @@ class ChatMessage:
content: Optional[str] = None
tool_calls: Optional[List[ChatMessageToolCall]] = None
+ def model_dump_json(self):
+ return json.dumps(get_dict_from_nested_dataclasses(self))
+
+ @classmethod
+ def from_hf_api(cls, message) -> "ChatMessage":
+ tool_calls = None
+ if getattr(message, "tool_calls", None) is not None:
+ tool_calls = [
+ ChatMessageToolCall.from_hf_api(tool_call)
+ for tool_call in message.tool_calls
+ ]
+ return cls(role=message.role, content=message.content, tool_calls=tool_calls)
+
class MessageRole(str, Enum):
USER = "user"
@@ -184,7 +216,6 @@ def __call__(
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
- max_tokens: int = 1500,
) -> ChatMessage:
"""Process the input messages and return the model's response.
@@ -195,8 +226,6 @@ def __call__(
A list of strings that will stop the generation if encountered in the model's output.
grammar (`str`, *optional*):
The grammar or formatting structure to use in the model's response.
- max_tokens (`int`, *optional*):
- The maximum count of tokens to generate.
Returns:
`str`: The text content of the model's response.
"""
@@ -206,7 +235,7 @@ def __call__(
class HfApiModel(Model):
"""A class to interact with Hugging Face's Inference API for language model interaction.
- This engine allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
+ This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
Parameters:
model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
@@ -227,9 +256,10 @@ class HfApiModel(Model):
>>> engine = HfApiModel(
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
... token="your_hf_token_here",
+ ... max_tokens=5000,
... )
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
- >>> response = engine(messages, stop_sequences=["END"], max_tokens=1500)
+ >>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."
```
@@ -241,6 +271,7 @@ def __init__(
token: Optional[str] = None,
timeout: Optional[int] = 120,
temperature: float = 0.5,
+ **kwargs,
):
super().__init__()
self.model_id = model_id
@@ -248,13 +279,13 @@ def __init__(
token = os.getenv("HF_TOKEN")
self.client = InferenceClient(self.model_id, token=token, timeout=timeout)
self.temperature = temperature
+ self.kwargs = kwargs
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
- max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
"""
@@ -270,33 +301,64 @@ def __call__(
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto",
stop=stop_sequences,
- max_tokens=max_tokens,
temperature=self.temperature,
+ **self.kwargs,
)
else:
response = self.client.chat.completions.create(
model=self.model_id,
messages=messages,
stop=stop_sequences,
- max_tokens=max_tokens,
temperature=self.temperature,
+ **self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens
self.last_output_token_count = response.usage.completion_tokens
- return response.choices[0].message
+ return ChatMessage.from_hf_api(response.choices[0].message)
class TransformersModel(Model):
- """This engine initializes a model and tokenizer from the given `model_id`.
+ """A class to interact with Hugging Face's Inference API for language model interaction.
+
+ This model allows you to communicate with Hugging Face's models using the Inference API. It can be used in both serverless mode or with a dedicated endpoint, supporting features like stop sequences and grammar customization.
Parameters:
- model_id (`str`, *optional*, defaults to `"HuggingFaceTB/SmolLM2-1.7B-Instruct"`):
+ model_id (`str`, *optional*, defaults to `"Qwen/Qwen2.5-Coder-32B-Instruct"`):
The Hugging Face model ID to be used for inference. This can be a path or model identifier from the Hugging Face model hub.
- device (`str`, optional, defaults to `"cuda"` if available, else `"cpu"`.):
- The device to load the model on (`"cpu"` or `"cuda"`).
+ device_map (`str`, *optional*):
+ The device_map to initialize your model with.
+ torch_dtype (`str`, *optional*):
+ The torch_dtype to initialize your model with.
+ trust_remote_code (bool):
+ Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
+ kwargs (dict, *optional*):
+ Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
+ Raises:
+ ValueError:
+ If the model name is not provided.
+
+ Example:
+ ```python
+ >>> engine = TransformersModel(
+ ... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
+ ... device="cuda",
+ ... max_new_tokens=5000,
+ ... )
+ >>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
+ >>> response = engine(messages, stop_sequences=["END"])
+ >>> print(response)
+ "Quantum mechanics is the branch of physics that studies..."
+ ```
"""
- def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
+ def __init__(
+ self,
+ model_id: Optional[str] = None,
+ device_map: Optional[str] = None,
+ torch_dtype: Optional[str] = None,
+ trust_remote_code: bool = False,
+ **kwargs,
+ ):
super().__init__()
if not is_torch_available():
raise ImportError("Please install torch in order to use TransformersModel.")
@@ -309,20 +371,27 @@ def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None)
f"`model_id`not provided, using this default tokenizer for token counts: '{model_id}'"
)
self.model_id = model_id
- if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- self.device = device
- logger.info(f"Using device: {self.device}")
+ self.kwargs = kwargs
+ if device_map is None:
+ device_map = "cuda" if torch.cuda.is_available() else "cpu"
+ logger.info(f"Using device: {device_map}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
- self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ device_map=device_map,
+ torch_dtype=torch_dtype,
+ trust_remote_code=trust_remote_code,
+ )
except Exception as e:
logger.warning(
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
)
self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
- self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_id, device_map=device_map, torch_dtype=torch_dtype
+ )
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnStrings(StoppingCriteria):
@@ -355,7 +424,6 @@ def __call__(
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
- max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
messages = get_clean_message_list(
@@ -380,10 +448,10 @@ def __call__(
out = self.model.generate(
**prompt_tensor,
- max_new_tokens=max_tokens,
stopping_criteria=(
self.make_stopping_criteria(stop_sequences) if stop_sequences else None
),
+ **self.kwargs,
)
generated_tokens = out[0, count_prompt_tokens:]
output = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
@@ -416,6 +484,19 @@ def __call__(
class LiteLLMModel(Model):
+ """This model connects to [LiteLLM](https://www.litellm.ai/) as a gateway to hundreds of LLMs.
+
+ Parameters:
+ model_id (`str`):
+ The model identifier to use on the server (e.g. "gpt-3.5-turbo").
+ api_base (`str`):
+ The base URL of the OpenAI-compatible API server.
+ api_key (`str`):
+ The API key to use for authentication.
+ **kwargs:
+ Additional keyword arguments to pass to the OpenAI API.
+ """
+
def __init__(
self,
model_id="anthropic/claude-3-5-sonnet-20240620",
@@ -423,9 +504,11 @@ def __init__(
api_key=None,
**kwargs,
):
- if not _is_package_available("litellm"):
- raise ImportError(
- "litellm not found. Install it with `pip install litellm`"
+ try:
+ import litellm
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ "Please install 'litellm' extra to use LiteLLMModel: `pip install 'smolagents[litellm]'`"
)
super().__init__()
self.model_id = model_id
@@ -440,12 +523,13 @@ def __call__(
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
- max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
+ import litellm
+
if tools_to_call_from:
response = litellm.completion(
model=self.model_id,
@@ -453,7 +537,6 @@ def __call__(
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="required",
stop=stop_sequences,
- max_tokens=max_tokens,
api_base=self.api_base,
api_key=self.api_key,
**self.kwargs,
@@ -463,7 +546,6 @@ def __call__(
model=self.model_id,
messages=messages,
stop=stop_sequences,
- max_tokens=max_tokens,
api_base=self.api_base,
api_key=self.api_key,
**self.kwargs,
@@ -474,17 +556,18 @@ def __call__(
class OpenAIServerModel(Model):
- """This engine connects to an OpenAI-compatible API server.
+ """This model connects to an OpenAI-compatible API server.
Parameters:
model_id (`str`):
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
- api_base (`str`):
+ api_base (`str`, *optional*):
The base URL of the OpenAI-compatible API server.
- api_key (`str`):
+ api_key (`str`, *optional*):
The API key to use for authentication.
- temperature (`float`, *optional*, defaults to 0.7):
- Controls randomness in the model's responses. Values between 0 and 2.
+ custom_role_conversions (`Dict{str, str]`, *optional*):
+ Custom role conversion mapping to convert message roles in others.
+ Useful for specific models that do not support specific message roles like "system".
**kwargs:
Additional keyword arguments to pass to the OpenAI API.
"""
@@ -492,30 +575,40 @@ class OpenAIServerModel(Model):
def __init__(
self,
model_id: str,
- api_base: str,
- api_key: str,
- temperature: float = 0.7,
+ api_base: Optional[str] = None,
+ api_key: Optional[str] = None,
+ custom_role_conversions: Optional[Dict[str, str]] = None,
**kwargs,
):
+ try:
+ import openai
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(
+ "Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`"
+ ) from None
super().__init__()
self.model_id = model_id
self.client = openai.OpenAI(
base_url=api_base,
api_key=api_key,
)
- self.temperature = temperature
self.kwargs = kwargs
+ self.custom_role_conversions = custom_role_conversions
def __call__(
self,
messages: List[Dict[str, str]],
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
- max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
) -> ChatMessage:
messages = get_clean_message_list(
- messages, role_conversions=tool_role_conversions
+ messages,
+ role_conversions=(
+ self.custom_role_conversions
+ if self.custom_role_conversions
+ else tool_role_conversions
+ ),
)
if tools_to_call_from:
response = self.client.chat.completions.create(
@@ -524,8 +617,6 @@ def __call__(
tools=[get_json_schema(tool) for tool in tools_to_call_from],
tool_choice="auto",
stop=stop_sequences,
- max_tokens=max_tokens,
- temperature=self.temperature,
**self.kwargs,
)
else:
@@ -533,8 +624,6 @@ def __call__(
model=self.model_id,
messages=messages,
stop=stop_sequences,
- max_tokens=max_tokens,
- temperature=self.temperature,
**self.kwargs,
)
self.last_input_token_count = response.usage.prompt_tokens
@@ -551,4 +640,5 @@ def __call__(
"HfApiModel",
"LiteLLMModel",
"OpenAIServerModel",
+ "ChatMessage",
]
diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py
index 04a203d92..fc85979ee 100644
--- a/src/smolagents/tools.py
+++ b/src/smolagents/tools.py
@@ -23,9 +23,10 @@
import sys
import tempfile
import textwrap
+from contextlib import contextmanager
from functools import lru_cache, wraps
from pathlib import Path
-from typing import Callable, Dict, Optional, Union, get_type_hints
+from typing import Callable, Dict, List, Optional, Union, get_type_hints
from huggingface_hub import (
create_repo,
@@ -35,6 +36,7 @@
upload_folder,
)
from huggingface_hub.utils import RepositoryNotFoundError
+
from packaging import version
from transformers.dynamic_module_utils import get_imports
from transformers.utils import (
@@ -275,7 +277,8 @@ def save(self, output_dir):
raise (ValueError("\n".join(method_checker.errors)))
forward_source_code = inspect.getsource(self.forward)
- tool_code = textwrap.dedent(f"""
+ tool_code = textwrap.dedent(
+ f"""
from smolagents import Tool
from typing import Optional
@@ -284,7 +287,8 @@ class {class_name}(Tool):
description = "{self.description}"
inputs = {json.dumps(self.inputs, separators=(",", ":"))}
output_type = "{self.output_type}"
- """).strip()
+ """
+ ).strip()
import re
def add_self_argument(source_code: str) -> str:
@@ -325,7 +329,8 @@ def replacement(match):
app_file = os.path.join(output_dir, "app.py")
with open(app_file, "w", encoding="utf-8") as f:
f.write(
- textwrap.dedent(f"""
+ textwrap.dedent(
+ f"""
from smolagents import launch_gradio_demo
from typing import Optional
from tool import {class_name}
@@ -333,7 +338,8 @@ def replacement(match):
tool = {class_name}()
launch_gradio_demo(tool)
- """).lstrip()
+ """
+ ).lstrip()
)
# Save requirements file
@@ -553,21 +559,21 @@ def from_space(
The Space, as a tool.
Examples:
+ ```py
+ >>> image_generator = Tool.from_space(
+ ... space_id="black-forest-labs/FLUX.1-schnell",
+ ... name="image-generator",
+ ... description="Generate an image from a prompt"
+ ... )
+ >>> image = image_generator("Generate an image of a cool surfer in Tahiti")
```
- image_generator = Tool.from_space(
- space_id="black-forest-labs/FLUX.1-schnell",
- name="image-generator",
- description="Generate an image from a prompt"
- )
- image = image_generator("Generate an image of a cool surfer in Tahiti")
- ```
- ```
- face_swapper = Tool.from_space(
- "tuan2308/face-swap",
- "face_swapper",
- "Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
- )
- image = face_swapper('./aymeric.jpeg', './ruth.jpg')
+ ```py
+ >>> face_swapper = Tool.from_space(
+ ... "tuan2308/face-swap",
+ ... "face_swapper",
+ ... "Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
+ ... )
+ >>> image = face_swapper('./aymeric.jpeg', './ruth.jpg')
```
"""
from gradio_client import Client, handle_file
@@ -870,42 +876,105 @@ def inner(func):
class ToolCollection:
"""
- Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
+ Tool collections enable loading a collection of tools in the agent's toolbox.
- > [!NOTE]
- > Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
- > like for this collection to showcase them.
+ Collections can be loaded from a collection in the Hub or from an MCP server, see:
+ - [`ToolCollection.from_hub`]
+ - [`ToolCollection.from_mcp`]
- Args:
- collection_slug (str):
- The collection slug referencing the collection.
- token (str, *optional*):
- The authentication token if the collection is private.
+ For example and usage, see: [`ToolCollection.from_hub`] and [`ToolCollection.from_mcp`]
+ """
- Example:
+ def __init__(self, tools: List[Tool]):
+ self.tools = tools
- ```py
- >>> from transformers import ToolCollection, CodeAgent
+ @classmethod
+ def from_hub(
+ cls,
+ collection_slug: str,
+ token: Optional[str] = None,
+ trust_remote_code: bool = False,
+ ) -> "ToolCollection":
+ """Loads a tool collection from the Hub.
- >>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
- >>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
+ it adds a collection of tools from all Spaces in the collection to the agent's toolbox
- >>> agent.run("Please draw me a picture of rivers and lakes.")
- ```
- """
+ > [!NOTE]
+ > Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
+ > like for this collection to showcase them.
- def __init__(
- self, collection_slug: str, token: Optional[str] = None, trust_remote_code=False
- ):
- self._collection = get_collection(collection_slug, token=token)
- self._hub_repo_ids = {
- item.item_id for item in self._collection.items if item.item_type == "space"
+ Args:
+ collection_slug (str): The collection slug referencing the collection.
+ token (str, *optional*): The authentication token if the collection is private.
+ trust_remote_code (bool, *optional*, defaults to False): Whether to trust the remote code.
+
+ Returns:
+ ToolCollection: A tool collection instance loaded with the tools.
+
+ Example:
+ ```py
+ >>> from smolagents import ToolCollection, CodeAgent
+
+ >>> image_tool_collection = ToolCollection.from_hub("huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
+ >>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
+
+ >>> agent.run("Please draw me a picture of rivers and lakes.")
+ ```
+ """
+ _collection = get_collection(collection_slug, token=token)
+ _hub_repo_ids = {
+ item.item_id for item in _collection.items if item.item_type == "space"
}
- self.tools = {
+
+ tools = {
Tool.from_hub(repo_id, token, trust_remote_code)
- for repo_id in self._hub_repo_ids
+ for repo_id in _hub_repo_ids
}
+ return cls(tools)
+
+ @classmethod
+ @contextmanager
+ def from_mcp(cls, server_parameters) -> "ToolCollection":
+ """Automatically load a tool collection from an MCP server.
+
+ Note: a separate thread will be spawned to run an asyncio event loop handling
+ the MCP server.
+
+ Args:
+ server_parameters (mcp.StdioServerParameters): The server parameters to use to
+ connect to the MCP server.
+
+ Returns:
+ ToolCollection: A tool collection instance.
+
+ Example:
+ ```py
+ >>> from smolagents import ToolCollection, CodeAgent
+ >>> from mcp import StdioServerParameters
+
+ >>> server_parameters = StdioServerParameters(
+ >>> command="uv",
+ >>> args=["--quiet", "pubmedmcp@0.1.3"],
+ >>> env={"UV_PYTHON": "3.12", **os.environ},
+ >>> )
+
+ >>> with ToolCollection.from_mcp(server_parameters) as tool_collection:
+ >>> agent = CodeAgent(tools=[*tool_collection.tools], add_base_tools=True)
+ >>> agent.run("Please find a remedy for hangover.")
+ ```
+ """
+ try:
+ from mcpadapt.core import MCPAdapt
+ from mcpadapt.smolagents_adapter import SmolAgentsAdapter
+ except ImportError:
+ raise ImportError(
+ """Please install 'mcp' extra to use ToolCollection.from_mcp: `pip install "smolagents[mcp]"`."""
+ )
+
+ with MCPAdapt(server_parameters, SmolAgentsAdapter()) as tools:
+ yield cls(tools)
+
def tool(tool_function: Callable) -> Tool:
"""
diff --git a/src/smolagents/types.py b/src/smolagents/types.py
index d88293f41..038885f88 100644
--- a/src/smolagents/types.py
+++ b/src/smolagents/types.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import importlib.util
import logging
import os
import pathlib
@@ -25,7 +26,6 @@
is_torch_available,
is_vision_available,
)
-from transformers.utils.import_utils import _is_package_available
logger = logging.getLogger(__name__)
@@ -41,9 +41,6 @@
else:
Tensor = object
-if _is_package_available("soundfile"):
- import soundfile as sf
-
class AgentType:
"""
@@ -187,11 +184,12 @@ class AgentAudio(AgentType, str):
"""
def __init__(self, value, samplerate=16_000):
+ if importlib.util.find_spec("soundfile") is None:
+ raise ModuleNotFoundError(
+ "Please install 'audio' extra to use AgentAudio: `pip install 'smolagents[audio]'`"
+ )
super().__init__(value)
- if not _is_package_available("soundfile"):
- raise ImportError("soundfile must be installed in order to handle audio.")
-
self._path = None
self._tensor = None
@@ -221,6 +219,8 @@ def to_raw(self):
"""
Returns the "raw" version of that object. It is a `torch.Tensor` object.
"""
+ import soundfile as sf
+
if self._tensor is not None:
return self._tensor
@@ -239,6 +239,8 @@ def to_string(self):
Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
version of the audio.
"""
+ import soundfile as sf
+
if self._path is not None:
return self._path
diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py
index e3ea23c82..ac4565f3d 100644
--- a/src/smolagents/utils.py
+++ b/src/smolagents/utils.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
+import importlib.util
import inspect
import json
import re
@@ -22,16 +23,13 @@
from typing import Dict, Tuple, Union
from rich.console import Console
-from transformers.utils.import_utils import _is_package_available
-
-_pygments_available = _is_package_available("pygments")
def is_pygments_available():
- return _pygments_available
+ return importlib.util.find_spec("soundfile") is not None
-console = Console(width=200)
+console = Console()
BASE_BUILTIN_MODULES = [
"collections",
@@ -171,9 +169,9 @@ def truncate_content(
return content
else:
return (
- content[: MAX_LENGTH_TRUNCATE_CONTENT // 2]
+ content[: max_length // 2]
+ f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
- + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
+ + content[-max_length // 2 :]
)
diff --git a/tests/test_agents.py b/tests/test_agents.py
index 1cd0a6750..4a031374b 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -35,6 +35,7 @@
ChatMessageToolCall,
ChatMessageToolCallDefinition,
)
+from smolagents.utils import BASE_BUILTIN_MODULES
def get_new_path(suffix="") -> str:
@@ -381,6 +382,12 @@ def test_tool_descriptions_get_baked_in_system_prompt(self):
assert tool.name in agent.system_prompt
assert tool.description in agent.system_prompt
+ def test_module_imports_get_baked_in_system_prompt(self):
+ agent = CodeAgent(tools=[], model=fake_code_model)
+ agent.run("Empty task")
+ for module in BASE_BUILTIN_MODULES:
+ assert module in agent.system_prompt
+
def test_init_agent_with_different_toolsets(self):
toolset_1 = []
agent = CodeAgent(tools=toolset_1, model=fake_code_model)
diff --git a/tests/test_all_docs.py b/tests/test_all_docs.py
index 3ad901d30..d1adabd73 100644
--- a/tests/test_all_docs.py
+++ b/tests/test_all_docs.py
@@ -85,7 +85,7 @@ class TestDocs:
def setup_class(cls):
cls._tmpdir = tempfile.mkdtemp()
cls.launch_args = ["python3"]
- cls.docs_dir = Path(__file__).parent.parent / "docs" / "source"
+ cls.docs_dir = Path(__file__).parent.parent / "docs" / "source" / "en"
cls.extractor = DocCodeExtractor()
if not cls.docs_dir.exists():
@@ -115,6 +115,7 @@ def test_single_doc(self, doc_path: Path):
"while llm_should_continue(memory):", # This is pseudo code
"ollama_chat/llama3.2", # Exclude ollama building in guided tour
"model = TransformersModel(model_id=model_id)", # Exclude testing with transformers model
+ "SmolagentsInstrumentor", # Exclude telemetry since it needs additional installs
]
code_blocks = [
block
diff --git a/tests/test_models.py b/tests/test_models.py
index dbd93ce14..992163194 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
+import json
from typing import Optional
-from smolagents import models, tool
+from smolagents import models, tool, ChatMessage, HfApiModel, TransformersModel
class ModelTests(unittest.TestCase):
@@ -38,3 +39,24 @@ def get_weather(location: str, celsius: Optional[bool] = False) -> str:
"properties"
]["celsius"]
)
+
+ def test_chatmessage_has_model_dumps_json(self):
+ message = ChatMessage("user", "Hello!")
+ data = json.loads(message.model_dump_json())
+ assert data["content"] == "Hello!"
+
+ def test_get_hfapi_message_no_tool(self):
+ model = HfApiModel(max_tokens=10)
+ messages = [{"role": "user", "content": "Hello!"}]
+ model(messages, stop_sequences=["great"])
+
+ def test_transformers_message_no_tool(self):
+ model = TransformersModel(
+ model_id="HuggingFaceTB/SmolLM2-135M-Instruct",
+ max_new_tokens=5,
+ device_map="auto",
+ do_sample=False,
+ )
+ messages = [{"role": "user", "content": "Hello!"}]
+ output = model(messages, stop_sequences=["great"]).content
+ assert output == "assistant\nHello"
diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py
index e55afb43d..bd8b14884 100644
--- a/tests/test_monitoring.py
+++ b/tests/test_monitoring.py
@@ -22,7 +22,7 @@
ToolCallingAgent,
stream_to_gradio,
)
-from huggingface_hub import (
+from smolagents.models import (
ChatMessage,
ChatMessageToolCall,
ChatMessageToolCallDefinition,
diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py
index 58f250cfc..4976c56af 100644
--- a/tests/test_python_interpreter.py
+++ b/tests/test_python_interpreter.py
@@ -60,6 +60,14 @@ def test_assignment_cannot_overwrite_tool(self):
in str(e)
)
+ def test_subscript_call(self):
+ code = """def foo(x,y):return x*y\n\ndef boo(y):\n\treturn y**3\nfun = [foo, boo]\nresult_foo = fun[0](4,2)\nresult_boo = fun[1](4)"""
+ state = {}
+ result, _ = evaluate_python_code(code, BASE_PYTHON_TOOLS, state=state)
+ assert result == 64
+ assert state["result_foo"] == 8
+ assert state["result_boo"] == 64
+
def test_evaluate_call(self):
code = "y = add_two(x)"
state = {"x": 3}
@@ -912,3 +920,19 @@ def test_fix_final_answer_code(self):
Expected: {expected}
Got: {result}
"""
+
+ def test_dangerous_subpackage_access_blocked(self):
+ # Direct imports with dangerous patterns should fail
+ code = "import random._os"
+ with pytest.raises(InterpreterError):
+ evaluate_python_code(code)
+
+ # Import of whitelisted modules should succeed but dangerous submodules should not exist
+ code = "import random;random._os.system('echo bad command passed')"
+ with pytest.raises(AttributeError) as e:
+ evaluate_python_code(code)
+ assert "module 'random' has no attribute '_os'" in str(e)
+
+ code = "import doctest;doctest.inspect.os.system('echo bad command passed')"
+ with pytest.raises(AttributeError):
+ evaluate_python_code(code, authorized_imports=["doctest"])
diff --git a/tests/test_tools.py b/tests/test_tools.py
index cfa61c19c..5b2dc0e1f 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -14,14 +14,17 @@
# limitations under the License.
import unittest
from pathlib import Path
+from textwrap import dedent
from typing import Dict, Optional, Union
+from unittest.mock import patch, MagicMock
+import mcp
import numpy as np
import pytest
from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import get_tests_dir
-from smolagents.tools import AUTHORIZED_TYPES, Tool, tool
+from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, tool
from smolagents.types import (
AGENT_TYPE_MAPPING,
AgentAudio,
@@ -385,3 +388,61 @@ def forward(self, location, celsius: str) -> str:
GetWeatherTool3()
assert "Nullable" in str(e)
+
+
+@pytest.fixture
+def mock_server_parameters():
+ return MagicMock()
+
+
+@pytest.fixture
+def mock_mcp_adapt():
+ with patch("mcpadapt.core.MCPAdapt") as mock:
+ mock.return_value.__enter__.return_value = ["tool1", "tool2"]
+ mock.return_value.__exit__.return_value = None
+ yield mock
+
+
+@pytest.fixture
+def mock_smolagents_adapter():
+ with patch("mcpadapt.smolagents_adapter.SmolAgentsAdapter") as mock:
+ yield mock
+
+
+class TestToolCollection:
+ def test_from_mcp(
+ self, mock_server_parameters, mock_mcp_adapt, mock_smolagents_adapter
+ ):
+ with ToolCollection.from_mcp(mock_server_parameters) as tool_collection:
+ assert isinstance(tool_collection, ToolCollection)
+ assert len(tool_collection.tools) == 2
+ assert "tool1" in tool_collection.tools
+ assert "tool2" in tool_collection.tools
+
+ def test_integration_from_mcp(self):
+ # define the most simple mcp server with one tool that echoes the input text
+ mcp_server_script = dedent("""\
+ from mcp.server.fastmcp import FastMCP
+
+ mcp = FastMCP("Echo Server")
+
+ @mcp.tool()
+ def echo_tool(text: str) -> str:
+ return text
+
+ mcp.run()
+ """).strip()
+
+ mcp_server_params = mcp.StdioServerParameters(
+ command="python",
+ args=["-c", mcp_server_script],
+ )
+
+ with ToolCollection.from_mcp(mcp_server_params) as tool_collection:
+ assert len(tool_collection.tools) == 1, "Expected 1 tool"
+ assert tool_collection.tools[0].name == "echo_tool", (
+ "Expected tool name to be 'echo_tool'"
+ )
+ assert tool_collection.tools[0](text="Hello") == "Hello", (
+ "Expected tool to echo the input text"
+ )
diff --git a/tests/test_types.py b/tests/test_types.py
index aa58a8f07..244875cfc 100644
--- a/tests/test_types.py
+++ b/tests/test_types.py
@@ -24,15 +24,9 @@
require_torch,
require_vision,
)
-from transformers.utils.import_utils import (
- _is_package_available,
-)
from smolagents.types import AgentAudio, AgentImage, AgentText
-if _is_package_available("soundfile"):
- import soundfile as sf
-
def get_new_path(suffix="") -> str:
directory = tempfile.mkdtemp()
@@ -43,6 +37,7 @@ def get_new_path(suffix="") -> str:
@require_torch
class AgentAudioTests(unittest.TestCase):
def test_from_tensor(self):
+ import soundfile as sf
import torch
tensor = torch.rand(12, dtype=torch.float64) - 0.5
@@ -62,6 +57,7 @@ def test_from_tensor(self):
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
def test_from_string(self):
+ import soundfile as sf
import torch
tensor = torch.rand(12, dtype=torch.float64) - 0.5
diff --git a/utils/check_tests_in_ci.py b/utils/check_tests_in_ci.py
new file mode 100644
index 000000000..4c55ef098
--- /dev/null
+++ b/utils/check_tests_in_ci.py
@@ -0,0 +1,58 @@
+# coding=utf-8
+# Copyright 2025-present, the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Check that all tests are called in CI."""
+
+from pathlib import Path
+
+ROOT = Path(__file__).parent.parent
+
+TESTS_FOLDER = ROOT / "tests"
+CI_WORKFLOW_FILE = ROOT / ".github" / "workflows" / "tests.yml"
+
+
+def check_tests_in_ci():
+ """List all test files in `./tests/` and check if they are listed in the CI workflow.
+
+ Since each test file is triggered separately in the CI workflow, it is easy to forget a new one when adding new
+ tests, hence this check.
+
+ NOTE: current implementation is quite naive but should work for now. Must be updated if one want to ignore some
+ tests or if file naming is updated (currently only files starting by `test_*` are cheked)
+ """
+ test_files = [
+ path.relative_to(TESTS_FOLDER).as_posix()
+ for path in TESTS_FOLDER.glob("**/*.py")
+ if path.name.startswith("test_")
+ ]
+ ci_workflow_file_content = CI_WORKFLOW_FILE.read_text()
+ missing_test_files = [
+ test_file
+ for test_file in test_files
+ if test_file not in ci_workflow_file_content
+ ]
+ if missing_test_files:
+ print(
+ "❌ Some test files seem to be ignored in the CI:\n"
+ + "\n".join(f" - {test_file}" for test_file in missing_test_files)
+ + f"\n Please add them manually in {CI_WORKFLOW_FILE}."
+ )
+ exit(1)
+ else:
+ print("✅ All good!")
+ exit(0)
+
+
+if __name__ == "__main__":
+ check_tests_in_ci()