Skip to content

Commit 38ba481

Browse files
authored
Explorer API collects feedback from agent applications (#295)
1 parent 4150cc8 commit 38ba481

File tree

27 files changed

+943
-372
lines changed

27 files changed

+943
-372
lines changed

benchmark/bench.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def check_taskset_path(dataset_name: str, taskset_path: str) -> str:
144144
raise AttributeError(f"{script_filename} is missing 'DEFAULT_DATA_PATH'")
145145
taskset_path = module.DEFAULT_DATA_PATH
146146
taskset_path = os.path.realpath(taskset_path)
147+
if os.path.exists(taskset_path):
148+
return taskset_path
147149

148150
# For frozenlake, check if train.parquet and test.parquet already exist
149151
if dataset_name == "frozenlake":

tests/buffer/experience_storage_test.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tests.tools import RayUnittestBaseAysnc
1111
from trinity.buffer.reader.sql_reader import SQLReader
1212
from trinity.buffer.writer.sql_writer import SQLWriter
13-
from trinity.common.config import ExperienceBufferConfig
13+
from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig
1414
from trinity.common.constants import StorageType
1515
from trinity.common.experience import EID, Experience
1616

@@ -95,6 +95,7 @@ async def test_sql_experience_buffer(self):
9595
max_read_timeout=3,
9696
path=f"sqlite:///{DB_PATH}",
9797
batch_size=self.train_batch_size,
98+
replay_buffer=ReplayBufferConfig(enable=True),
9899
)
99100
config = config.to_storage_config()
100101
writer = SQLWriter(config)
@@ -118,7 +119,7 @@ async def test_sql_experience_buffer(self):
118119
exps = reader.read()
119120
self.assertEqual(len(exps), self.train_batch_size)
120121
for exp in exps:
121-
self.assertEqual(exp.eid.task, cnt)
122+
self.assertEqual(exp.eid.task, str(cnt))
122123
cnt -= 1
123124

124125
# experience buffer support experience reuse
@@ -127,7 +128,7 @@ async def test_sql_experience_buffer(self):
127128
exps = reader.read()
128129
self.assertEqual(len(exps), self.train_batch_size)
129130
for exp in exps:
130-
self.assertEqual(exp.eid.task, cnt)
131+
self.assertEqual(exp.eid.task, str(cnt))
131132
cnt -= 1
132133
self.assertEqual(await writer.release(), 0)
133134

tests/buffer/sql_test.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,31 @@
22

33
import ray
44
import torch
5+
from parameterized import parameterized
56

67
from tests.tools import RayUnittestBaseAysnc
78
from trinity.buffer import get_buffer_reader
89
from trinity.buffer.reader.sql_reader import SQLReader
910
from trinity.buffer.writer.sql_writer import SQLWriter
10-
from trinity.common.config import ExperienceBufferConfig, TasksetConfig
11+
from trinity.common.config import (
12+
ExperienceBufferConfig,
13+
ReplayBufferConfig,
14+
TasksetConfig,
15+
)
1116
from trinity.common.constants import StorageType
1217
from trinity.common.experience import Experience
1318

1419
db_path = os.path.join(os.path.dirname(__file__), "test.db")
1520

1621

1722
class TestSQLBuffer(RayUnittestBaseAysnc):
18-
async def test_sql_exp_buffer_read_write(self) -> None:
23+
@parameterized.expand(
24+
[
25+
(True,),
26+
(False,),
27+
]
28+
)
29+
async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None:
1930
total_num = 8
2031
put_batch_size = 2
2132
read_batch_size = 4
@@ -25,7 +36,10 @@ async def test_sql_exp_buffer_read_write(self) -> None:
2536
path=f"sqlite:///{db_path}",
2637
storage_type=StorageType.SQL.value,
2738
batch_size=read_batch_size,
39+
max_read_timeout=3,
2840
)
41+
if enable_replay:
42+
config.replay_buffer = ReplayBufferConfig(enable=True)
2943
sql_writer = SQLWriter(config.to_storage_config())
3044
sql_reader = SQLReader(config.to_storage_config())
3145
exps = [
@@ -53,13 +67,22 @@ async def test_sql_exp_buffer_read_write(self) -> None:
5367
reward=float(i),
5468
logprobs=torch.tensor([0.1]),
5569
action_mask=torch.tensor([j % 2 for j in range(i + 1)]),
56-
info={"model_version": i},
70+
info={"model_version": i + put_batch_size},
5771
)
5872
for i in range(1, put_batch_size * 2 + 1)
5973
]
6074
)
6175
exps = sql_reader.read(batch_size=put_batch_size * 2)
6276
self.assertEqual(len(exps), put_batch_size * 2)
77+
for exp in exps:
78+
self.assertTrue(exp.info["model_version"] > put_batch_size)
79+
if enable_replay:
80+
# support replay, so we can read all again
81+
exps = sql_reader.read(batch_size=(put_batch_size * 2 + total_num))
82+
self.assertEqual(len(exps), (put_batch_size * 2 + total_num))
83+
# if read more than available, will wait until timeout
84+
with self.assertRaises(StopIteration):
85+
exps = sql_reader.read(batch_size=(put_batch_size * 3 + total_num))
6386
db_wrapper = ray.get_actor("sql-test_buffer")
6487
self.assertIsNotNone(db_wrapper)
6588
self.assertEqual(await sql_writer.release(), 0)

tests/explorer/explorer_test.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55
import os
66
import random
77
import shutil
8-
import unittest
98
from datetime import datetime
109

1110
import httpx
12-
import openai
1311
import ray
1412

1513
from tests.tools import (
@@ -27,6 +25,7 @@
2725
from trinity.common.config import ExperienceBufferConfig, InferenceModelConfig
2826
from trinity.common.constants import StorageType
2927
from trinity.explorer.explorer import Explorer
28+
from trinity.explorer.proxy.client import TrinityClient
3029
from trinity.manager.state_manager import StateManager
3130

3231

@@ -158,8 +157,9 @@ def run_serve(config):
158157
run_stage(config)
159158

160159

161-
def run_agent(base_url, model_path: str):
162-
client = openai.Client(base_url=base_url, api_key="testkey")
160+
def run_agent(proxy_url, model_path: str):
161+
proxy_client = TrinityClient(proxy_url=proxy_url)
162+
openai_client = proxy_client.get_openai_client()
163163
contents = [
164164
"Hello, how are you?",
165165
"What is the capital of China?",
@@ -172,10 +172,11 @@ def run_agent(base_url, model_path: str):
172172
"What is the best way to learn programming?",
173173
"Describe the process of photosynthesis.",
174174
]
175-
response = client.chat.completions.create(
175+
response = openai_client.chat.completions.create(
176176
model=model_path,
177177
messages=[{"role": "user", "content": random.choice(contents)}],
178178
)
179+
proxy_client.feedback(reward=2.0, msg_ids=[response.id])
179180
return response.choices[0].message.content
180181

181182

@@ -191,7 +192,7 @@ def setUp(self):
191192
self.config.explorer.rollout_model.engine_num = 4
192193
self.config.explorer.rollout_model.enable_openai_api = True
193194
self.config.checkpoint_root_dir = get_checkpoint_path()
194-
self.config.explorer.api_port = 8010
195+
self.config.explorer.proxy_port = 8010
195196
self.config.explorer.service_status_check_interval = 30
196197
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
197198
name="experience_buffer",
@@ -201,7 +202,6 @@ def setUp(self):
201202
if multiprocessing.get_start_method(allow_none=True) != "spawn":
202203
multiprocessing.set_start_method("spawn", force=True)
203204

204-
@unittest.skip("Require improvement for agent mode")
205205
async def test_serve(self): # noqa: C901
206206
serve_process = multiprocessing.Process(target=run_serve, args=(self.config,))
207207
serve_process.start()
@@ -238,7 +238,7 @@ async def test_serve(self): # noqa: C901
238238
apps = []
239239
for i in range(task_num):
240240
app_process = multiprocessing.Process(
241-
target=run_agent, args=(server_url + "/v1", self.config.model.model_path)
241+
target=run_agent, args=(server_url, self.config.model.model_path)
242242
)
243243
apps.append(app_process)
244244
app_process.start()
@@ -248,22 +248,20 @@ async def test_serve(self): # noqa: C901
248248
self.assertFalse(app.is_alive())
249249

250250
finish_step = None
251-
251+
proxy_client = TrinityClient(proxy_url=server_url)
252252
for i in range(20):
253-
async with httpx.AsyncClient() as client:
254-
response = await client.get(f"{server_url}/metrics")
255-
self.assertEqual(response.status_code, 200)
256-
metrics = response.json()
257-
metrics_keys = list(metrics.keys())
258-
self.assertIn("explore_step_num", metrics_keys)
259-
self.assertIn("rollout/total_experience_count", metrics_keys)
260-
self.assertIn("rollout/model_0/total_request_count", metrics_keys)
261-
self.assertIn("rollout/model_3/model_version", metrics_keys)
262-
if not finish_step and metrics["rollout/total_experience_count"] == task_num:
263-
finish_step = metrics["explore_step_num"]
264-
if finish_step and metrics["explore_step_num"] >= finish_step + 1:
265-
# wait for one more step to ensure all data are written to buffer
266-
break
253+
metrics = await proxy_client.get_metrics_async()
254+
metrics_keys = list(metrics.keys())
255+
self.assertIn("explore_step_num", metrics_keys)
256+
self.assertIn("rollout/total_experience_count", metrics_keys)
257+
self.assertIn("rollout/model_0/total_request_count", metrics_keys)
258+
self.assertIn("rollout/model_3/model_version", metrics_keys)
259+
if not finish_step and metrics["rollout/total_experience_count"] == task_num:
260+
finish_step = metrics["explore_step_num"]
261+
await proxy_client.commit_async()
262+
if finish_step and metrics["explore_step_num"] >= finish_step + 1:
263+
# wait for one more step to ensure all data are written to buffer
264+
break
267265
await asyncio.sleep(3)
268266

269267
serve_process.terminate()
@@ -277,6 +275,9 @@ async def test_serve(self): # noqa: C901
277275
exps = await buffer_reader.read_async(batch_size=10)
278276
for exp in exps:
279277
self.assertTrue(len(exp.tokens) > 0)
278+
self.assertTrue(len(exp.logprobs) > 0)
279+
self.assertTrue(exp.prompt_length > 0)
280+
self.assertTrue(exp.reward == 2.0)
280281
self.assertEqual(len(exps), task_num)
281282

282283
def tearDown(self):

tests/explorer/proxy_test.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import os
2+
import unittest
3+
import uuid
4+
from typing import List
5+
6+
import torch
7+
8+
from trinity.common.experience import EID, Experience
9+
from trinity.explorer.proxy.recorder import HistoryRecorder
10+
11+
12+
def get_dummy_experience(num: int) -> List[Experience]:
13+
return [
14+
Experience(
15+
eid=EID(suffix=uuid.uuid4().hex[:6]),
16+
tokens=torch.zeros(5),
17+
prompt_length=2,
18+
info={
19+
"model_version": 0,
20+
},
21+
)
22+
for _ in range(num)
23+
]
24+
25+
26+
db_path = os.path.join(os.path.dirname(__file__), "test_recorder.db")
27+
28+
29+
class RecorderTest(unittest.TestCase):
30+
def setUp(self) -> None:
31+
if os.path.exists(db_path):
32+
os.remove(db_path)
33+
34+
def tearDown(self) -> None:
35+
if os.path.exists(db_path):
36+
os.remove(db_path)
37+
38+
def test_recorder(self):
39+
recorder = HistoryRecorder(
40+
# in memory sqlite for testing
41+
db_url="sqlite:///" + db_path,
42+
table_name="experience",
43+
)
44+
self.assertIsInstance(recorder, HistoryRecorder)
45+
# test record history
46+
47+
experiences_1 = get_dummy_experience(3)
48+
recorder.record_history(experiences_1)
49+
# test update reward
50+
msg_ids_1 = [exp.eid.suffix for exp in experiences_1]
51+
experiences_2 = get_dummy_experience(2)
52+
recorder.record_history(experiences_2)
53+
updated_experiences = recorder.update_reward(
54+
reward=1.0, msg_ids=msg_ids_1, run_id=1, task_id="test_task"
55+
)
56+
self.assertEqual(len(updated_experiences), 3)
57+
for exp in updated_experiences:
58+
self.assertEqual(exp.reward, 1.0)
59+
self.assertEqual(exp.eid.run, 1)
60+
self.assertEqual(str(exp.eid.task), "test_task")
61+
# test update reward with non-existing msg_ids
62+
updated_experiences_empty = recorder.update_reward(
63+
reward=2.0, msg_ids=["non_existing_id"], run_id=1, task_id="test_task"
64+
)
65+
self.assertEqual(len(updated_experiences_empty), 0)
66+
# test record history with empty experiences
67+
recorder.record_history([]) # should not raise any exception
68+
# test update reward multiple times
69+
updated_experiences_2 = recorder.update_reward(
70+
reward=3.0,
71+
msg_ids=[exp.eid.suffix for exp in experiences_2],
72+
run_id=2,
73+
task_id="test_task_2",
74+
)
75+
self.assertEqual(len(updated_experiences_2), 2)
76+
for exp in updated_experiences_2:
77+
self.assertEqual(exp.reward, 3.0)
78+
self.assertEqual(exp.eid.run, 2)
79+
self.assertEqual(str(exp.eid.task), "test_task_2")
80+
updated_experiences_3 = recorder.update_reward(
81+
reward=4.0,
82+
msg_ids=[exp.eid.suffix for exp in experiences_2],
83+
run_id=3,
84+
task_id="test_task_3",
85+
)
86+
self.assertEqual(len(updated_experiences_3), 0) # already consumed

tests/explorer/workflow_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ async def test_adapter(self):
555555
try:
556556
from agentscope.model import TrinityChatModel
557557
except ImportError:
558-
self.skipTest("agentscope >= 0.1.6 is not installed")
558+
self.skipTest("agentscope >= 1.0.9 is not installed")
559559

560560
async def as_workflow_func(task, model) -> float:
561561
self.assertIsInstance(task, dict)

tests/manager/synchronizer_test.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,32 +80,23 @@ async def new_finish_explore_step(self, step: int, model_version: int) -> None:
8080

8181
def run_trainer(config: Config, max_steps: int, intervals: List[int]) -> None:
8282
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
83-
try:
84-
trainer_monkey_patch(config, max_steps, intervals)
85-
train(config)
86-
finally:
87-
ray.shutdown(_exiting_interpreter=True)
83+
trainer_monkey_patch(config, max_steps, intervals)
84+
train(config)
8885

8986

9087
def run_explorer(config: Config, max_steps: int, intervals: List[int]) -> None:
9188
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
92-
try:
93-
explorer_monkey_patch(config, max_steps, intervals)
94-
explore(config)
95-
finally:
96-
ray.shutdown(_exiting_interpreter=True)
89+
explorer_monkey_patch(config, max_steps, intervals)
90+
explore(config)
9791

9892

9993
def run_both(
10094
config: Config, max_steps: int, trainer_intervals: List[int], explorer_intervals: List[int]
10195
) -> None:
10296
ray.init(ignore_reinit_error=True, namespace=config.ray_namespace)
103-
try:
104-
trainer_monkey_patch(config, max_steps, trainer_intervals)
105-
explorer_monkey_patch(config, max_steps, explorer_intervals)
106-
both(config)
107-
finally:
108-
ray.shutdown(_exiting_interpreter=True)
97+
trainer_monkey_patch(config, max_steps, trainer_intervals)
98+
explorer_monkey_patch(config, max_steps, explorer_intervals)
99+
both(config)
109100

110101

111102
class BaseTestSynchronizer(unittest.TestCase):
@@ -115,6 +106,7 @@ def setUp(self):
115106

116107
def tearDown(self):
117108
checkpoint_path = get_checkpoint_path()
109+
ray.shutdown(_exiting_interpreter=True)
118110
shutil.rmtree(os.path.join(checkpoint_path, "unittest"))
119111

120112

tests/template/config.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ explorer:
3838
rollout_model:
3939
engine_num: 2
4040
tensor_parallel_size: 1
41-
enable_prefix_caching: false
42-
enforce_eager: true
4341
dtype: bfloat16
4442
seed: 42
4543
trainer:

0 commit comments

Comments
 (0)