Skip to content

Commit d67b73e

Browse files
authored
Merge pull request #9 from Agent-One-Lab/agents
Code Refactoring
2 parents f0fcc23 + 5572de4 commit d67b73e

File tree

270 files changed

+2329
-1012
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

270 files changed

+2329
-1012
lines changed

.github/workflows/cpu_tests.yml

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,19 @@ on:
99
jobs:
1010
test-envs:
1111
runs-on: ubuntu-latest
12-
timeout-minutes: 15
12+
timeout-minutes: 20
1313

1414
strategy:
1515
matrix:
16+
python-version: ["3.10"]
1617
test-file:
17-
- tests/unit/envs/ --ignore tests/unit/envs/test_webshop_text_env.py --ignore tests/unit/envs/test_alfworld_env.py
18-
- tests/unit/envs/test_alfworld_env.py
18+
- agentfly/tests/unit/envs/ --ignore agentfly/tests/unit/envs/test_webshop_text_env.py --ignore agentfly/tests/unit/envs/test_alfworld_env.py
19+
- agentfly/tests/unit/envs/test_alfworld_env.py
1920
# - tests/unit/envs/test_webshop_text_env.py # TODO: add minimal variant of the webshop docker image
20-
- tests/unit/rewards/ --ignore tests/unit/rewards/test_env_id.py --ignore tests/unit/rewards/test_webshop_reward.py
21-
- tests/unit/tools/ --ignore tests/unit/tools/test_webshop_tool.py --ignore tests/unit/tools/test_scienceworld_tool.py --ignore tests/unit/tools/test_code_tool.py
22-
- tests/unit/tools/test_scienceworld_tool.py
23-
- tests/unit/tools/test_code_tool.py
21+
- agentfly/tests/unit/rewards/ --ignore agentfly/tests/unit/rewards/test_env_id.py --ignore agentfly/tests/unit/rewards/test_webshop_reward.py
22+
- agentfly/tests/unit/tools/ --ignore agentfly/tests/unit/tools/test_webshop_tool.py --ignore agentfly/tests/unit/tools/test_scienceworld_tool.py --ignore agentfly/tests/unit/tools/test_code_tool.py
23+
- agentfly/tests/unit/tools/test_scienceworld_tool.py
24+
- agentfly/tests/unit/tools/test_code_tool.py
2425
# - test/unit/agents/ # TODO: recheck this
2526

2627
steps:
@@ -34,6 +35,13 @@ jobs:
3435
with:
3536
python-version: '3.10'
3637

38+
- name: Verify Python
39+
run: |
40+
which python
41+
python --version
42+
which pip
43+
python -m pip --version
44+
3745
- name: Free up disk space
3846
run: |
3947
echo "Before cleanup:"
@@ -50,8 +58,9 @@ jobs:
5058
5159
- name: Install dependencies (main repo)
5260
run: |
53-
pip install -r agents/requirements.txt
61+
pip install -e .
5462
pip install datasets
63+
pip install -e '.[dev]' --no-build-isolation
5564
5665
- name: Cache AgentFly cache
5766
uses: actions/cache@v4
@@ -75,5 +84,4 @@ jobs:
7584
7685
- name: Run unit test (${{ matrix.test-file }})
7786
run: |
78-
cd agents
79-
python -m pytest ${{ matrix.test-file }}
87+
pytest ${{ matrix.test-file }}

.gitignore

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ tests/e2e/toy_examples/deepspeed/synchronous/output.txt
122122

123123
# data
124124
*.parquet
125-
agents/agents/data/*
125+
agentfly/agents/data/*
126126

127127
# local logs
128128
logs
@@ -131,10 +131,12 @@ outputs
131131
*.out
132132

133133
# Notebooks
134-
agents/tests/*.ipynb
135-
agents/tests/*.jpg
136-
agents/tests/*.jpeg
137-
agents/tests/*.png
138-
agents/agents/*.ipynb
139-
agents/temp/
134+
agentfly/tests/*.ipynb
135+
agentfly/tests/*.jpg
136+
agentfly/tests/*.jpeg
137+
agentfly/tests/*.png
138+
agentfly/agents/*.ipynb
139+
agentfly/temp/
140+
agentfly/data/
141+
*.ipynb
140142

README.md

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,30 +48,28 @@ bash install.sh # Assume conda with python3.10.x
4848
```
4949
**Option 2**: Customized Installation
5050

51-
Clone and initialize the project:
52-
```bash
53-
git clone https://github.com/Agent-One-Lab/AgentFly
54-
cd AgentFly
55-
git submodule init
56-
git submodule update
57-
```
58-
Basic python packages installation:
59-
```bash
60-
pip install -e .
61-
pip install -e '.[verl]' --no-build-isolation
62-
```
63-
Optionally, some tools actually require some additional dependencies:
64-
65-
Some of our tools & environments are managed by *enroot* backend. To use them, please install [enroot](https://github.com/NVIDIA/enroot/blob/master/doc/installation.md) (sudo required). Such tools include code_interpreter, retrieval, webshop, alfworld, sciencworld.
66-
67-
Search requires redis to cache results, an optional way to install with conda:
68-
```bash
69-
conda install conda-forge::redis-server==7.4.0
70-
```
51+
Please refer to [installation.md](docs/start/installation.md) for custmoized installation.
7152

7253
## Quick Start
73-
```
54+
```python
55+
# Really small example to build an agent and run
56+
from agentfly.agents import HFAgent
57+
from agentfly.tools import calculate
58+
messages = [{"role": "user", "content": "What is the result of 1 + 1?"}]
59+
agent = HFAgent(
60+
model_name_or_path="Qwen/Qwen2.5-3B-Instruct",
61+
tools=[calculate],
62+
template="qwen2.5",
63+
backend="async_vllm",
64+
)
65+
await agent.run(
66+
messages=messages,
67+
max_turns=3,
68+
num_chains=1
69+
)
7470

71+
trajectories = agent.trajectories
72+
print(trajectories)
7573
```
7674

7775
## Features
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
from .specialized.code_agent import CodeAgent
33
from .specialized.think_agent import ThinkAgent
44
from .specialized.gui_agent import GUIAgent
5-
6-
__all__ = ["ReactAgent", "CodeAgent", "ThinkAgent", "GUIAgent"]
5+
from .specialized.hf_agent import HFAgent
6+
from .templates.utils import process_vision_info, tokenize_conversation, tokenize_conversations
Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
from collections import defaultdict
33
import json
4-
4+
from .utils.messages import MessagesList
55
from .templates.templates import get_template
66
from ..__init__ import AGENT_DATA_DIR
77
from .llm_backend import AsyncVLLMBackend, AsyncVerlBackend, ClientBackend, TransformersBackend, VLLMBackend
@@ -11,7 +11,7 @@
1111
import torch
1212
from .templates.utils import tokenize_conversations
1313
from .templates.vision_processor import is_vision_template
14-
from .chain.chain_base import ChainGeneration
14+
from .chain.chain_base import ChainRollout
1515
import os
1616
import transformers
1717
import warnings
@@ -27,7 +27,7 @@
2727

2828
Logger = logging.getLogger(__name__)
2929

30-
class BaseAgent(ChainGeneration, ABC):
30+
class BaseAgent(ChainRollout, ABC):
3131
"""
3232
Base class for all agents. All agent should subclass this class. A customized agent can implement the following methods:
3333
@@ -43,7 +43,7 @@ def __init__(
4343
system_prompt: str = None,
4444
tools: List = None,
4545
max_length: int=8192,
46-
backend: str = "transformers",
46+
backend: str = "async_vllm",
4747
backend_config: Any = None,
4848
reward_fn: Callable = None,
4949
log_file: str = "agent",
@@ -156,6 +156,47 @@ def _init_llm_engine(self, model_name_or_path: str, backend: str):
156156

157157
return llm_engine
158158

159+
def _preprocess_messages(self, messages: List[Dict]):
160+
"""
161+
Do some necessary preprocessings to the messages, such as adding the sytem prompt
162+
Args:
163+
messages: List of messages to preprocess.
164+
165+
Returns:
166+
List of preprocessed messages.
167+
"""
168+
messages_list = MessagesList.from_data(messages)
169+
for messages in messages_list:
170+
if self.system_prompt:
171+
messages.set_system_prompt(self.system_prompt, enforce=False)
172+
173+
return messages_list.to_list()
174+
175+
async def run(self,
176+
messages: Union[List[dict], np.ndarray, Dict],
177+
max_turns: int,
178+
generation_config: Optional[Dict[str, Any]] = None,
179+
**kwargs,
180+
):
181+
"""
182+
This is the main interface for running the agent. It is a wrapper of different
183+
rollout methods, which must be asynchronous. Currently, we only support chain-based rollout.
184+
Args:
185+
messages: List of messages to generate responses for.
186+
max_turns: The maximum number of turns to generate.
187+
generation_config: The generation configuration.
188+
**kwargs: Additional keyword arguments for generation.
189+
190+
"""
191+
processed_messages = self._preprocess_messages(messages)
192+
193+
return await self.run_async(
194+
processed_messages,
195+
max_turns=max_turns,
196+
generation_config=generation_config,
197+
**kwargs,
198+
)
199+
159200
def set_llm_engine(self, llm_engine: Any, tokenizer: Any, processor: Any):
160201
assert self.backend == "async_verl", "Only async verl backend is supported for now"
161202

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .react.react_agent import ReactAgent
88
from .specialized.code_agent import CodeAgent
99
from .specialized.gui_agent import GUIAgent
10+
from .specialized.hf_agent import HFAgent
1011
from ..rewards.reward_base import get_reward_from_name
1112

1213

@@ -167,4 +168,5 @@ def from_pretrained(
167168
AutoAgent.register_agent("code", CodeAgent)
168169
AutoAgent.register_agent("openai", OpenAIAgent)
169170
AutoAgent.register_agent("think", ThinkAgent)
170-
AutoAgent.register_agent("gui", GUIAgent)
171+
AutoAgent.register_agent("gui", GUIAgent)
172+
AutoAgent.register_agent("hf", HFAgent)

0 commit comments

Comments
 (0)