From c611dfc7e5711f6c6f6b2e604bd89c0b809484cc Mon Sep 17 00:00:00 2001
From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com>
Date: Mon, 13 Jan 2025 17:23:03 +0100
Subject: [PATCH 1/9] Clean local python interpreter: propagate imports (#175)
---
examples/benchmark.ipynb | 270 ++++++++--
src/smolagents/agents.py | 50 +-
src/smolagents/e2b_executor.py | 11 +-
src/smolagents/gradio_ui.py | 29 +-
src/smolagents/local_python_executor.py | 638 ++++++++++++++++++------
src/smolagents/models.py | 2 +-
tests/test_agents.py | 8 +-
7 files changed, 763 insertions(+), 245 deletions(-)
diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb
index 7a7b776e5..1009f2807 100644
--- a/examples/benchmark.ipynb
+++ b/examples/benchmark.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -16,20 +16,21 @@
}
],
"source": [
- "!pip install -e .. sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages"
+ "!pip install -e .. datasets sympy numpy matplotlib seaborn -q # Install dev version of smolagents + some packages"
]
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "Using the latest cached version of the dataset since m-ric/smolagentsbenchmark couldn't be found on the Hugging Face Hub\n",
- "Found the latest cached dataset configuration 'default' at /Users/aymeric/.cache/huggingface/datasets/m-ric___smolagentsbenchmark/default/0.0.0/0ad5fb2293ab185eece723a4ac0e4a7188f71add (last modified on Wed Jan 8 17:50:13 2025).\n"
+ "/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "Generating train split: 100%|██████████| 132/132 [00:00<00:00, 17393.36 examples/s]\n"
]
},
{
@@ -172,7 +173,7 @@
"[132 rows x 4 columns]"
]
},
- "execution_count": 3,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -181,7 +182,7 @@
"import datasets\n",
"import pandas as pd\n",
"\n",
- "eval_ds = datasets.load_dataset(\"m-ric/smolagentsbenchmark\")[\"train\"]\n",
+ "eval_ds = datasets.load_dataset(\"m-ric/smol_agents_benchmark\")[\"train\"]\n",
"pd.DataFrame(eval_ds)"
]
},
@@ -195,9 +196,19 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 6,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/aymeric/venv/test/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2:\n",
+ "* 'fields' has been removed\n",
+ " warnings.warn(message, UserWarning)\n"
+ ]
+ }
+ ],
"source": [
"import time\n",
"import json\n",
@@ -351,6 +362,7 @@
" model_answer: str,\n",
" ground_truth: str,\n",
") -> bool:\n",
+ " \"\"\"Scoring function used to score functions from the GAIA benchmark\"\"\"\n",
" if is_float(ground_truth):\n",
" normalized_answer = normalize_number_str(str(model_answer))\n",
" return normalized_answer == float(ground_truth)\n",
@@ -396,9 +408,100 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 7,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 132/132 [00:00<00:00, 27061.35it/s]\n",
+ "100%|██████████| 132/132 [00:00<00:00, 34618.15it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating 'Qwen/Qwen2.5-72B-Instruct'...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 132/132 [00:00<00:00, 33008.29it/s]\n",
+ "100%|██████████| 132/132 [00:00<00:00, 36292.90it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating 'Qwen/Qwen2.5-Coder-32B-Instruct'...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 132/132 [00:00<00:00, 29165.47it/s]\n",
+ "100%|██████████| 132/132 [00:00<00:00, 30378.50it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating 'meta-llama/Llama-3.2-3B-Instruct'...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 132/132 [00:00<00:00, 33453.06it/s]\n",
+ "100%|██████████| 132/132 [00:00<00:00, 34763.79it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating 'meta-llama/Llama-3.1-8B-Instruct'...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 132/132 [00:00<00:00, 35246.25it/s]\n",
+ "100%|██████████| 132/132 [00:00<00:00, 28551.81it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating 'mistralai/Mistral-Nemo-Instruct-2407'...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 132/132 [00:00<00:00, 32441.59it/s]\n",
+ "100%|██████████| 132/132 [00:00<00:00, 35542.67it/s]\n"
+ ]
+ }
+ ],
"source": [
"open_model_ids = [\n",
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
@@ -451,9 +554,42 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 8,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating 'gpt-4o'...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 132/132 [00:00<00:00, 36136.55it/s]\n",
+ "100%|██████████| 132/132 [00:00<00:00, 33451.04it/s]\n",
+ "100%|██████████| 132/132 [00:00<00:00, 39146.44it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating 'anthropic/claude-3-5-sonnet-latest'...\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 132/132 [00:00<00:00, 31512.79it/s]\n",
+ "100%|██████████| 132/132 [00:00<00:00, 33576.82it/s]\n",
+ "100%|██████████| 132/132 [00:00<00:00, 36075.33it/s]\n"
+ ]
+ }
+ ],
"source": [
"from smolagents import LiteLLMModel\n",
"\n",
@@ -495,7 +631,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@@ -534,14 +670,14 @@
},
{
"cell_type": "code",
- "execution_count": 31,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_74415/3026956094.py:163: UserWarning: Answer lists have different lengths, returning False.\n",
+ "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_6037/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\n"
]
}
@@ -552,9 +688,25 @@
"\n",
"res = []\n",
"for file_path in glob.glob(\"output/*.jsonl\"):\n",
- " smoldf = pd.read_json(file_path, lines=True)\n",
- " smoldf[\"action_type\"] = \"vanilla\" if \"-vanilla-\" in file_path else \"code\"\n",
- " res.append(smoldf)\n",
+ " data = []\n",
+ " with open(file_path) as f:\n",
+ " for line in f:\n",
+ " try:\n",
+ " # Use standard json module instead of pandas.json to handle large numbers better\n",
+ " record = json.loads(line)\n",
+ " data.append(record)\n",
+ " except json.JSONDecodeError as e:\n",
+ " print(f\"Error parsing line in {file_path}: {e}\")\n",
+ " continue\n",
+ "\n",
+ " try:\n",
+ " smoldf = pd.DataFrame(data)\n",
+ " smoldf[\"action_type\"] = \"vanilla\" if \"-vanilla-\" in file_path else \"code\"\n",
+ " res.append(smoldf)\n",
+ " except Exception as e:\n",
+ " print(f\"Error creating DataFrame from {file_path}: {e}\")\n",
+ " continue\n",
+ "\n",
"result_df = pd.concat(res)\n",
"\n",
"\n",
@@ -579,7 +731,7 @@
},
{
"cell_type": "code",
- "execution_count": 33,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@@ -600,7 +752,7 @@
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": 15,
"metadata": {},
"outputs": [
{
@@ -643,7 +795,7 @@
"
Qwen/Qwen2.5-72B-Instruct | \n",
" MATH | \n",
" 74.0 | \n",
- " 31.9 | \n",
+ " 30.0 | \n",
" \n",
" \n",
" | 2 | \n",
@@ -778,33 +930,57 @@
" 84.0 | \n",
" 12.0 | \n",
"
\n",
+ " \n",
+ " | 21 | \n",
+ " mistralai/Mistral-Nemo-Instruct-2407 | \n",
+ " GAIA | \n",
+ " 3.1 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " | 22 | \n",
+ " mistralai/Mistral-Nemo-Instruct-2407 | \n",
+ " MATH | \n",
+ " 20.0 | \n",
+ " 22.0 | \n",
+ "
\n",
+ " \n",
+ " | 23 | \n",
+ " mistralai/Mistral-Nemo-Instruct-2407 | \n",
+ " SimpleQA | \n",
+ " 30.0 | \n",
+ " 0.0 | \n",
+ "
\n",
" \n",
"\n",
""
],
"text/plain": [
- "action_type model_id source code vanilla\n",
- "0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n",
- "1 Qwen/Qwen2.5-72B-Instruct MATH 74.0 31.9\n",
- "2 Qwen/Qwen2.5-72B-Instruct SimpleQA 70.0 10.0\n",
- "3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 18.8 3.1\n",
- "4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 76.0 60.0\n",
- "5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n",
- "6 anthropic/claude-3-5-sonnet-latest GAIA 40.6 3.1\n",
- "7 anthropic/claude-3-5-sonnet-latest MATH 67.0 50.0\n",
- "8 anthropic/claude-3-5-sonnet-latest SimpleQA 90.0 34.0\n",
- "9 gpt-4o GAIA 28.1 3.1\n",
- "10 gpt-4o MATH 70.0 40.0\n",
- "11 gpt-4o SimpleQA 88.0 6.0\n",
- "12 meta-llama/Llama-3.1-8B-Instruct GAIA 0.0 0.0\n",
- "13 meta-llama/Llama-3.1-8B-Instruct MATH 42.0 18.0\n",
- "14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 54.0 6.0\n",
- "15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n",
- "16 meta-llama/Llama-3.2-3B-Instruct MATH 32.0 12.0\n",
- "17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 4.0 0.0\n",
- "18 meta-llama/Llama-3.3-70B-Instruct GAIA 34.4 3.1\n",
- "19 meta-llama/Llama-3.3-70B-Instruct MATH 82.0 40.0\n",
- "20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 84.0 12.0"
+ "action_type model_id source code vanilla\n",
+ "0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n",
+ "1 Qwen/Qwen2.5-72B-Instruct MATH 74.0 30.0\n",
+ "2 Qwen/Qwen2.5-72B-Instruct SimpleQA 70.0 10.0\n",
+ "3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 18.8 3.1\n",
+ "4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 76.0 60.0\n",
+ "5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n",
+ "6 anthropic/claude-3-5-sonnet-latest GAIA 40.6 3.1\n",
+ "7 anthropic/claude-3-5-sonnet-latest MATH 67.0 50.0\n",
+ "8 anthropic/claude-3-5-sonnet-latest SimpleQA 90.0 34.0\n",
+ "9 gpt-4o GAIA 28.1 3.1\n",
+ "10 gpt-4o MATH 70.0 40.0\n",
+ "11 gpt-4o SimpleQA 88.0 6.0\n",
+ "12 meta-llama/Llama-3.1-8B-Instruct GAIA 0.0 0.0\n",
+ "13 meta-llama/Llama-3.1-8B-Instruct MATH 42.0 18.0\n",
+ "14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 54.0 6.0\n",
+ "15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n",
+ "16 meta-llama/Llama-3.2-3B-Instruct MATH 32.0 12.0\n",
+ "17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 4.0 0.0\n",
+ "18 meta-llama/Llama-3.3-70B-Instruct GAIA 34.4 3.1\n",
+ "19 meta-llama/Llama-3.3-70B-Instruct MATH 82.0 40.0\n",
+ "20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 84.0 12.0\n",
+ "21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 3.1 0.0\n",
+ "22 mistralai/Mistral-Nemo-Instruct-2407 MATH 20.0 22.0\n",
+ "23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 0.0"
]
},
"metadata": {},
@@ -817,12 +993,12 @@
},
{
"cell_type": "code",
- "execution_count": 36,
+ "execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
""
]
@@ -995,7 +1171,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "compare-agents",
+ "display_name": "test",
"language": "python",
"name": "python3"
},
diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py
index e9c3d9d14..832ac8efc 100644
--- a/src/smolagents/agents.py
+++ b/src/smolagents/agents.py
@@ -884,7 +884,6 @@ def __init__(
system_prompt: Optional[str] = None,
grammar: Optional[Dict[str, str]] = None,
additional_authorized_imports: Optional[List[str]] = None,
- allow_all_imports: bool = False,
planning_interval: Optional[int] = None,
use_e2b_executor: bool = False,
**kwargs,
@@ -899,18 +898,29 @@ def __init__(
planning_interval=planning_interval,
**kwargs,
)
-
- if ( allow_all_imports and
- ( not(additional_authorized_imports is None) and (len(additional_authorized_imports)) > 0)):
- raise Exception(
- f"You passed both allow_all_imports and additional_authorized_imports. Please choose one."
- )
-
- if allow_all_imports: additional_authorized_imports=['*']
-
self.additional_authorized_imports = (
additional_authorized_imports if additional_authorized_imports else []
)
+ self.authorized_imports = list(
+ set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
+ )
+ if "{{authorized_imports}}" not in self.system_prompt:
+ raise AgentError(
+ "Tag '{{authorized_imports}}' should be provided in the prompt."
+ )
+ self.system_prompt = self.system_prompt.replace(
+ "{{authorized_imports}}",
+ "You can import from any package you want."
+ if "*" in self.authorized_imports
+ else str(self.authorized_imports),
+ )
+
+ if "*" in self.additional_authorized_imports:
+ self.logger.log(
+ "Caution: you set an authorization for all imports, meaning your agent can decide to import any package it deems necessary. This might raise issues if the package is not installed in your environment.",
+ 0,
+ )
+
if use_e2b_executor and len(self.managed_agents) > 0:
raise Exception(
f"You passed both {use_e2b_executor=} and some managed agents. Managed agents is not yet supported with remote code execution."
@@ -919,25 +929,15 @@ def __init__(
all_tools = {**self.tools, **self.managed_agents}
if use_e2b_executor:
self.python_executor = E2BExecutor(
- self.additional_authorized_imports, list(all_tools.values())
+ self.additional_authorized_imports,
+ list(all_tools.values()),
+ self.logger,
)
else:
self.python_executor = LocalPythonInterpreter(
- self.additional_authorized_imports, all_tools
+ self.additional_authorized_imports,
+ all_tools,
)
- if allow_all_imports:
- self.authorized_imports = 'all imports without restriction'
- else:
- self.authorized_imports = list(
- set(BASE_BUILTIN_MODULES) | set(self.additional_authorized_imports)
- )
- if "{{authorized_imports}}" not in self.system_prompt:
- raise AgentError(
- "Tag '{{authorized_imports}}' should be provided in the prompt."
- )
- self.system_prompt = self.system_prompt.replace(
- "{{authorized_imports}}", str(self.authorized_imports)
- )
def step(self, log_entry: ActionStep) -> Union[None, Any]:
"""
diff --git a/src/smolagents/e2b_executor.py b/src/smolagents/e2b_executor.py
index 68f557940..e8cc89347 100644
--- a/src/smolagents/e2b_executor.py
+++ b/src/smolagents/e2b_executor.py
@@ -26,13 +26,13 @@
from .tool_validation import validate_tool_attributes
from .tools import Tool
-from .utils import BASE_BUILTIN_MODULES, console, instance_to_source
+from .utils import BASE_BUILTIN_MODULES, instance_to_source
load_dotenv()
class E2BExecutor:
- def __init__(self, additional_imports: List[str], tools: List[Tool]):
+ def __init__(self, additional_imports: List[str], tools: List[Tool], logger):
self.custom_tools = {}
self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
# TODO: validate installing agents package or not
@@ -42,6 +42,7 @@ def __init__(self, additional_imports: List[str], tools: List[Tool]):
# timeout=300
# )
# print("Installation of agents package finished.")
+ self.logger = logger
additional_imports = additional_imports + ["pickle5"]
if len(additional_imports) > 0:
execution = self.sbx.commands.run(
@@ -50,7 +51,7 @@ def __init__(self, additional_imports: List[str], tools: List[Tool]):
if execution.error:
raise Exception(f"Error installing dependencies: {execution.error}")
else:
- console.print(f"Installation of {additional_imports} succeeded!")
+ logger.log(f"Installation of {additional_imports} succeeded!", 0)
tool_codes = []
for tool in tools:
@@ -74,7 +75,7 @@ def forward(self, *args, **kwargs):
tool_definition_code += "\n\n".join(tool_codes)
tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
- console.print(tool_definition_execution.logs)
+ self.logger.log(tool_definition_execution.logs)
def run_code_raise_errors(self, code: str):
execution = self.sbx.run_code(
@@ -109,7 +110,7 @@ def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
"""
execution = self.run_code_raise_errors(remote_unloading_code)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
- console.print(execution_logs)
+ self.logger.log(execution_logs, 1)
execution = self.run_code_raise_errors(code_action)
execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py
index 514bd1f22..45ae8a250 100644
--- a/src/smolagents/gradio_ui.py
+++ b/src/smolagents/gradio_ui.py
@@ -85,7 +85,7 @@ def stream_to_gradio(
class GradioUI:
"""A one-line interface to launch your agent in Gradio"""
- def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None=None):
+ def __init__(self, agent: MultiStepAgent, file_upload_folder: str | None = None):
self.agent = agent
self.file_upload_folder = file_upload_folder
if self.file_upload_folder is not None:
@@ -100,7 +100,15 @@ def interact_with_agent(self, prompt, messages):
yield messages
yield messages
- def upload_file(self, file, allowed_file_types=["application/pdf", "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "text/plain"]):
+ def upload_file(
+ self,
+ file,
+ allowed_file_types=[
+ "application/pdf",
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
+ "text/plain",
+ ],
+ ):
"""
Handle file uploads, default allowed types are pdf, docx, and .txt
"""
@@ -110,18 +118,19 @@ def upload_file(self, file, allowed_file_types=["application/pdf", "application/
return "No file uploaded"
# Check if file is in allowed filetypes
- name = os.path.basename(file.name)
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
return f"Error: {e}"
-
+
if mime_type not in allowed_file_types:
return "File type disallowed"
-
+
# Sanitize file name
original_name = os.path.basename(file.name)
- sanitized_name = re.sub(r'[^\w\-.]', '_', original_name) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
+ sanitized_name = re.sub(
+ r"[^\w\-.]", "_", original_name
+ ) # Replace any non-alphanumeric, non-dash, or non-dot characters with underscores
type_to_ext = {}
for ext, t in mimetypes.types_map.items():
@@ -134,7 +143,9 @@ def upload_file(self, file, allowed_file_types=["application/pdf", "application/
sanitized_name = "".join(sanitized_name)
# Save the uploaded file to the specified folder
- file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
+ file_path = os.path.join(
+ self.file_upload_folder, os.path.basename(sanitized_name)
+ )
shutil.copy(file.name, file_path)
return f"File uploaded successfully to {self.file_upload_folder}"
@@ -155,9 +166,7 @@ def launch(self):
upload_file = gr.File(label="Upload a file")
upload_status = gr.Textbox(label="Upload Status", interactive=False)
- upload_file.change(
- self.upload_file, [upload_file], [upload_status]
- )
+ upload_file.change(self.upload_file, [upload_file], [upload_status])
text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(
lambda s: (s, ""), [text_input], [stored_message, text_input]
diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py
index a70b53766..3c545cae6 100644
--- a/src/smolagents/local_python_executor.py
+++ b/src/smolagents/local_python_executor.py
@@ -159,8 +159,16 @@ def fix_final_answer_code(code: str) -> str:
return code
-def evaluate_unaryop(expression, state, static_tools, custom_tools):
- operand = evaluate_ast(expression.operand, state, static_tools, custom_tools)
+def evaluate_unaryop(
+ expression: ast.UnaryOp,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Any:
+ operand = evaluate_ast(
+ expression.operand, state, static_tools, custom_tools, authorized_imports
+ )
if isinstance(expression.op, ast.USub):
return -operand
elif isinstance(expression.op, ast.UAdd):
@@ -175,27 +183,47 @@ def evaluate_unaryop(expression, state, static_tools, custom_tools):
)
-def evaluate_lambda(lambda_expression, state, static_tools, custom_tools):
+def evaluate_lambda(
+ lambda_expression: ast.Lambda,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Callable:
args = [arg.arg for arg in lambda_expression.args.args]
- def lambda_func(*values):
+ def lambda_func(*values: Any) -> Any:
new_state = state.copy()
for arg, value in zip(args, values):
new_state[arg] = value
return evaluate_ast(
- lambda_expression.body, new_state, static_tools, custom_tools
+ lambda_expression.body,
+ new_state,
+ static_tools,
+ custom_tools,
+ authorized_imports,
)
return lambda_func
-def evaluate_while(while_loop, state, static_tools, custom_tools):
+def evaluate_while(
+ while_loop: ast.While,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> None:
max_iterations = 1000
iterations = 0
- while evaluate_ast(while_loop.test, state, static_tools, custom_tools):
+ while evaluate_ast(
+ while_loop.test, state, static_tools, custom_tools, authorized_imports
+ ):
for node in while_loop.body:
try:
- evaluate_ast(node, state, static_tools, custom_tools)
+ evaluate_ast(
+ node, state, static_tools, custom_tools, authorized_imports
+ )
except BreakException:
return None
except ContinueException:
@@ -208,12 +236,18 @@ def evaluate_while(while_loop, state, static_tools, custom_tools):
return None
-def create_function(func_def, state, static_tools, custom_tools):
- def new_func(*args, **kwargs):
+def create_function(
+ func_def: ast.FunctionDef,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Callable:
+ def new_func(*args: Any, **kwargs: Any) -> Any:
func_state = state.copy()
arg_names = [arg.arg for arg in func_def.args.args]
default_values = [
- evaluate_ast(d, state, static_tools, custom_tools)
+ evaluate_ast(d, state, static_tools, custom_tools, authorized_imports)
for d in func_def.args.defaults
]
@@ -224,7 +258,7 @@ def new_func(*args, **kwargs):
for name, value in zip(arg_names, args):
func_state[name] = value
- # # Set keyword arguments
+ # Set keyword arguments
for name, value in kwargs.items():
func_state[name] = value
@@ -251,7 +285,9 @@ def new_func(*args, **kwargs):
result = None
try:
for stmt in func_def.body:
- result = evaluate_ast(stmt, func_state, static_tools, custom_tools)
+ result = evaluate_ast(
+ stmt, func_state, static_tools, custom_tools, authorized_imports
+ )
except ReturnException as e:
result = e.value
@@ -263,24 +299,29 @@ def new_func(*args, **kwargs):
return new_func
-def create_class(class_name, class_bases, class_body):
- class_dict = {}
- for key, value in class_body.items():
- class_dict[key] = value
- return type(class_name, tuple(class_bases), class_dict)
-
-
-def evaluate_function_def(func_def, state, static_tools, custom_tools):
+def evaluate_function_def(
+ func_def: ast.FunctionDef,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Callable:
custom_tools[func_def.name] = create_function(
- func_def, state, static_tools, custom_tools
+ func_def, state, static_tools, custom_tools, authorized_imports
)
return custom_tools[func_def.name]
-def evaluate_class_def(class_def, state, static_tools, custom_tools):
+def evaluate_class_def(
+ class_def: ast.ClassDef,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> type:
class_name = class_def.name
bases = [
- evaluate_ast(base, state, static_tools, custom_tools)
+ evaluate_ast(base, state, static_tools, custom_tools, authorized_imports)
for base in class_def.bases
]
class_dict = {}
@@ -288,17 +329,25 @@ def evaluate_class_def(class_def, state, static_tools, custom_tools):
for stmt in class_def.body:
if isinstance(stmt, ast.FunctionDef):
class_dict[stmt.name] = evaluate_function_def(
- stmt, state, static_tools, custom_tools
+ stmt, state, static_tools, custom_tools, authorized_imports
)
elif isinstance(stmt, ast.Assign):
for target in stmt.targets:
if isinstance(target, ast.Name):
class_dict[target.id] = evaluate_ast(
- stmt.value, state, static_tools, custom_tools
+ stmt.value,
+ state,
+ static_tools,
+ custom_tools,
+ authorized_imports,
)
elif isinstance(target, ast.Attribute):
class_dict[target.attr] = evaluate_ast(
- stmt.value, state, static_tools, custom_tools
+ stmt.value,
+ state,
+ static_tools,
+ custom_tools,
+ authorized_imports,
)
else:
raise InterpreterError(
@@ -310,16 +359,28 @@ def evaluate_class_def(class_def, state, static_tools, custom_tools):
return new_class
-def evaluate_augassign(expression, state, static_tools, custom_tools):
- def get_current_value(target):
+def evaluate_augassign(
+ expression: ast.AugAssign,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Any:
+ def get_current_value(target: ast.AST) -> Any:
if isinstance(target, ast.Name):
return state.get(target.id, 0)
elif isinstance(target, ast.Subscript):
- obj = evaluate_ast(target.value, state, static_tools, custom_tools)
- key = evaluate_ast(target.slice, state, static_tools, custom_tools)
+ obj = evaluate_ast(
+ target.value, state, static_tools, custom_tools, authorized_imports
+ )
+ key = evaluate_ast(
+ target.slice, state, static_tools, custom_tools, authorized_imports
+ )
return obj[key]
elif isinstance(target, ast.Attribute):
- obj = evaluate_ast(target.value, state, static_tools, custom_tools)
+ obj = evaluate_ast(
+ target.value, state, static_tools, custom_tools, authorized_imports
+ )
return getattr(obj, target.attr)
elif isinstance(target, ast.Tuple):
return tuple(get_current_value(elt) for elt in target.elts)
@@ -331,7 +392,9 @@ def get_current_value(target):
)
current_value = get_current_value(expression.target)
- value_to_add = evaluate_ast(expression.value, state, static_tools, custom_tools)
+ value_to_add = evaluate_ast(
+ expression.value, state, static_tools, custom_tools, authorized_imports
+ )
if isinstance(expression.op, ast.Add):
if isinstance(current_value, list):
@@ -370,28 +433,55 @@ def get_current_value(target):
)
# Update the state
- set_value(expression.target, updated_value, state, static_tools, custom_tools)
+ set_value(
+ expression.target,
+ updated_value,
+ state,
+ static_tools,
+ custom_tools,
+ authorized_imports,
+ )
return updated_value
-def evaluate_boolop(node, state, static_tools, custom_tools):
+def evaluate_boolop(
+ node: ast.BoolOp,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> bool:
if isinstance(node.op, ast.And):
for value in node.values:
- if not evaluate_ast(value, state, static_tools, custom_tools):
+ if not evaluate_ast(
+ value, state, static_tools, custom_tools, authorized_imports
+ ):
return False
return True
elif isinstance(node.op, ast.Or):
for value in node.values:
- if evaluate_ast(value, state, static_tools, custom_tools):
+ if evaluate_ast(
+ value, state, static_tools, custom_tools, authorized_imports
+ ):
return True
return False
-def evaluate_binop(binop, state, static_tools, custom_tools):
+def evaluate_binop(
+ binop: ast.BinOp,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Any:
# Recursively evaluate the left and right operands
- left_val = evaluate_ast(binop.left, state, static_tools, custom_tools)
- right_val = evaluate_ast(binop.right, state, static_tools, custom_tools)
+ left_val = evaluate_ast(
+ binop.left, state, static_tools, custom_tools, authorized_imports
+ )
+ right_val = evaluate_ast(
+ binop.right, state, static_tools, custom_tools, authorized_imports
+ )
# Determine the operation based on the type of the operator in the BinOp
if isinstance(binop.op, ast.Add):
@@ -424,11 +514,19 @@ def evaluate_binop(binop, state, static_tools, custom_tools):
)
-def evaluate_assign(assign, state, static_tools, custom_tools):
- result = evaluate_ast(assign.value, state, static_tools, custom_tools)
+def evaluate_assign(
+ assign: ast.Assign,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Any:
+ result = evaluate_ast(
+ assign.value, state, static_tools, custom_tools, authorized_imports
+ )
if len(assign.targets) == 1:
target = assign.targets[0]
- set_value(target, result, state, static_tools, custom_tools)
+ set_value(target, result, state, static_tools, custom_tools, authorized_imports)
else:
if len(assign.targets) != len(result):
raise InterpreterError(
@@ -441,11 +539,18 @@ def evaluate_assign(assign, state, static_tools, custom_tools):
else:
expanded_values.append(result)
for tgt, val in zip(assign.targets, expanded_values):
- set_value(tgt, val, state, static_tools, custom_tools)
+ set_value(tgt, val, state, static_tools, custom_tools, authorized_imports)
return result
-def set_value(target, value, state, static_tools, custom_tools):
+def set_value(
+ target: ast.AST,
+ value: Any,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> None:
if isinstance(target, ast.Name):
if target.id in static_tools:
raise InterpreterError(
@@ -461,21 +566,37 @@ def set_value(target, value, state, static_tools, custom_tools):
if len(target.elts) != len(value):
raise InterpreterError("Cannot unpack tuple of wrong size")
for i, elem in enumerate(target.elts):
- set_value(elem, value[i], state, static_tools, custom_tools)
+ set_value(
+ elem, value[i], state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(target, ast.Subscript):
- obj = evaluate_ast(target.value, state, static_tools, custom_tools)
- key = evaluate_ast(target.slice, state, static_tools, custom_tools)
+ obj = evaluate_ast(
+ target.value, state, static_tools, custom_tools, authorized_imports
+ )
+ key = evaluate_ast(
+ target.slice, state, static_tools, custom_tools, authorized_imports
+ )
obj[key] = value
elif isinstance(target, ast.Attribute):
- obj = evaluate_ast(target.value, state, static_tools, custom_tools)
+ obj = evaluate_ast(
+ target.value, state, static_tools, custom_tools, authorized_imports
+ )
setattr(obj, target.attr, value)
-def evaluate_call(call, state, static_tools, custom_tools):
+def evaluate_call(
+ call: ast.Call,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Any:
if not (isinstance(call.func, ast.Attribute) or isinstance(call.func, ast.Name)):
raise InterpreterError(f"This is not a correct function: {call.func}).")
if isinstance(call.func, ast.Attribute):
- obj = evaluate_ast(call.func.value, state, static_tools, custom_tools)
+ obj = evaluate_ast(
+ call.func.value, state, static_tools, custom_tools, authorized_imports
+ )
func_name = call.func.attr
if not hasattr(obj, func_name):
raise InterpreterError(f"Object {obj} has no attribute {func_name}")
@@ -499,22 +620,20 @@ def evaluate_call(call, state, static_tools, custom_tools):
args = []
for arg in call.args:
if isinstance(arg, ast.Starred):
- args.extend(evaluate_ast(arg.value, state, static_tools, custom_tools))
- else:
- args.append(evaluate_ast(arg, state, static_tools, custom_tools))
-
- args = []
- for arg in call.args:
- if isinstance(arg, ast.Starred):
- unpacked = evaluate_ast(arg.value, state, static_tools, custom_tools)
- if not hasattr(unpacked, "__iter__") or isinstance(unpacked, (str, bytes)):
- raise InterpreterError(f"Cannot unpack non-iterable value {unpacked}")
- args.extend(unpacked)
+ args.extend(
+ evaluate_ast(
+ arg.value, state, static_tools, custom_tools, authorized_imports
+ )
+ )
else:
- args.append(evaluate_ast(arg, state, static_tools, custom_tools))
+ args.append(
+ evaluate_ast(arg, state, static_tools, custom_tools, authorized_imports)
+ )
kwargs = {
- keyword.arg: evaluate_ast(keyword.value, state, static_tools, custom_tools)
+ keyword.arg: evaluate_ast(
+ keyword.value, state, static_tools, custom_tools, authorized_imports
+ )
for keyword in call.keywords
}
@@ -545,9 +664,19 @@ def evaluate_call(call, state, static_tools, custom_tools):
return func(*args, **kwargs)
-def evaluate_subscript(subscript, state, static_tools, custom_tools):
- index = evaluate_ast(subscript.slice, state, static_tools, custom_tools)
- value = evaluate_ast(subscript.value, state, static_tools, custom_tools)
+def evaluate_subscript(
+ subscript: ast.Subscript,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Any:
+ index = evaluate_ast(
+ subscript.slice, state, static_tools, custom_tools, authorized_imports
+ )
+ value = evaluate_ast(
+ subscript.value, state, static_tools, custom_tools, authorized_imports
+ )
if isinstance(value, str) and isinstance(index, str):
raise InterpreterError(
@@ -583,7 +712,13 @@ def evaluate_subscript(subscript, state, static_tools, custom_tools):
raise InterpreterError(f"Could not index {value} with '{index}'.")
-def evaluate_name(name, state, static_tools, custom_tools):
+def evaluate_name(
+ name: ast.Name,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Any:
if name.id in state:
return state[name.id]
elif name.id in static_tools:
@@ -596,10 +731,18 @@ def evaluate_name(name, state, static_tools, custom_tools):
raise InterpreterError(f"The variable `{name.id}` is not defined.")
-def evaluate_condition(condition, state, static_tools, custom_tools):
- left = evaluate_ast(condition.left, state, static_tools, custom_tools)
+def evaluate_condition(
+ condition: ast.Compare,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> bool:
+ left = evaluate_ast(
+ condition.left, state, static_tools, custom_tools, authorized_imports
+ )
comparators = [
- evaluate_ast(c, state, static_tools, custom_tools)
+ evaluate_ast(c, state, static_tools, custom_tools, authorized_imports)
for c in condition.comparators
]
ops = [type(op) for op in condition.ops]
@@ -640,30 +783,59 @@ def evaluate_condition(condition, state, static_tools, custom_tools):
return result if isinstance(result, (bool, pd.Series)) else result.all()
-def evaluate_if(if_statement, state, static_tools, custom_tools):
+def evaluate_if(
+ if_statement: ast.If,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Any:
result = None
- test_result = evaluate_ast(if_statement.test, state, static_tools, custom_tools)
+ test_result = evaluate_ast(
+ if_statement.test, state, static_tools, custom_tools, authorized_imports
+ )
if test_result:
for line in if_statement.body:
- line_result = evaluate_ast(line, state, static_tools, custom_tools)
+ line_result = evaluate_ast(
+ line, state, static_tools, custom_tools, authorized_imports
+ )
if line_result is not None:
result = line_result
else:
for line in if_statement.orelse:
- line_result = evaluate_ast(line, state, static_tools, custom_tools)
+ line_result = evaluate_ast(
+ line, state, static_tools, custom_tools, authorized_imports
+ )
if line_result is not None:
result = line_result
return result
-def evaluate_for(for_loop, state, static_tools, custom_tools):
+def evaluate_for(
+ for_loop: ast.For,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Any:
result = None
- iterator = evaluate_ast(for_loop.iter, state, static_tools, custom_tools)
+ iterator = evaluate_ast(
+ for_loop.iter, state, static_tools, custom_tools, authorized_imports
+ )
for counter in iterator:
- set_value(for_loop.target, counter, state, static_tools, custom_tools)
+ set_value(
+ for_loop.target,
+ counter,
+ state,
+ static_tools,
+ custom_tools,
+ authorized_imports,
+ )
for node in for_loop.body:
try:
- line_result = evaluate_ast(node, state, static_tools, custom_tools)
+ line_result = evaluate_ast(
+ node, state, static_tools, custom_tools, authorized_imports
+ )
if line_result is not None:
result = line_result
except BreakException:
@@ -676,15 +848,33 @@ def evaluate_for(for_loop, state, static_tools, custom_tools):
return result
-def evaluate_listcomp(listcomp, state, static_tools, custom_tools):
- def inner_evaluate(generators, index, current_state):
+def evaluate_listcomp(
+ listcomp: ast.ListComp,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> List[Any]:
+ def inner_evaluate(
+ generators: List[ast.comprehension], index: int, current_state: Dict[str, Any]
+ ) -> List[Any]:
if index >= len(generators):
return [
- evaluate_ast(listcomp.elt, current_state, static_tools, custom_tools)
+ evaluate_ast(
+ listcomp.elt,
+ current_state,
+ static_tools,
+ custom_tools,
+ authorized_imports,
+ )
]
generator = generators[index]
iter_value = evaluate_ast(
- generator.iter, current_state, static_tools, custom_tools
+ generator.iter,
+ current_state,
+ static_tools,
+ custom_tools,
+ authorized_imports,
)
result = []
for value in iter_value:
@@ -695,7 +885,9 @@ def inner_evaluate(generators, index, current_state):
else:
new_state[generator.target.id] = value
if all(
- evaluate_ast(if_clause, new_state, static_tools, custom_tools)
+ evaluate_ast(
+ if_clause, new_state, static_tools, custom_tools, authorized_imports
+ )
for if_clause in generator.ifs
):
result.extend(inner_evaluate(generators, index + 1, new_state))
@@ -704,41 +896,66 @@ def inner_evaluate(generators, index, current_state):
return inner_evaluate(listcomp.generators, 0, state)
-def evaluate_try(try_node, state, static_tools, custom_tools):
+def evaluate_try(
+ try_node: ast.Try,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> None:
try:
for stmt in try_node.body:
- evaluate_ast(stmt, state, static_tools, custom_tools)
+ evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
except Exception as e:
matched = False
for handler in try_node.handlers:
if handler.type is None or isinstance(
- e, evaluate_ast(handler.type, state, static_tools, custom_tools)
+ e,
+ evaluate_ast(
+ handler.type, state, static_tools, custom_tools, authorized_imports
+ ),
):
matched = True
if handler.name:
state[handler.name] = e
for stmt in handler.body:
- evaluate_ast(stmt, state, static_tools, custom_tools)
+ evaluate_ast(
+ stmt, state, static_tools, custom_tools, authorized_imports
+ )
break
if not matched:
raise e
else:
if try_node.orelse:
for stmt in try_node.orelse:
- evaluate_ast(stmt, state, static_tools, custom_tools)
+ evaluate_ast(
+ stmt, state, static_tools, custom_tools, authorized_imports
+ )
finally:
if try_node.finalbody:
for stmt in try_node.finalbody:
- evaluate_ast(stmt, state, static_tools, custom_tools)
+ evaluate_ast(
+ stmt, state, static_tools, custom_tools, authorized_imports
+ )
-def evaluate_raise(raise_node, state, static_tools, custom_tools):
+def evaluate_raise(
+ raise_node: ast.Raise,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> None:
if raise_node.exc is not None:
- exc = evaluate_ast(raise_node.exc, state, static_tools, custom_tools)
+ exc = evaluate_ast(
+ raise_node.exc, state, static_tools, custom_tools, authorized_imports
+ )
else:
exc = None
if raise_node.cause is not None:
- cause = evaluate_ast(raise_node.cause, state, static_tools, custom_tools)
+ cause = evaluate_ast(
+ raise_node.cause, state, static_tools, custom_tools, authorized_imports
+ )
else:
cause = None
if exc is not None:
@@ -750,11 +967,21 @@ def evaluate_raise(raise_node, state, static_tools, custom_tools):
raise InterpreterError("Re-raise is not supported without an active exception")
-def evaluate_assert(assert_node, state, static_tools, custom_tools):
- test_result = evaluate_ast(assert_node.test, state, static_tools, custom_tools)
+def evaluate_assert(
+ assert_node: ast.Assert,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> None:
+ test_result = evaluate_ast(
+ assert_node.test, state, static_tools, custom_tools, authorized_imports
+ )
if not test_result:
if assert_node.msg:
- msg = evaluate_ast(assert_node.msg, state, static_tools, custom_tools)
+ msg = evaluate_ast(
+ assert_node.msg, state, static_tools, custom_tools, authorized_imports
+ )
raise AssertionError(msg)
else:
# Include the failing condition in the assertion message
@@ -762,11 +989,17 @@ def evaluate_assert(assert_node, state, static_tools, custom_tools):
raise AssertionError(f"Assertion failed: {test_code}")
-def evaluate_with(with_node, state, static_tools, custom_tools):
+def evaluate_with(
+ with_node: ast.With,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> None:
contexts = []
for item in with_node.items:
context_expr = evaluate_ast(
- item.context_expr, state, static_tools, custom_tools
+ item.context_expr, state, static_tools, custom_tools, authorized_imports
)
if item.optional_vars:
state[item.optional_vars.id] = context_expr.__enter__()
@@ -777,7 +1010,7 @@ def evaluate_with(with_node, state, static_tools, custom_tools):
try:
for stmt in with_node.body:
- evaluate_ast(stmt, state, static_tools, custom_tools)
+ evaluate_ast(stmt, state, static_tools, custom_tools, authorized_imports)
except Exception as e:
for context in reversed(contexts):
context.__exit__(type(e), e, e.__traceback__)
@@ -789,15 +1022,14 @@ def evaluate_with(with_node, state, static_tools, custom_tools):
def import_modules(expression, state, authorized_imports):
def check_module_authorized(module_name):
- if '*' in authorized_imports:
- return True
- else:
- module_path = module_name.split(".")
- module_subpaths = [
- ".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
- ]
- return any(subpath in authorized_imports for subpath in module_subpaths)
-
+ if "*" in authorized_imports:
+ return True
+ else:
+ module_path = module_name.split(".")
+ module_subpaths = [
+ ".".join(module_path[:i]) for i in range(1, len(module_path) + 1)
+ ]
+ return any(subpath in authorized_imports for subpath in module_subpaths)
if isinstance(expression, ast.Import):
for alias in expression.names:
@@ -821,20 +1053,47 @@ def check_module_authorized(module_name):
return None
-def evaluate_dictcomp(dictcomp, state, static_tools, custom_tools):
+def evaluate_dictcomp(
+ dictcomp: ast.DictComp,
+ state: Dict[str, Any],
+ static_tools: Dict[str, Callable],
+ custom_tools: Dict[str, Callable],
+ authorized_imports: List[str],
+) -> Dict[Any, Any]:
result = {}
for gen in dictcomp.generators:
- iter_value = evaluate_ast(gen.iter, state, static_tools, custom_tools)
+ iter_value = evaluate_ast(
+ gen.iter, state, static_tools, custom_tools, authorized_imports
+ )
for value in iter_value:
new_state = state.copy()
- set_value(gen.target, value, new_state, static_tools, custom_tools)
+ set_value(
+ gen.target,
+ value,
+ new_state,
+ static_tools,
+ custom_tools,
+ authorized_imports,
+ )
if all(
- evaluate_ast(if_clause, new_state, static_tools, custom_tools)
+ evaluate_ast(
+ if_clause, new_state, static_tools, custom_tools, authorized_imports
+ )
for if_clause in gen.ifs
):
- key = evaluate_ast(dictcomp.key, new_state, static_tools, custom_tools)
+ key = evaluate_ast(
+ dictcomp.key,
+ new_state,
+ static_tools,
+ custom_tools,
+ authorized_imports,
+ )
val = evaluate_ast(
- dictcomp.value, new_state, static_tools, custom_tools
+ dictcomp.value,
+ new_state,
+ static_tools,
+ custom_tools,
+ authorized_imports,
)
result[key] = val
return result
@@ -865,7 +1124,7 @@ def evaluate_ast(
Functions that may be called during the evaluation. These static_tools can be overwritten.
authorized_imports (`List[str]`):
The list of modules that can be imported by the code. By default, only a few safe modules are allowed.
- Add more at your own risk!
+ If it contains "*", it will authorize any import. Use this at your own risk!
"""
global OPERATIONS_COUNT
if OPERATIONS_COUNT >= MAX_OPERATIONS:
@@ -876,131 +1135,202 @@ def evaluate_ast(
if isinstance(expression, ast.Assign):
# Assignment -> we evaluate the assignment which should update the state
# We return the variable assigned as it may be used to determine the final result.
- return evaluate_assign(expression, state, static_tools, custom_tools)
+ return evaluate_assign(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.AugAssign):
- return evaluate_augassign(expression, state, static_tools, custom_tools)
+ return evaluate_augassign(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Call):
# Function call -> we return the value of the function call
- return evaluate_call(expression, state, static_tools, custom_tools)
+ return evaluate_call(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Constant):
# Constant -> just return the value
return expression.value
elif isinstance(expression, ast.Tuple):
return tuple(
- evaluate_ast(elt, state, static_tools, custom_tools)
+ evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
for elt in expression.elts
)
elif isinstance(expression, (ast.ListComp, ast.GeneratorExp)):
- return evaluate_listcomp(expression, state, static_tools, custom_tools)
+ return evaluate_listcomp(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.UnaryOp):
- return evaluate_unaryop(expression, state, static_tools, custom_tools)
+ return evaluate_unaryop(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Starred):
- return evaluate_ast(expression.value, state, static_tools, custom_tools)
+ return evaluate_ast(
+ expression.value, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.BoolOp):
# Boolean operation -> evaluate the operation
- return evaluate_boolop(expression, state, static_tools, custom_tools)
+ return evaluate_boolop(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Break):
raise BreakException()
elif isinstance(expression, ast.Continue):
raise ContinueException()
elif isinstance(expression, ast.BinOp):
# Binary operation -> execute operation
- return evaluate_binop(expression, state, static_tools, custom_tools)
+ return evaluate_binop(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Compare):
# Comparison -> evaluate the comparison
- return evaluate_condition(expression, state, static_tools, custom_tools)
+ return evaluate_condition(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Lambda):
- return evaluate_lambda(expression, state, static_tools, custom_tools)
+ return evaluate_lambda(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.FunctionDef):
- return evaluate_function_def(expression, state, static_tools, custom_tools)
+ return evaluate_function_def(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Dict):
# Dict -> evaluate all keys and values
keys = [
- evaluate_ast(k, state, static_tools, custom_tools) for k in expression.keys
+ evaluate_ast(k, state, static_tools, custom_tools, authorized_imports)
+ for k in expression.keys
]
values = [
- evaluate_ast(v, state, static_tools, custom_tools)
+ evaluate_ast(v, state, static_tools, custom_tools, authorized_imports)
for v in expression.values
]
return dict(zip(keys, values))
elif isinstance(expression, ast.Expr):
# Expression -> evaluate the content
- return evaluate_ast(expression.value, state, static_tools, custom_tools)
+ return evaluate_ast(
+ expression.value, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.For):
# For loop -> execute the loop
- return evaluate_for(expression, state, static_tools, custom_tools)
+ return evaluate_for(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.FormattedValue):
# Formatted value (part of f-string) -> evaluate the content and return
- return evaluate_ast(expression.value, state, static_tools, custom_tools)
+ return evaluate_ast(
+ expression.value, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.If):
# If -> execute the right branch
- return evaluate_if(expression, state, static_tools, custom_tools)
+ return evaluate_if(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif hasattr(ast, "Index") and isinstance(expression, ast.Index):
- return evaluate_ast(expression.value, state, static_tools, custom_tools)
+ return evaluate_ast(
+ expression.value, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.JoinedStr):
return "".join(
[
- str(evaluate_ast(v, state, static_tools, custom_tools))
+ str(
+ evaluate_ast(
+ v, state, static_tools, custom_tools, authorized_imports
+ )
+ )
for v in expression.values
]
)
elif isinstance(expression, ast.List):
# List -> evaluate all elements
return [
- evaluate_ast(elt, state, static_tools, custom_tools)
+ evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
for elt in expression.elts
]
elif isinstance(expression, ast.Name):
# Name -> pick up the value in the state
- return evaluate_name(expression, state, static_tools, custom_tools)
+ return evaluate_name(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Subscript):
# Subscript -> return the value of the indexing
- return evaluate_subscript(expression, state, static_tools, custom_tools)
+ return evaluate_subscript(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.IfExp):
- test_val = evaluate_ast(expression.test, state, static_tools, custom_tools)
+ test_val = evaluate_ast(
+ expression.test, state, static_tools, custom_tools, authorized_imports
+ )
if test_val:
- return evaluate_ast(expression.body, state, static_tools, custom_tools)
+ return evaluate_ast(
+ expression.body, state, static_tools, custom_tools, authorized_imports
+ )
else:
- return evaluate_ast(expression.orelse, state, static_tools, custom_tools)
+ return evaluate_ast(
+ expression.orelse, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Attribute):
- value = evaluate_ast(expression.value, state, static_tools, custom_tools)
+ value = evaluate_ast(
+ expression.value, state, static_tools, custom_tools, authorized_imports
+ )
return getattr(value, expression.attr)
elif isinstance(expression, ast.Slice):
return slice(
- evaluate_ast(expression.lower, state, static_tools, custom_tools)
+ evaluate_ast(
+ expression.lower, state, static_tools, custom_tools, authorized_imports
+ )
if expression.lower is not None
else None,
- evaluate_ast(expression.upper, state, static_tools, custom_tools)
+ evaluate_ast(
+ expression.upper, state, static_tools, custom_tools, authorized_imports
+ )
if expression.upper is not None
else None,
- evaluate_ast(expression.step, state, static_tools, custom_tools)
+ evaluate_ast(
+ expression.step, state, static_tools, custom_tools, authorized_imports
+ )
if expression.step is not None
else None,
)
elif isinstance(expression, ast.DictComp):
- return evaluate_dictcomp(expression, state, static_tools, custom_tools)
+ return evaluate_dictcomp(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.While):
- return evaluate_while(expression, state, static_tools, custom_tools)
+ return evaluate_while(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, (ast.Import, ast.ImportFrom)):
return import_modules(expression, state, authorized_imports)
elif isinstance(expression, ast.ClassDef):
- return evaluate_class_def(expression, state, static_tools, custom_tools)
+ return evaluate_class_def(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Try):
- return evaluate_try(expression, state, static_tools, custom_tools)
+ return evaluate_try(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Raise):
- return evaluate_raise(expression, state, static_tools, custom_tools)
+ return evaluate_raise(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Assert):
- return evaluate_assert(expression, state, static_tools, custom_tools)
+ return evaluate_assert(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.With):
- return evaluate_with(expression, state, static_tools, custom_tools)
+ return evaluate_with(
+ expression, state, static_tools, custom_tools, authorized_imports
+ )
elif isinstance(expression, ast.Set):
return {
- evaluate_ast(elt, state, static_tools, custom_tools)
+ evaluate_ast(elt, state, static_tools, custom_tools, authorized_imports)
for elt in expression.elts
}
elif isinstance(expression, ast.Return):
raise ReturnException(
- evaluate_ast(expression.value, state, static_tools, custom_tools)
+ evaluate_ast(
+ expression.value, state, static_tools, custom_tools, authorized_imports
+ )
if expression.value
else None
)
diff --git a/src/smolagents/models.py b/src/smolagents/models.py
index cc9aedc42..a8901e0ef 100644
--- a/src/smolagents/models.py
+++ b/src/smolagents/models.py
@@ -361,7 +361,7 @@ def __call__(
)
prompt_tensor = prompt_tensor.to(self.model.device)
count_prompt_tokens = prompt_tensor["input_ids"].shape[1]
-
+
out = self.model.generate(
**prompt_tensor,
max_new_tokens=max_tokens,
diff --git a/tests/test_agents.py b/tests/test_agents.py
index f51ce9fe9..0a90d2b84 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -313,9 +313,11 @@ def test_fake_code_agent(self):
assert isinstance(output, float)
assert output == 7.2904
assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
- assert agent.logs[3].tool_call == ToolCall(
- name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
- )
+ assert agent.logs[3].tool_calls == [
+ ToolCall(
+ name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
+ )
+ ]
def test_additional_args_added_to_task(self):
agent = CodeAgent(tools=[], model=fake_code_model)
From ad180410789af353b0a13e601b8598813fde2ffa Mon Sep 17 00:00:00 2001
From: tanhuajie <68807603+tanhuajie@users.noreply.github.com>
Date: Tue, 14 Jan 2025 00:24:18 +0800
Subject: [PATCH 2/9] Fix tool_calls parsing error in ToolCallingAgent (#160)
---
src/smolagents/agents.py | 15 ++++++++++++---
1 file changed, 12 insertions(+), 3 deletions(-)
diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py
index 832ac8efc..5741ce9d7 100644
--- a/src/smolagents/agents.py
+++ b/src/smolagents/agents.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
+import json
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -806,9 +807,17 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)
- tool_calls = model_message.tool_calls[0]
- tool_arguments = tool_calls.function.arguments
- tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
+
+ # Extract tool call from model output
+ if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0:
+ tool_calls = model_message.tool_calls[0]
+ tool_arguments = tool_calls.function.arguments
+ tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
+ else:
+ start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1
+ tool_calls = json.loads(model_message.content[start:end])
+ tool_arguments = tool_calls["tool_arguments"]
+ tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}"
except Exception as e:
raise AgentGenerationError(
From 1f96560c925a686eb901bff342526ca933c2c462 Mon Sep 17 00:00:00 2001
From: Albert Villanova del Moral
<8515462+albertvillanova@users.noreply.github.com>
Date: Mon, 13 Jan 2025 17:26:32 +0100
Subject: [PATCH 3/9] Fix minor issues in building_good_agents docs (#170)
* Fix doc inter-link to intro_agents in building_good_agents, make text italic, minor typos
---
docs/source/en/tutorials/building_good_agents.md | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/docs/source/en/tutorials/building_good_agents.md b/docs/source/en/tutorials/building_good_agents.md
index 84f77eaf8..f2d37a20e 100644
--- a/docs/source/en/tutorials/building_good_agents.md
+++ b/docs/source/en/tutorials/building_good_agents.md
@@ -30,7 +30,7 @@ Giving an LLM some agency in your workflow introduces some risk of errors.
Well-programmed agentic systems have good error logging and retry mechanisms anyway, so the LLM engine has a chance to self-correct their mistake. But to reduce the risk of LLM error to the maximum, you should simplify your workflow!
-Let's revisit the example from [intro_agents]: a bot that answers user queries for a surf trip company.
+Let's revisit the example from the [intro to agents](../conceptual_guides/intro_agents): a bot that answers user queries for a surf trip company.
Instead of letting the agent do 2 different calls for "travel distance API" and "weather API" each time they are asked about a new surf spot, you could just make one unified tool "return_spot_information", a function that calls both APIs at once and returns their concatenated outputs to the user.
This will reduce costs, latency, and error risk!
@@ -43,7 +43,7 @@ This leads to a few takeaways:
### Improve the information flow to the LLM engine
-Remember that your LLM engine is like a ~intelligent~ robot, tapped into a room with the only communication with the outside world being notes passed under a door.
+Remember that your LLM engine is like an *intelligent* robot, tapped into a room with the only communication with the outside world being notes passed under a door.
It won't know of anything that happened if you don't explicitly put that into its prompt.
@@ -88,7 +88,7 @@ def get_weather_api(location: str, date_time: str) -> str:
Why is it bad?
- there's no precision of the format that should be used for `date_time`
- there's no detail on how location should be specified.
-- there's no logging mechanism tying to explicit failure cases like location not being in a proper format, or date_time not being properly formatted.
+- there's no logging mechanism trying to make explicit failure cases like location not being in a proper format, or date_time not being properly formatted.
- the output format is hard to understand
If the tool call fails, the error trace logged in memory can help the LLM reverse engineer the tool to fix the errors. But why leave it with so much heavy lifting to do?
From 1d846072eb5a33e429dc0c89dc68976b6114f5e2 Mon Sep 17 00:00:00 2001
From: Aymeric
Date: Mon, 13 Jan 2025 19:46:36 +0100
Subject: [PATCH 4/9] Improve GradioUI file upload system
---
examples/gradio_upload.py | 2 +-
src/smolagents/agents.py | 27 ++++++++++++------
src/smolagents/gradio_ui.py | 49 ++++++++++++++++++++++----------
src/smolagents/models.py | 28 ++++++------------
tests/test_agents.py | 3 +-
tests/test_python_interpreter.py | 1 +
6 files changed, 66 insertions(+), 44 deletions(-)
diff --git a/examples/gradio_upload.py b/examples/gradio_upload.py
index 1e5b4641d..4b8425d83 100644
--- a/examples/gradio_upload.py
+++ b/examples/gradio_upload.py
@@ -5,7 +5,7 @@
)
agent = CodeAgent(
- tools=[], model=HfApiModel(), max_steps=4, verbose=True
+ tools=[], model=HfApiModel(), max_steps=4, verbosity_level=0
)
GradioUI(agent, file_upload_folder='./data').launch()
diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py
index 5741ce9d7..1c81727b8 100644
--- a/src/smolagents/agents.py
+++ b/src/smolagents/agents.py
@@ -396,7 +396,7 @@ def provide_final_answer(self, task) -> str:
}
]
try:
- return self.model(self.input_messages)
+ return self.model(self.input_messages).content
except Exception as e:
return f"Error in generating final LLM output:\n{e}"
@@ -666,7 +666,9 @@ def planning_step(self, task, is_first_step: bool, step: int):
Now begin!""",
}
- answer_facts = self.model([message_prompt_facts, message_prompt_task])
+ answer_facts = self.model(
+ [message_prompt_facts, message_prompt_task]
+ ).content
message_system_prompt_plan = {
"role": MessageRole.SYSTEM,
@@ -688,7 +690,7 @@ def planning_step(self, task, is_first_step: bool, step: int):
answer_plan = self.model(
[message_system_prompt_plan, message_user_prompt_plan],
stop_sequences=[""],
- )
+ ).content
final_plan_redaction = f"""Here is the plan of action that I will follow to solve the task:
```
@@ -722,7 +724,7 @@ def planning_step(self, task, is_first_step: bool, step: int):
}
facts_update = self.model(
[facts_update_system_prompt] + agent_memory + [facts_update_message]
- )
+ ).content
# Redact updated plan
plan_update_message = {
@@ -807,17 +809,26 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)
-
+
# Extract tool call from model output
- if type(model_message.tool_calls) is list and len(model_message.tool_calls) > 0:
+ if (
+ type(model_message.tool_calls) is list
+ and len(model_message.tool_calls) > 0
+ ):
tool_calls = model_message.tool_calls[0]
tool_arguments = tool_calls.function.arguments
tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
else:
- start, end = model_message.content.find('{'), model_message.content.rfind('}') + 1
+ start, end = (
+ model_message.content.find("{"),
+ model_message.content.rfind("}") + 1,
+ )
tool_calls = json.loads(model_message.content[start:end])
tool_arguments = tool_calls["tool_arguments"]
- tool_name, tool_call_id = tool_calls["tool_name"], f"call_{len(self.logs)}"
+ tool_name, tool_call_id = (
+ tool_calls["tool_name"],
+ f"call_{len(self.logs)}",
+ )
except Exception as e:
raise AgentGenerationError(
diff --git a/src/smolagents/gradio_ui.py b/src/smolagents/gradio_ui.py
index 45ae8a250..42a4183dc 100644
--- a/src/smolagents/gradio_ui.py
+++ b/src/smolagents/gradio_ui.py
@@ -27,14 +27,15 @@ def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
"""Extract ChatMessage objects from agent steps"""
if isinstance(step_log, ActionStep):
yield gr.ChatMessage(role="assistant", content=step_log.llm_output or "")
- if step_log.tool_call is not None:
- used_code = step_log.tool_call.name == "code interpreter"
- content = step_log.tool_call.arguments
+ if step_log.tool_calls is not None:
+ first_tool_call = step_log.tool_calls[0]
+ used_code = first_tool_call.name == "code interpreter"
+ content = first_tool_call.arguments
if used_code:
content = f"```py\n{content}\n```"
yield gr.ChatMessage(
role="assistant",
- metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"},
+ metadata={"title": f"🛠️ Used tool {first_tool_call.name}"},
content=str(content),
)
if step_log.observations is not None:
@@ -103,6 +104,7 @@ def interact_with_agent(self, prompt, messages):
def upload_file(
self,
file,
+ file_uploads_log,
allowed_file_types=[
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
@@ -110,14 +112,12 @@ def upload_file(
],
):
"""
- Handle file uploads, default allowed types are pdf, docx, and .txt
+ Handle file uploads, default allowed types are .pdf, .docx, and .txt
"""
- # Check if file is uploaded
if file is None:
return "No file uploaded"
- # Check if file is in allowed filetypes
try:
mime_type, _ = mimetypes.guess_type(file.name)
except Exception as e:
@@ -148,11 +148,23 @@ def upload_file(
)
shutil.copy(file.name, file_path)
- return f"File uploaded successfully to {self.file_upload_folder}"
+ return gr.Textbox(
+ f"File uploaded: {file_path}", visible=True
+ ), file_uploads_log + [file_path]
+
+ def log_user_message(self, text_input, file_uploads_log):
+ return (
+ text_input
+ + f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
+ if len(file_uploads_log) > 0
+ else "",
+ "",
+ )
def launch(self):
with gr.Blocks() as demo:
- stored_message = gr.State([])
+ stored_messages = gr.State([])
+ file_uploads_log = gr.State([])
chatbot = gr.Chatbot(
label="Agent",
type="messages",
@@ -163,14 +175,21 @@ def launch(self):
)
# If an upload folder is provided, enable the upload feature
if self.file_upload_folder is not None:
- upload_file = gr.File(label="Upload a file")
- upload_status = gr.Textbox(label="Upload Status", interactive=False)
-
- upload_file.change(self.upload_file, [upload_file], [upload_status])
+ upload_file = gr.File(label="Upload a file", height=1)
+ upload_status = gr.Textbox(
+ label="Upload Status", interactive=False, visible=False
+ )
+ upload_file.change(
+ self.upload_file,
+ [upload_file, file_uploads_log],
+ [upload_status, file_uploads_log],
+ )
text_input = gr.Textbox(lines=1, label="Chat Message")
text_input.submit(
- lambda s: (s, ""), [text_input], [stored_message, text_input]
- ).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
+ self.log_user_message,
+ [text_input, file_uploads_log],
+ [stored_messages, text_input],
+ ).then(self.interact_with_agent, [stored_messages, chatbot], [chatbot])
demo.launch()
diff --git a/src/smolagents/models.py b/src/smolagents/models.py
index a8901e0ef..a57550ad6 100644
--- a/src/smolagents/models.py
+++ b/src/smolagents/models.py
@@ -36,6 +36,8 @@
StoppingCriteriaList,
is_torch_available,
)
+from transformers.utils.import_utils import _is_package_available
+
import openai
from .tools import Tool
@@ -52,13 +54,9 @@
"value": "Thought: .+?\\nCode:\\n```(?:py|python)?\\n(?:.|\\s)+?\\n```",
}
-try:
+if _is_package_available("litellm"):
import litellm
- is_litellm_available = True
-except ImportError:
- is_litellm_available = False
-
class MessageRole(str, Enum):
USER = "user"
@@ -159,7 +157,7 @@ def __call__(
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
- ) -> str:
+ ) -> ChatCompletionOutputMessage:
"""Process the input messages and return the model's response.
Parameters:
@@ -174,15 +172,7 @@ def __call__(
Returns:
`str`: The text content of the model's response.
"""
- if not isinstance(messages, List):
- raise ValueError(
- "Messages should be a list of dictionaries with 'role' and 'content' keys."
- )
- if stop_sequences is None:
- stop_sequences = []
- response = self.generate(messages, stop_sequences, grammar, max_tokens)
-
- return remove_stop_sequences(response, stop_sequences)
+ pass # To be implemented in child classes!
class HfApiModel(Model):
@@ -238,7 +228,7 @@ def __call__(
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> str:
+ ) -> ChatCompletionOutputMessage:
"""
Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
@@ -407,7 +397,7 @@ def __init__(
api_key=None,
**kwargs,
):
- if not is_litellm_available:
+ if not _is_package_available("litellm"):
raise ImportError(
"litellm not found. Install it with `pip install litellm`"
)
@@ -426,7 +416,7 @@ def __call__(
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> str:
+ ) -> ChatCompletionOutputMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
@@ -497,7 +487,7 @@ def __call__(
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> str:
+ ) -> ChatCompletionOutputMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
diff --git a/tests/test_agents.py b/tests/test_agents.py
index 0a90d2b84..38538cebc 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -367,9 +367,10 @@ def test_fails_max_steps(self):
model=fake_code_model_no_return, # use this callable because it never ends
max_steps=5,
)
- agent.run("What is 2 multiplied by 3.6452?")
+ answer = agent.run("What is 2 multiplied by 3.6452?")
assert len(agent.logs) == 8
assert type(agent.logs[-1].error) is AgentMaxStepsError
+ assert isinstance(answer, str)
def test_tool_descriptions_get_baked_in_system_prompt(self):
tool = PythonInterpreterTool()
diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py
index 59440665b..8c7aacc9c 100644
--- a/tests/test_python_interpreter.py
+++ b/tests/test_python_interpreter.py
@@ -486,6 +486,7 @@ def test_additional_imports(self):
code = "import numpy.random as rd"
evaluate_python_code(code, authorized_imports=["numpy.random"], state={})
evaluate_python_code(code, authorized_imports=["numpy"], state={})
+ evaluate_python_code(code, authorized_imports=["*"], state={})
with pytest.raises(InterpreterError):
evaluate_python_code(code, authorized_imports=["random"], state={})
From c04e8de8250467dec047690b1db9b7dd01b2650b Mon Sep 17 00:00:00 2001
From: Ilya Gusev
Date: Tue, 14 Jan 2025 09:58:45 +0100
Subject: [PATCH 5/9] Bugfix: Fix plan_update message display (#179)
---
src/smolagents/agents.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py
index 1c81727b8..b3d0c5a78 100644
--- a/src/smolagents/agents.py
+++ b/src/smolagents/agents.py
@@ -748,7 +748,7 @@ def planning_step(self, task, is_first_step: bool, step: int):
plan_update = self.model(
[plan_update_message] + agent_memory + [plan_update_message_user],
stop_sequences=[""],
- )
+ ).content
# Log final facts and plan
final_plan_redaction = PLAN_UPDATE_FINAL_PLAN_REDACTION.format(
From 12a2e6f4b4eadf94a57034b073672e8b115ac89c Mon Sep 17 00:00:00 2001
From: Deng Tongwei <74892366+6643789wsx@users.noreply.github.com>
Date: Tue, 14 Jan 2025 17:00:08 +0800
Subject: [PATCH 6/9] feat: Add multi-GPU support for TransformersModel (#139)
Add multi-GPU support for TransformersModel
---
src/smolagents/models.py | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/src/smolagents/models.py b/src/smolagents/models.py
index a57550ad6..70ef5d196 100644
--- a/src/smolagents/models.py
+++ b/src/smolagents/models.py
@@ -287,16 +287,14 @@ def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None)
logger.info(f"Using device: {self.device}")
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
- self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device)
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
except Exception as e:
logger.warning(
f"Failed to load tokenizer and model for {model_id=}: {e}. Loading default tokenizer and model instead from {default_model_id=}."
)
self.model_id = default_model_id
self.tokenizer = AutoTokenizer.from_pretrained(default_model_id)
- self.model = AutoModelForCausalLM.from_pretrained(default_model_id).to(
- self.device
- )
+ self.model = AutoModelForCausalLM.from_pretrained(model_id, device_map=self.device)
def make_stopping_criteria(self, stop_sequences: List[str]) -> StoppingCriteriaList:
class StopOnStrings(StoppingCriteria):
From 5f323735511f54168b688cdb0dee10ab5bdcd909 Mon Sep 17 00:00:00 2001
From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com>
Date: Tue, 14 Jan 2025 14:57:11 +0100
Subject: [PATCH 7/9] Make default tools more robust (#186)
---
.github/workflows/tests.yml | 3 +
examples/benchmark.ipynb | 301 +++++++++----------------------
src/smolagents/agents.py | 29 +--
src/smolagents/default_tools.py | 24 ++-
src/smolagents/models.py | 60 ++++--
src/smolagents/tools.py | 10 +
tests/test_agents.py | 78 ++++----
tests/test_default_tools.py | 83 +++++++++
tests/test_monitoring.py | 18 +-
tests/test_python_interpreter.py | 46 +----
10 files changed, 296 insertions(+), 356 deletions(-)
create mode 100644 tests/test_default_tools.py
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index a595bede2..c720ec0f5 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -36,6 +36,9 @@ jobs:
- name: Agent tests
run: |
uv run pytest -sv ./tests/test_agents.py
+ - name: Default tools tests
+ run: |
+ uv run pytest -sv ./tests/test_default_tools.py
- name: Final answer tests
run: |
uv run pytest -sv ./tests/test_final_answer.py
diff --git a/examples/benchmark.ipynb b/examples/benchmark.ipynb
index 1009f2807..8b49b0aa2 100644
--- a/examples/benchmark.ipynb
+++ b/examples/benchmark.ipynb
@@ -21,7 +21,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -29,8 +29,7 @@
"output_type": "stream",
"text": [
"/Users/aymeric/venv/test/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n",
- "Generating train split: 100%|██████████| 132/132 [00:00<00:00, 17393.36 examples/s]\n"
+ " from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
@@ -173,7 +172,7 @@
"[132 rows x 4 columns]"
]
},
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -196,19 +195,9 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 7,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/aymeric/venv/test/lib/python3.12/site-packages/pydantic/_internal/_config.py:345: UserWarning: Valid config keys have changed in V2:\n",
- "* 'fields' has been removed\n",
- " warnings.warn(message, UserWarning)\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import time\n",
"import json\n",
@@ -408,100 +397,9 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'meta-llama/Llama-3.3-70B-Instruct'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 27061.35it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 34618.15it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'Qwen/Qwen2.5-72B-Instruct'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 33008.29it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 36292.90it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'Qwen/Qwen2.5-Coder-32B-Instruct'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 29165.47it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 30378.50it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'meta-llama/Llama-3.2-3B-Instruct'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 33453.06it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 34763.79it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'meta-llama/Llama-3.1-8B-Instruct'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 35246.25it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 28551.81it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'mistralai/Mistral-Nemo-Instruct-2407'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 32441.59it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 35542.67it/s]\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"open_model_ids = [\n",
" \"meta-llama/Llama-3.3-70B-Instruct\",\n",
@@ -554,42 +452,9 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'gpt-4o'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 36136.55it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 33451.04it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 39146.44it/s]\n"
- ]
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Evaluating 'anthropic/claude-3-5-sonnet-latest'...\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 132/132 [00:00<00:00, 31512.79it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 33576.82it/s]\n",
- "100%|██████████| 132/132 [00:00<00:00, 36075.33it/s]\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"from smolagents import LiteLLMModel\n",
"\n",
@@ -614,7 +479,7 @@
" agent = CodeAgent(\n",
" tools=[GoogleSearchTool(), VisitWebpageTool()],\n",
" model=LiteLLMModel(model_id),\n",
- " additional_authorized_imports=[\"numpy\"],\n",
+ " additional_authorized_imports=[\"numpy\", \"sympy\"],\n",
" max_steps=10,\n",
" )\n",
" file_name = f\"output/{model_id.replace('/', '_')}-{action_type}-26-dec-2024.jsonl\"\n",
@@ -631,34 +496,39 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# import glob\n",
"# import json\n",
+ "\n",
"# jsonl_files = glob.glob(f\"output/*.jsonl\")\n",
"\n",
"# for file_path in jsonl_files:\n",
- "# print(file_path)\n",
- "# # Read all lines and filter out SimpleQA sources\n",
- "# filtered_lines = []\n",
- "# removed = 0\n",
- "# with open(file_path, 'r', encoding='utf-8') as f:\n",
- "# for line in f:\n",
- "# try:\n",
- "# data = json.loads(line.strip())\n",
- "# if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
- "# removed +=1\n",
- "# else:\n",
- "# filtered_lines.append(line)\n",
- "# except json.JSONDecodeError:\n",
- "# print(\"Invalid line:\", line)\n",
- "# continue # Skip invalid JSON lines\n",
- "# print(f\"Removed {removed} lines.\")\n",
- "# # Write filtered content back to the same file\n",
- "# with open(file_path, 'w', encoding='utf-8') as f:\n",
- "# f.writelines(filtered_lines)"
+ "# if \"-Nemo-\" in file_path and \"-vanilla-\" in file_path:\n",
+ "# print(file_path)\n",
+ "# # Read all lines and filter out SimpleQA sources\n",
+ "# filtered_lines = []\n",
+ "# removed = 0\n",
+ "# with open(file_path, \"r\", encoding=\"utf-8\") as f:\n",
+ "# for line in f:\n",
+ "# try:\n",
+ "# data = json.loads(line.strip())\n",
+ "# data[\"answer\"] = data[\"answer\"][\"content\"]\n",
+ "# # if not any([question in data[\"question\"] for question in eval_ds[\"question\"]]):\n",
+ "# # removed +=1\n",
+ "# # else:\n",
+ "# filtered_lines.append(json.dumps(data) + \"\\n\")\n",
+ "# except json.JSONDecodeError:\n",
+ "# print(\"Invalid line:\", line)\n",
+ "# continue # Skip invalid JSON lines\n",
+ "# print(f\"Removed {removed} lines.\")\n",
+ "# # Write filtered content back to the same file\n",
+ "# with open(\n",
+ "# str(file_path).replace(\"-vanilla-\", \"-vanilla2-\"), \"w\", encoding=\"utf-8\"\n",
+ "# ) as f:\n",
+ "# f.writelines(filtered_lines)"
]
},
{
@@ -670,14 +540,14 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_6037/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
+ "/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/ipykernel_27085/1901135017.py:164: UserWarning: Answer lists have different lengths, returning False.\n",
" warnings.warn(\n"
]
}
@@ -731,7 +601,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@@ -752,7 +622,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -794,28 +664,28 @@
" 1 | \n",
" Qwen/Qwen2.5-72B-Instruct | \n",
" MATH | \n",
- " 74.0 | \n",
+ " 76.0 | \n",
" 30.0 | \n",
" \n",
" \n",
" | 2 | \n",
" Qwen/Qwen2.5-72B-Instruct | \n",
" SimpleQA | \n",
- " 70.0 | \n",
+ " 88.0 | \n",
" 10.0 | \n",
"
\n",
" \n",
" | 3 | \n",
" Qwen/Qwen2.5-Coder-32B-Instruct | \n",
" GAIA | \n",
- " 18.8 | \n",
+ " 25.0 | \n",
" 3.1 | \n",
"
\n",
" \n",
" | 4 | \n",
" Qwen/Qwen2.5-Coder-32B-Instruct | \n",
" MATH | \n",
- " 76.0 | \n",
+ " 86.0 | \n",
" 60.0 | \n",
"
\n",
" \n",
@@ -829,63 +699,63 @@
" | 6 | \n",
" anthropic/claude-3-5-sonnet-latest | \n",
" GAIA | \n",
- " 40.6 | \n",
+ " NaN | \n",
" 3.1 | \n",
"
\n",
" \n",
" | 7 | \n",
" anthropic/claude-3-5-sonnet-latest | \n",
" MATH | \n",
- " 67.0 | \n",
+ " NaN | \n",
" 50.0 | \n",
"
\n",
" \n",
" | 8 | \n",
" anthropic/claude-3-5-sonnet-latest | \n",
" SimpleQA | \n",
- " 90.0 | \n",
+ " NaN | \n",
" 34.0 | \n",
"
\n",
" \n",
" | 9 | \n",
" gpt-4o | \n",
" GAIA | \n",
- " 28.1 | \n",
+ " 25.6 | \n",
" 3.1 | \n",
"
\n",
" \n",
" | 10 | \n",
" gpt-4o | \n",
" MATH | \n",
- " 70.0 | \n",
+ " 58.0 | \n",
" 40.0 | \n",
"
\n",
" \n",
" | 11 | \n",
" gpt-4o | \n",
" SimpleQA | \n",
- " 88.0 | \n",
+ " 86.0 | \n",
" 6.0 | \n",
"
\n",
" \n",
" | 12 | \n",
" meta-llama/Llama-3.1-8B-Instruct | \n",
" GAIA | \n",
- " 0.0 | \n",
+ " 3.1 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 13 | \n",
" meta-llama/Llama-3.1-8B-Instruct | \n",
" MATH | \n",
- " 42.0 | \n",
+ " 14.0 | \n",
" 18.0 | \n",
"
\n",
" \n",
" | 14 | \n",
" meta-llama/Llama-3.1-8B-Instruct | \n",
" SimpleQA | \n",
- " 54.0 | \n",
+ " 2.0 | \n",
" 6.0 | \n",
"
\n",
" \n",
@@ -899,49 +769,49 @@
" | 16 | \n",
" meta-llama/Llama-3.2-3B-Instruct | \n",
" MATH | \n",
- " 32.0 | \n",
+ " 40.0 | \n",
" 12.0 | \n",
"
\n",
" \n",
" | 17 | \n",
" meta-llama/Llama-3.2-3B-Instruct | \n",
" SimpleQA | \n",
- " 4.0 | \n",
+ " 20.0 | \n",
" 0.0 | \n",
"
\n",
" \n",
" | 18 | \n",
" meta-llama/Llama-3.3-70B-Instruct | \n",
" GAIA | \n",
- " 34.4 | \n",
+ " 31.2 | \n",
" 3.1 | \n",
"
\n",
" \n",
" | 19 | \n",
" meta-llama/Llama-3.3-70B-Instruct | \n",
" MATH | \n",
- " 82.0 | \n",
+ " 72.0 | \n",
" 40.0 | \n",
"
\n",
" \n",
" | 20 | \n",
" meta-llama/Llama-3.3-70B-Instruct | \n",
" SimpleQA | \n",
- " 84.0 | \n",
+ " 78.0 | \n",
" 12.0 | \n",
"
\n",
" \n",
" | 21 | \n",
" mistralai/Mistral-Nemo-Instruct-2407 | \n",
" GAIA | \n",
- " 3.1 | \n",
" 0.0 | \n",
+ " 3.1 | \n",
"
\n",
" \n",
" | 22 | \n",
" mistralai/Mistral-Nemo-Instruct-2407 | \n",
" MATH | \n",
- " 20.0 | \n",
+ " 30.0 | \n",
" 22.0 | \n",
"
\n",
" \n",
@@ -949,7 +819,7 @@
" | mistralai/Mistral-Nemo-Instruct-2407 | \n",
" SimpleQA | \n",
" 30.0 | \n",
- " 0.0 | \n",
+ " 6.0 | \n",
"
\n",
" \n",
"\n",
@@ -958,29 +828,29 @@
"text/plain": [
"action_type model_id source code vanilla\n",
"0 Qwen/Qwen2.5-72B-Instruct GAIA 28.1 6.2\n",
- "1 Qwen/Qwen2.5-72B-Instruct MATH 74.0 30.0\n",
- "2 Qwen/Qwen2.5-72B-Instruct SimpleQA 70.0 10.0\n",
- "3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 18.8 3.1\n",
- "4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 76.0 60.0\n",
+ "1 Qwen/Qwen2.5-72B-Instruct MATH 76.0 30.0\n",
+ "2 Qwen/Qwen2.5-72B-Instruct SimpleQA 88.0 10.0\n",
+ "3 Qwen/Qwen2.5-Coder-32B-Instruct GAIA 25.0 3.1\n",
+ "4 Qwen/Qwen2.5-Coder-32B-Instruct MATH 86.0 60.0\n",
"5 Qwen/Qwen2.5-Coder-32B-Instruct SimpleQA 86.0 8.0\n",
- "6 anthropic/claude-3-5-sonnet-latest GAIA 40.6 3.1\n",
- "7 anthropic/claude-3-5-sonnet-latest MATH 67.0 50.0\n",
- "8 anthropic/claude-3-5-sonnet-latest SimpleQA 90.0 34.0\n",
- "9 gpt-4o GAIA 28.1 3.1\n",
- "10 gpt-4o MATH 70.0 40.0\n",
- "11 gpt-4o SimpleQA 88.0 6.0\n",
- "12 meta-llama/Llama-3.1-8B-Instruct GAIA 0.0 0.0\n",
- "13 meta-llama/Llama-3.1-8B-Instruct MATH 42.0 18.0\n",
- "14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 54.0 6.0\n",
+ "6 anthropic/claude-3-5-sonnet-latest GAIA NaN 3.1\n",
+ "7 anthropic/claude-3-5-sonnet-latest MATH NaN 50.0\n",
+ "8 anthropic/claude-3-5-sonnet-latest SimpleQA NaN 34.0\n",
+ "9 gpt-4o GAIA 25.6 3.1\n",
+ "10 gpt-4o MATH 58.0 40.0\n",
+ "11 gpt-4o SimpleQA 86.0 6.0\n",
+ "12 meta-llama/Llama-3.1-8B-Instruct GAIA 3.1 0.0\n",
+ "13 meta-llama/Llama-3.1-8B-Instruct MATH 14.0 18.0\n",
+ "14 meta-llama/Llama-3.1-8B-Instruct SimpleQA 2.0 6.0\n",
"15 meta-llama/Llama-3.2-3B-Instruct GAIA 3.1 0.0\n",
- "16 meta-llama/Llama-3.2-3B-Instruct MATH 32.0 12.0\n",
- "17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 4.0 0.0\n",
- "18 meta-llama/Llama-3.3-70B-Instruct GAIA 34.4 3.1\n",
- "19 meta-llama/Llama-3.3-70B-Instruct MATH 82.0 40.0\n",
- "20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 84.0 12.0\n",
- "21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 3.1 0.0\n",
- "22 mistralai/Mistral-Nemo-Instruct-2407 MATH 20.0 22.0\n",
- "23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 0.0"
+ "16 meta-llama/Llama-3.2-3B-Instruct MATH 40.0 12.0\n",
+ "17 meta-llama/Llama-3.2-3B-Instruct SimpleQA 20.0 0.0\n",
+ "18 meta-llama/Llama-3.3-70B-Instruct GAIA 31.2 3.1\n",
+ "19 meta-llama/Llama-3.3-70B-Instruct MATH 72.0 40.0\n",
+ "20 meta-llama/Llama-3.3-70B-Instruct SimpleQA 78.0 12.0\n",
+ "21 mistralai/Mistral-Nemo-Instruct-2407 GAIA 0.0 3.1\n",
+ "22 mistralai/Mistral-Nemo-Instruct-2407 MATH 30.0 22.0\n",
+ "23 mistralai/Mistral-Nemo-Instruct-2407 SimpleQA 30.0 6.0"
]
},
"metadata": {},
@@ -1005,6 +875,15 @@
},
"metadata": {},
"output_type": "display_data"
+ },
+ {
+ "ename": "",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
+ "\u001b[1;31mView Jupyter log for further details."
+ ]
}
],
"source": [
diff --git a/src/smolagents/agents.py b/src/smolagents/agents.py
index b3d0c5a78..cfa8a6ff4 100644
--- a/src/smolagents/agents.py
+++ b/src/smolagents/agents.py
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
-import json
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -809,26 +808,9 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
tools_to_call_from=list(self.tools.values()),
stop_sequences=["Observation:"],
)
-
- # Extract tool call from model output
- if (
- type(model_message.tool_calls) is list
- and len(model_message.tool_calls) > 0
- ):
- tool_calls = model_message.tool_calls[0]
- tool_arguments = tool_calls.function.arguments
- tool_name, tool_call_id = tool_calls.function.name, tool_calls.id
- else:
- start, end = (
- model_message.content.find("{"),
- model_message.content.rfind("}") + 1,
- )
- tool_calls = json.loads(model_message.content[start:end])
- tool_arguments = tool_calls["tool_arguments"]
- tool_name, tool_call_id = (
- tool_calls["tool_name"],
- f"call_{len(self.logs)}",
- )
+ tool_call = model_message.tool_calls[0]
+ tool_name, tool_call_id = tool_call.function.name, tool_call.id
+ tool_arguments = tool_call.function.arguments
except Exception as e:
raise AgentGenerationError(
@@ -887,7 +869,10 @@ def step(self, log_entry: ActionStep) -> Union[None, Any]:
updated_information = f"Stored '{observation_name}' in memory."
else:
updated_information = str(observation).strip()
- self.logger.log(f"Observations: {updated_information}", level=LogLevel.INFO)
+ self.logger.log(
+ f"Observations: {updated_information.replace('[', '|')}", # escape potential rich-tag-like components
+ level=LogLevel.INFO,
+ )
log_entry.observations = updated_information
return None
diff --git a/src/smolagents/default_tools.py b/src/smolagents/default_tools.py
index 75fe8d017..59f6820f2 100644
--- a/src/smolagents/default_tools.py
+++ b/src/smolagents/default_tools.py
@@ -31,6 +31,7 @@
)
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
from .types import AgentAudio
+from .utils import truncate_content
if is_torch_available():
from transformers.models.whisper import (
@@ -112,18 +113,15 @@ def __init__(self, *args, authorized_imports=None, **kwargs):
def forward(self, code: str) -> str:
state = {}
- try:
- output = str(
- self.python_evaluator(
- code,
- state=state,
- static_tools=self.base_python_tools,
- authorized_imports=self.authorized_imports,
- )[0] # The second element is boolean is_final_answer
- )
- return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
- except Exception as e:
- return f"Error: {str(e)}"
+ output = str(
+ self.python_evaluator(
+ code,
+ state=state,
+ static_tools=self.base_python_tools,
+ authorized_imports=self.authorized_imports,
+ )[0] # The second element is boolean is_final_answer
+ )
+ return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
class FinalAnswerTool(Tool):
@@ -295,7 +293,7 @@ def forward(self, url: str) -> str:
# Remove multiple line breaks
markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
- return markdown_content
+ return truncate_content(markdown_content)
except RequestException as e:
return f"Error fetching the webpage: {str(e)}"
diff --git a/src/smolagents/models.py b/src/smolagents/models.py
index 70ef5d196..f25ced9c6 100644
--- a/src/smolagents/models.py
+++ b/src/smolagents/models.py
@@ -14,20 +14,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from dataclasses import dataclass
import json
import logging
import os
import random
from copy import deepcopy
from enum import Enum
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Union, Any
-from huggingface_hub import (
- InferenceClient,
- ChatCompletionOutputMessage,
- ChatCompletionOutputToolCall,
- ChatCompletionOutputFunctionDefinition,
-)
+from huggingface_hub import InferenceClient
from transformers import (
AutoModelForCausalLM,
@@ -58,6 +54,27 @@
import litellm
+@dataclass
+class ChatMessageToolCallDefinition:
+ arguments: Any
+ name: str
+ description: Optional[str] = None
+
+
+@dataclass
+class ChatMessageToolCall:
+ function: ChatMessageToolCallDefinition
+ id: str
+ type: str
+
+
+@dataclass
+class ChatMessage:
+ role: str
+ content: Optional[str] = None
+ tool_calls: Optional[List[ChatMessageToolCall]] = None
+
+
class MessageRole(str, Enum):
USER = "user"
ASSISTANT = "assistant"
@@ -140,6 +157,17 @@ def get_clean_message_list(
return final_message_list
+def parse_dictionary(possible_dictionary: str) -> Union[Dict, str]:
+ try:
+ start, end = (
+ possible_dictionary.find("{"),
+ possible_dictionary.rfind("}") + 1,
+ )
+ return json.loads(possible_dictionary[start:end])
+ except Exception:
+ return possible_dictionary
+
+
class Model:
def __init__(self):
self.last_input_token_count = None
@@ -157,7 +185,7 @@ def __call__(
stop_sequences: Optional[List[str]] = None,
grammar: Optional[str] = None,
max_tokens: int = 1500,
- ) -> ChatCompletionOutputMessage:
+ ) -> ChatMessage:
"""Process the input messages and return the model's response.
Parameters:
@@ -228,7 +256,7 @@ def __call__(
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> ChatCompletionOutputMessage:
+ ) -> ChatMessage:
"""
Gets an LLM output message for the given list of input messages.
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
@@ -329,7 +357,7 @@ def __call__(
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> ChatCompletionOutputMessage:
+ ) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
@@ -365,21 +393,21 @@ def __call__(
if stop_sequences is not None:
output = remove_stop_sequences(output, stop_sequences)
if tools_to_call_from is None:
- return ChatCompletionOutputMessage(role="assistant", content=output)
+ return ChatMessage(role="assistant", content=output)
else:
if "Action:" in output:
output = output.split("Action:", 1)[1].strip()
parsed_output = json.loads(output)
tool_name = parsed_output.get("tool_name")
tool_arguments = parsed_output.get("tool_arguments")
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="".join(random.choices("0123456789", k=5)),
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name=tool_name, arguments=tool_arguments
),
)
@@ -414,7 +442,7 @@ def __call__(
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> ChatCompletionOutputMessage:
+ ) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
@@ -485,7 +513,7 @@ def __call__(
grammar: Optional[str] = None,
max_tokens: int = 1500,
tools_to_call_from: Optional[List[Tool]] = None,
- ) -> ChatCompletionOutputMessage:
+ ) -> ChatMessage:
messages = get_clean_message_list(
messages, role_conversions=tool_role_conversions
)
diff --git a/src/smolagents/tools.py b/src/smolagents/tools.py
index d5ec6b0fe..04a203d92 100644
--- a/src/smolagents/tools.py
+++ b/src/smolagents/tools.py
@@ -221,6 +221,16 @@ def forward(self, *args, **kwargs):
def __call__(self, *args, sanitize_inputs_outputs: bool = False, **kwargs):
if not self.is_initialized:
self.setup()
+
+ # Handle the arguments might be passed as a single dictionary
+ if len(args) == 1 and len(kwargs) == 0 and isinstance(args[0], dict):
+ potential_kwargs = args[0]
+
+ # If the dictionary keys match our input parameters, convert it to kwargs
+ if all(key in self.inputs for key in potential_kwargs):
+ args = ()
+ kwargs = potential_kwargs
+
if sanitize_inputs_outputs:
args, kwargs = handle_agent_input_types(*args, **kwargs)
outputs = self.forward(*args, **kwargs)
diff --git a/tests/test_agents.py b/tests/test_agents.py
index 38538cebc..1cd0a6750 100644
--- a/tests/test_agents.py
+++ b/tests/test_agents.py
@@ -30,10 +30,10 @@
from smolagents.default_tools import PythonInterpreterTool
from smolagents.tools import tool
from smolagents.types import AgentImage, AgentText
-from huggingface_hub import (
- ChatCompletionOutputMessage,
- ChatCompletionOutputToolCall,
- ChatCompletionOutputFunctionDefinition,
+from smolagents.models import (
+ ChatMessage,
+ ChatMessageToolCall,
+ ChatMessageToolCallDefinition,
)
@@ -47,28 +47,28 @@ def __call__(
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
if len(messages) < 3:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_0",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="python_interpreter", arguments={"code": "2*3.6452"}
),
)
],
)
else:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_1",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="final_answer", arguments={"answer": "7.2904"}
),
)
@@ -81,14 +81,14 @@ def __call__(
self, messages, tools_to_call_from=None, stop_sequences=None, grammar=None
):
if len(messages) < 3:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_0",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="fake_image_generation_tool",
arguments={"prompt": "An image of a cat"},
),
@@ -96,14 +96,14 @@ def __call__(
],
)
else:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_1",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="final_answer", arguments="image.png"
),
)
@@ -114,7 +114,7 @@ def __call__(
def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@@ -125,7 +125,7 @@ def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
""",
)
else: # We're at step 2
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@@ -140,7 +140,7 @@ def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
def fake_code_model_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@@ -154,7 +154,7 @@ def fake_code_model_error(messages, stop_sequences=None) -> str:
""",
)
else: # We're at step 2
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@@ -169,7 +169,7 @@ def fake_code_model_error(messages, stop_sequences=None) -> str:
def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@@ -183,7 +183,7 @@ def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
""",
)
else: # We're at step 2
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@@ -196,7 +196,7 @@ def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
def fake_code_model_import(messages, stop_sequences=None) -> str:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I can answer the question
@@ -212,7 +212,7 @@ def fake_code_model_import(messages, stop_sequences=None) -> str:
def fake_code_functiondef(messages, stop_sequences=None) -> str:
prompt = str(messages)
if "special_marker" not in prompt:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: Let's define the function. special_marker
@@ -226,7 +226,7 @@ def moving_average(x, w):
""",
)
else: # We're at step 2
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I can now answer the initial question
@@ -241,7 +241,7 @@ def moving_average(x, w):
def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@@ -255,7 +255,7 @@ def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) ->
def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: I should multiply 2 by 3.6452. special_marker
@@ -454,14 +454,14 @@ def __call__(
):
if tools_to_call_from is not None:
if len(messages) < 3:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_0",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="search_agent",
arguments="Who is the current US president?",
),
@@ -470,14 +470,14 @@ def __call__(
)
else:
assert "Report on the current US president" in str(messages)
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_0",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="final_answer", arguments="Final report."
),
)
@@ -485,7 +485,7 @@ def __call__(
)
else:
if len(messages) < 3:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: Let's call our search agent.
@@ -497,7 +497,7 @@ def __call__(
)
else:
assert "Report on the current US president" in str(messages)
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Thought: Let's return the report.
@@ -518,14 +518,14 @@ def __call__(
stop_sequences=None,
grammar=None,
):
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="call_0",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="final_answer",
arguments="Report on the current US president",
),
@@ -568,7 +568,7 @@ def __call__(
def test_code_nontrivial_final_answer_works(self):
def fake_code_model_final_answer(messages, stop_sequences=None, grammar=None):
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""Code:
```py
diff --git a/tests/test_default_tools.py b/tests/test_default_tools.py
new file mode 100644
index 000000000..d966b84a9
--- /dev/null
+++ b/tests/test_default_tools.py
@@ -0,0 +1,83 @@
+# coding=utf-8
+# Copyright 2024 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+import pytest
+
+from smolagents.default_tools import PythonInterpreterTool, VisitWebpageTool
+from smolagents.types import AGENT_TYPE_MAPPING
+
+from .test_tools import ToolTesterMixin
+
+
+class DefaultToolTests(unittest.TestCase):
+ def test_visit_webpage(self):
+ arguments = {
+ "url": "https://en.wikipedia.org/wiki/United_States_Secretary_of_Homeland_Security"
+ }
+ result = VisitWebpageTool()(arguments)
+ assert isinstance(result, str)
+ assert (
+ "* [About Wikipedia](/wiki/Wikipedia:About)" in result
+ ) # Proper wikipedia pages have an About
+
+
+class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
+ def setUp(self):
+ self.tool = PythonInterpreterTool(authorized_imports=["numpy"])
+ self.tool.setup()
+
+ def test_exact_match_arg(self):
+ result = self.tool("(2 / 2) * 4")
+ self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
+
+ def test_exact_match_kwarg(self):
+ result = self.tool(code="(2 / 2) * 4")
+ self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
+
+ def test_agent_type_output(self):
+ inputs = ["2 * 2"]
+ output = self.tool(*inputs, sanitize_inputs_outputs=True)
+ output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
+ self.assertTrue(isinstance(output, output_type))
+
+ def test_agent_types_inputs(self):
+ inputs = ["2 * 2"]
+ _inputs = []
+
+ for _input, expected_input in zip(inputs, self.tool.inputs.values()):
+ input_type = expected_input["type"]
+ if isinstance(input_type, list):
+ _inputs.append(
+ [
+ AGENT_TYPE_MAPPING[_input_type](_input)
+ for _input_type in input_type
+ ]
+ )
+ else:
+ _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
+
+ # Should not raise an error
+ output = self.tool(*inputs, sanitize_inputs_outputs=True)
+ output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
+ self.assertTrue(isinstance(output, output_type))
+
+ def test_imports_work(self):
+ result = self.tool("import numpy as np")
+ assert "import from numpy is not allowed" not in result.lower()
+
+ def test_unauthorized_imports_fail(self):
+ with pytest.raises(Exception) as e:
+ self.tool("import sympy as sp")
+ assert "sympy" in str(e).lower()
diff --git a/tests/test_monitoring.py b/tests/test_monitoring.py
index 11594e7ff..e55afb43d 100644
--- a/tests/test_monitoring.py
+++ b/tests/test_monitoring.py
@@ -23,9 +23,9 @@
stream_to_gradio,
)
from huggingface_hub import (
- ChatCompletionOutputMessage,
- ChatCompletionOutputToolCall,
- ChatCompletionOutputFunctionDefinition,
+ ChatMessage,
+ ChatMessageToolCall,
+ ChatMessageToolCallDefinition,
)
@@ -36,21 +36,21 @@ def __init__(self):
def __call__(self, prompt, tools_to_call_from=None, **kwargs):
if tools_to_call_from is not None:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="",
tool_calls=[
- ChatCompletionOutputToolCall(
+ ChatMessageToolCall(
id="fake_id",
type="function",
- function=ChatCompletionOutputFunctionDefinition(
+ function=ChatMessageToolCallDefinition(
name="final_answer", arguments={"answer": "image"}
),
)
],
)
else:
- return ChatCompletionOutputMessage(
+ return ChatMessage(
role="assistant",
content="""
Code:
@@ -91,9 +91,7 @@ def __init__(self):
self.last_output_token_count = 20
def __call__(self, prompt, **kwargs):
- return ChatCompletionOutputMessage(
- role="assistant", content="Malformed answer"
- )
+ return ChatMessage(role="assistant", content="Malformed answer")
agent = CodeAgent(
tools=[],
diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py
index 8c7aacc9c..75a146e00 100644
--- a/tests/test_python_interpreter.py
+++ b/tests/test_python_interpreter.py
@@ -18,15 +18,12 @@
import numpy as np
import pytest
-from smolagents.default_tools import BASE_PYTHON_TOOLS, PythonInterpreterTool
+from smolagents.default_tools import BASE_PYTHON_TOOLS
from smolagents.local_python_executor import (
InterpreterError,
evaluate_python_code,
fix_final_answer_code,
)
-from smolagents.types import AGENT_TYPE_MAPPING
-
-from .test_tools import ToolTesterMixin
# Fake function we will use as tool
@@ -34,47 +31,6 @@ def add_two(x):
return x + 2
-class PythonInterpreterToolTester(unittest.TestCase, ToolTesterMixin):
- def setUp(self):
- self.tool = PythonInterpreterTool(authorized_imports=["sqlite3"])
- self.tool.setup()
-
- def test_exact_match_arg(self):
- result = self.tool("(2 / 2) * 4")
- self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
-
- def test_exact_match_kwarg(self):
- result = self.tool(code="(2 / 2) * 4")
- self.assertEqual(result, "Stdout:\n\nOutput: 4.0")
-
- def test_agent_type_output(self):
- inputs = ["2 * 2"]
- output = self.tool(*inputs, sanitize_inputs_outputs=True)
- output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
- self.assertTrue(isinstance(output, output_type))
-
- def test_agent_types_inputs(self):
- inputs = ["2 * 2"]
- _inputs = []
-
- for _input, expected_input in zip(inputs, self.tool.inputs.values()):
- input_type = expected_input["type"]
- if isinstance(input_type, list):
- _inputs.append(
- [
- AGENT_TYPE_MAPPING[_input_type](_input)
- for _input_type in input_type
- ]
- )
- else:
- _inputs.append(AGENT_TYPE_MAPPING[input_type](_input))
-
- # Should not raise an error
- output = self.tool(*inputs, sanitize_inputs_outputs=True)
- output_type = AGENT_TYPE_MAPPING[self.tool.output_type]
- self.assertTrue(isinstance(output, output_type))
-
-
class PythonInterpreterTester(unittest.TestCase):
def test_evaluate_assign(self):
code = "x = 3"
From 77f656c80d9995a3b28d1ecd1a12372766822994 Mon Sep 17 00:00:00 2001
From: Aggelos Kyriakoulis
Date: Tue, 14 Jan 2025 18:21:38 +0200
Subject: [PATCH 8/9] Implemented support for ast.Pass in the interpeter.
(#189)
---
src/smolagents/local_python_executor.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py
index 3c545cae6..7f1c67e94 100644
--- a/src/smolagents/local_python_executor.py
+++ b/src/smolagents/local_python_executor.py
@@ -1334,6 +1334,8 @@ def evaluate_ast(
if expression.value
else None
)
+ elif isinstance(expression, ast.Pass):
+ return None
else:
# For now we refuse anything else. Let's add things as we need them.
raise InterpreterError(f"{expression.__class__.__name__} is not supported.")
From ce1cd6d9066d2150ec6b5d72200f49e6171901f2 Mon Sep 17 00:00:00 2001
From: Aymeric Roucher <69208727+aymeric-roucher@users.noreply.github.com>
Date: Tue, 14 Jan 2025 19:27:07 +0100
Subject: [PATCH 9/9] Support pandas' iloc indexer (#191)
---
src/smolagents/local_python_executor.py | 3 +++
tests/test_python_interpreter.py | 16 ++++++++++++++++
2 files changed, 19 insertions(+)
diff --git a/src/smolagents/local_python_executor.py b/src/smolagents/local_python_executor.py
index 7f1c67e94..e54e0594c 100644
--- a/src/smolagents/local_python_executor.py
+++ b/src/smolagents/local_python_executor.py
@@ -685,6 +685,9 @@ def evaluate_subscript(
if isinstance(value, pd.core.indexing._LocIndexer):
parent_object = value.obj
return parent_object.loc[index]
+ if isinstance(value, pd.core.indexing._iLocIndexer):
+ parent_object = value.obj
+ return parent_object.iloc[index]
if isinstance(value, (pd.DataFrame, pd.Series, np.ndarray)):
return value[index]
elif isinstance(value, pd.core.groupby.generic.DataFrameGroupBy):
diff --git a/tests/test_python_interpreter.py b/tests/test_python_interpreter.py
index 75a146e00..58f250cfc 100644
--- a/tests/test_python_interpreter.py
+++ b/tests/test_python_interpreter.py
@@ -808,6 +808,7 @@ def test_pandas(self):
)
assert np.array_equal(result.values[0], [104, 1])
+ # Test groupby
code = """import pandas as pd
data = pd.DataFrame.from_dict([
{"Pclass": 1, "Survived": 1},
@@ -821,6 +822,21 @@ def test_pandas(self):
)
assert result.values[1] == 0.5
+ # Test loc and iloc
+ code = """import pandas as pd
+data = pd.DataFrame.from_dict([
+ {"Pclass": 1, "Survived": 1},
+ {"Pclass": 2, "Survived": 0},
+ {"Pclass": 2, "Survived": 1}
+])
+survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
+survival_rate_biased = data.loc[data['Survived']==1]['Survived'].mean()
+survival_rate_sorted = data.sort_values(by='Survived', ascending=False).iloc[0]
+"""
+ result, _ = evaluate_python_code(
+ code, {}, state={}, authorized_imports=["pandas"]
+ )
+
def test_starred(self):
code = """
from math import radians, sin, cos, sqrt, atan2