Skip to content

Commit c3cf720

Browse files
committed
update requirements
1 parent 29f7af8 commit c3cf720

File tree

10 files changed

+94
-12
lines changed

10 files changed

+94
-12
lines changed

agents/agents/agents/templates/templates.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,8 @@ def _convert_single_message_to_hf_format(self, message: Dict) -> Dict:
523523
pass
524524

525525
def convert_to_hf_format_messages(self, messages: List[Dict]) -> List[Dict]:
526+
if messages is None:
527+
return None
526528
role_label, content_label = self._detect_labels(messages)
527529
hf_messages = []
528530
for message in messages:
@@ -628,6 +630,16 @@ def get_template(name: str) -> Template:
628630
)
629631
)
630632

633+
register_template(
634+
Template(
635+
name="deepseek-prover",
636+
system_template="{system_message}\n",
637+
system_message="You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.",
638+
user_template="### Instruction:\n{content}\n",
639+
assistant_template="### Response:\n{content}\n<|EOT|>\n",
640+
stop_words=["<|EOT|>"],
641+
)
642+
)
631643

632644
# register_conv_template(
633645
# Template(

agents/agents/envs/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from .alfworld_env import ALFWorldEnv
33
from .webshop_text_env import WebAgentTextEnv
44
from .scienceworld_env import ScienceWorldEnv
5+
from .manager.enroot import clear_enroot_containers
56

67
__all__ = [
78
"PythonSandboxEnv",
89
"ALFWorldEnv",
910
"WebAgentTextEnv",
1011
"ScienceWorldEnv",
12+
"clear_enroot_containers",
1113
]

agents/agents/envs/manager/enroot.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,10 @@ def api(self):
348348
def close(self):
349349
pass
350350

351-
352351
def from_env() -> EnrootClient:
353352
return EnrootClient()
353+
354+
355+
def clear_enroot_containers() -> None:
356+
_run_enroot(["remove", "--force", "$(enroot list)"], capture=False)
357+
print("Cleared all enroot containers")

agents/agents/envs/manager/env_manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ async def release(cls,
6060
cls._acquired_envs.pop(id)
6161
await cls._pools[key].release(env, finished=finished)
6262
else:
63-
warnings.warn(f"Environment {id} not found during release. Skipped it.")
63+
# This should be generally safe to skip
64+
# warnings.warn(f"Environment {id} not found during release. Skipped it.")
65+
pass
6466

6567
@classmethod
6668
async def reset(cls, env: BaseEnv, env_args: dict | None = None):

agents/agents/tools/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .src.search.async_dense_retriever import asyncdense_retrieve
1212
# from .src.search.http_retriever import http_retrieve
1313
from .src.webshop.tools import webshop_browser
14-
from .src.react.tools import answer
14+
from .src.react.tools import answer_qa, answer_math
1515
from .src.search.async_dense_retriever import asyncdense_retrieve
1616
from .src.scienceworld.tools import scienceworld_explorer
1717

@@ -25,7 +25,8 @@
2525
"alfworld_reset",
2626
"alfworld_get_admissible_commands",
2727
"google_search_serper",
28-
"answer",
28+
"answer_qa",
29+
"answer_math",
2930
"hallucination_tool",
3031
"invalid_input_tool",
3132
"submit_tool_call",
@@ -49,7 +50,8 @@
4950
"alfworld_get_task_objective": alfworld_get_task_objective,
5051
"alfworld_get_admissible_commands": alfworld_get_admissible_commands,
5152
"google_search": google_search_serper,
52-
"answer": answer,
53+
"answer_qa": answer_qa,
54+
"answer_math": answer_math,
5355
"hallucination_tool": hallucination_tool,
5456
"invalid_input_tool": invalid_input_tool,
5557
"dense_retrieve": dense_retrieve
Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
from ...tool_base import tool
22

33
@tool(name="answer", description="Give the final answer. The answer should be put inside the \\boxed{} tag.", status="finish")
4-
def answer(answer: str):
4+
def answer_math(answer: str):
55
"""
66
A helper tool to give the final answer. The answer should be put inside the \\boxed{} tag.
77
Args:
88
answer (str): The final answer to the question.
99
Returns:
1010
str: The final answer to the question.
1111
"""
12+
return str(answer)
13+
14+
15+
@tool(name="answer", description="Give the final answer. The answer should be a simple, short, and direct.", status="finish")
16+
def answer_qa(answer: str):
17+
"""
18+
A helper tool to give the final answer. The answer should be a simple, short, and direct.
19+
Args:
20+
answer (str): The final answer to the question.
21+
Returns:
22+
str: The final answer to the question.
23+
"""
1224
return str(answer)

agents/agents/tools/src/search/async_dense_retriever.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from .faiss_indexer import Indexer
2828
from ...tool_base import tool
29-
from ....__init__ import AGENT_DATA_DIR
29+
from ....__init__ import AGENT_CACHE_DIR
3030
import builtins
3131
import numpy as np
3232
import importlib
@@ -575,7 +575,7 @@ def _ensure_corpus_loaded(self):
575575
max_length=4096,
576576
)
577577
async def asyncdense_retrieve(query: str):
578-
global GLOBAL_RETRIEVER, AGENT_DATA_DIR
578+
global GLOBAL_RETRIEVER, AGENT_CACHE_DIR
579579

580580
if not query.startswith("query:"):
581581
query = "query: " + query
@@ -584,10 +584,10 @@ async def asyncdense_retrieve(query: str):
584584
if GLOBAL_RETRIEVER is None:
585585
GLOBAL_RETRIEVER = DenseRetriever(
586586
corpus_file=os.path.join(
587-
AGENT_DATA_DIR, "search", "wiki-18.jsonl"
587+
AGENT_CACHE_DIR, "data", "search", "wiki-18.jsonl"
588588
),
589589
index_file=os.path.join(
590-
AGENT_DATA_DIR, "search", "e5_Flat.index"
590+
AGENT_CACHE_DIR, "data", "search", "e5_Flat.index"
591591
),
592592
)
593593

agents/agents/tools/utils/data.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import gzip
2+
import os
3+
from huggingface_hub import hf_hub_download
4+
import shutil
5+
from ... import AGENT_CACHE_DIR
6+
7+
def download_tool_data(tool_name: str):
8+
"""
9+
This is used to download tool-related data.
10+
"""
11+
global AGENT_CACHE_DIR
12+
if tool_name == "asyncdense_retrieve":
13+
data_dir = os.path.join(AGENT_CACHE_DIR, "data", "search")
14+
corpus_file = os.path.join(data_dir, "wiki-18.jsonl")
15+
index_file = os.path.join(data_dir, "e5_Flat.index")
16+
if not os.path.exists(corpus_file):
17+
if not os.path.exists(os.path.join(data_dir, "wiki-18.jsonl.gz")):
18+
repo_id = "PeterJinGo/wiki-18-corpus"
19+
hf_hub_download(
20+
repo_id=repo_id,
21+
filename="wiki-18.jsonl.gz",
22+
repo_type="dataset",
23+
local_dir=data_dir,
24+
)
25+
# Unzip the file
26+
print(f"Unzipping {os.path.join(data_dir, 'wiki-18.jsonl.gz')}")
27+
gz_path = os.path.join(data_dir, "wiki-18.jsonl.gz")
28+
if os.path.exists(gz_path):
29+
with gzip.open(gz_path, 'rb') as f_in, open(corpus_file, 'wb') as f_out:
30+
shutil.copyfileobj(f_in, f_out)
31+
32+
if not os.path.exists(index_file):
33+
if not os.path.exists(os.path.join(data_dir, "part_aa")):
34+
repo_id = "PeterJinGo/wiki-18-e5-index"
35+
for file in ["part_aa", "part_ab"]:
36+
hf_hub_download(
37+
repo_id=repo_id,
38+
filename=file, # e.g., "e5_Flat.index"
39+
repo_type="dataset",
40+
local_dir=data_dir,
41+
)
42+
print(f"Concatenating {os.path.join(data_dir, 'part_*')} > {os.path.join(data_dir, 'e5_Flat.index')}")
43+
os.system(f"cat {os.path.join(data_dir, 'part_*')} > {os.path.join(data_dir, 'e5_Flat.index')}")
44+
45+
if __name__ == "__main__":
46+
download_tool_data("asyncdense_retrieve")

agents/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ redis
66
docker
77
openai
88
faiss-cpu
9-
vllm==0.9.1
9+
vllm==0.9.2
1010
termcolor
1111
tenacity
1212
nest-asyncio
1313
pytest
1414
pytest-asyncio
1515
bs4
1616
qwen_vl_utils
17+
onnxruntime

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ include = ["agents", "verl"]
88
[project]
99
name = "AgentFly"
1010
version = "0.0.1"
11-
description = "A simple Python project"
11+
description = "Agent reinforcement learning framework."
1212
readme = "README.md"
1313
requires-python = ">=3.10,<3.11"
1414
license = { text = "Apache-2.0" }
@@ -27,6 +27,7 @@ dependencies = [
2727
"tenacity",
2828
"bs4",
2929
"qwen_vl_utils",
30+
"onnxruntime",
3031
]
3132

3233
[project.optional-dependencies]

0 commit comments

Comments
 (0)