Skip to content

Commit 39dd1d4

Browse files
authored
Enhance AgentScope Workflow Adapter (agentscope-ai#465)
1 parent a348d96 commit 39dd1d4

File tree

18 files changed

+336
-80
lines changed

18 files changed

+336
-80
lines changed

docs/sphinx_doc/source/tutorial/example_tinker_backend.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,4 @@ synchronizer:
208208

209209
Since Llama-3.2-3B is a base (non-instruct-tuned) model, it has limited ability to follow formatting instructions. Additionally, we trained for only **one epoch**. As a result, both backends achieved final rewards just slightly above 0.1. Nonetheless, the training curves show a clear upward trend in reward, indicating successful learning. The results are visualized below:
210210

211-
![Training Rewards on GSM8K](../../docs/sphinx_doc/assets/tinker-gsm8k.png)
211+
![Training Rewards on GSM8K](../../assets/tinker-gsm8k.png)

docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,4 +207,4 @@ synchronizer:
207207

208208
由于 Llama-3.2-3B 是基础(非指令微调)模型,其格式化指令跟随能力有限,且本实验仅训练了**一个 epoch**。因此,两种后端的最终 reward 都略高于 0.1。但训练曲线显示 reward 呈明显上升趋势,表明模型已成功学习。结果可视化如下:
209209

210-
![GSM8K 训练奖励曲线](../../docs/sphinx_doc/assets/tinker-gsm8k.png)
210+
![GSM8K 训练奖励曲线](../../assets/tinker-gsm8k.png)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ megatron = [
8282
"mbridge>=0.13.0",
8383
]
8484
tinker = [
85-
"tinker", # tinker requires python>=3.11
85+
"tinker; python_version >= '3.11'",
8686
]
8787

8888
doc = [

scripts/docker/Dockerfile.uv

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
#
88
# Note:
99
# 1. This Dockerfile uses 'uv' to create a virtual environment for better package management. If you want a simpler setup without 'uv', please refer to `scripts/docker/Dockerfile`.
10-
# 2. Make sure to use `uv pip` to install packages within the virtual environment.
10+
# 2. The uv virtual environment is created at `/opt/venv`, use `source /opt/venv/bin/activate` to activate it.
11+
# 3. Make sure to use `uv pip` to install packages within the virtual environment.
1112

1213
FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04
1314

tests/buffer/file_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,11 @@ def setUp(self):
109109
name="test_buffer", storage_type=StorageType.FILE.value
110110
)
111111
self.config.check_and_update()
112-
ray.init(ignore_reinit_error=True, runtime_env={"env_vars": self.config.get_envs()})
112+
ray.init(
113+
ignore_reinit_error=True,
114+
runtime_env={"env_vars": self.config.get_envs()},
115+
namespace="trinity_unittest",
116+
)
113117
os.makedirs(self.config.buffer.cache_dir, exist_ok=True)
114118
file_path = self.config.buffer.trainer_input.experience_buffer.path
115119
if os.path.exists(file_path):

tests/common/vllm_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,6 @@ async def test_generate(self):
135135
await prepare_engines(self.engines, self.auxiliary_engines)
136136
await self.model_wrapper.prepare()
137137
self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path)
138-
self.assertEqual(await self.model_wrapper.model_path_async, self.config.model.model_path)
139138
prompts = ["Hello, world!", "Hello, my name is"]
140139
n = self.config.algorithm.repeat_times
141140
if self.use_async:

tests/explorer/scheduler_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ async def run_async(self) -> List[Experience]:
171171

172172
@ray.remote
173173
class DummyModel(InferenceModel):
174+
def __init__(self):
175+
from trinity.common.config import InferenceModelConfig
176+
177+
super().__init__(InferenceModelConfig(model_path="dummy_model"))
178+
174179
def sync_model(self, model_version, update_weight_args_list):
175180
return True
176181

tests/explorer/workflow_test.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from tests.common.vllm_test import CHAT_TEMPLATE
1717
from tests.tools import get_model_path, get_template_config, get_unittest_dataset_config
18+
from trinity.common.config import InferenceModelConfig
1819
from trinity.common.experience import EID, Experience
1920
from trinity.common.models import create_inference_models
2021
from trinity.common.models.model import ModelWrapper
@@ -551,7 +552,7 @@ async def monitor_routine():
551552

552553

553554
class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase):
554-
async def test_adapter(self):
555+
async def test_adapter_v0(self):
555556
try:
556557
from agentscope.model import TrinityChatModel
557558
except ImportError:
@@ -586,6 +587,58 @@ async def as_workflow_func(task, model) -> float:
586587
self.assertEqual(result[1].reward, 0.1)
587588
self.assertEqual(result[1].prompt_length, 2)
588589

590+
async def test_adapter_v1(self):
591+
try:
592+
from agentscope.model import ChatModelBase
593+
from agentscope.tuner import JudgeOutput, WorkflowOutput
594+
except ImportError:
595+
self.skipTest("agentscope >= 1.0.12 is not installed")
596+
597+
async def as_workflow_func(task, model) -> WorkflowOutput:
598+
self.assertIsInstance(task, dict)
599+
self.assertIsInstance(model, ChatModelBase)
600+
return WorkflowOutput(
601+
reward=task["reward"],
602+
response=task["reward"],
603+
metrics={"workflow_metric_1": 0.0},
604+
)
605+
606+
async def as_judge_func(task, response) -> JudgeOutput:
607+
self.assertIsInstance(task, dict)
608+
self.assertIsInstance(response, float)
609+
return JudgeOutput(
610+
reward=response,
611+
metrics={"judge_metric_1": 1.0},
612+
)
613+
614+
model = MagicMock()
615+
openai_client = MagicMock()
616+
openai_client.model_path = "Qwen/Qwen3-8B"
617+
model.get_openai_async_client.return_value = openai_client
618+
model.extract_experience_from_history.return_value = [
619+
Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, logprobs=Tensor([0.1, 0.2])),
620+
]
621+
622+
as_adapter_cls = WORKFLOWS.get("agentscope_workflow_adapter_v1")
623+
as_adapter = as_adapter_cls(
624+
task=Task(
625+
raw_task={"reward": 0.2},
626+
workflow_args={
627+
"workflow_func": as_workflow_func,
628+
"judge_func": as_judge_func,
629+
},
630+
),
631+
model=model,
632+
)
633+
result = await as_adapter.run_async()
634+
self.assertEqual(len(result), 1)
635+
self.assertEqual(result[0].reward, 0.2)
636+
self.assertEqual(result[0].prompt_length, 1)
637+
metrics = result[-1].metrics
638+
self.assertEqual(len(metrics), 2)
639+
self.assertEqual(metrics["workflow_metric_1"], 0.0)
640+
self.assertEqual(metrics["judge_metric_1"], 1.0)
641+
589642

590643
class DummyModelWrapper:
591644
def __init__(self, model, engine_type="vllm", **kwargs):
@@ -678,9 +731,13 @@ async def mock_get_api_server_url_remote():
678731
async def mock_get_model_version_remote():
679732
return 1
680733

734+
async def mock_get_model_config_remote():
735+
return InferenceModelConfig(model_path="dummy_model")
736+
681737
model = MagicMock()
682738
model.get_api_server_url.remote = MagicMock(side_effect=mock_get_api_server_url_remote)
683739
model.get_model_version.remote = MagicMock(side_effect=mock_get_model_version_remote)
740+
model.get_model_config.remote = MagicMock(side_effect=mock_get_model_config_remote)
684741

685742
runner = WorkflowRunner(
686743
config,

tests/trainer/trainer_test.py

Lines changed: 150 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import unittest
1010
from copy import deepcopy
1111
from datetime import datetime
12+
from typing import Dict
1213
from unittest import mock
1314

1415
import ray
@@ -1134,10 +1135,10 @@ async def test_serve_with_trainer(self): # noqa: C901
11341135
+ metrics["rollout/model_1/total_request_count"],
11351136
metrics["rollout/total_experience_count"],
11361137
)
1137-
# at least updated to version 2
1138+
# at least updated to version 1
11381139
await asyncio.sleep(5) # wait for model version update
1139-
self.assertGreaterEqual(metrics["rollout/model_0/model_version"], 2)
1140-
self.assertGreaterEqual(metrics["rollout/model_1/model_version"], 2)
1140+
self.assertGreaterEqual(metrics["rollout/model_0/model_version"], 1)
1141+
self.assertGreaterEqual(metrics["rollout/model_1/model_version"], 1)
11411142
# check final checkpoint
11421143
_, step_num = get_checkpoint_dir_with_step_num(
11431144
checkpoint_root_path=serve_config.checkpoint_job_dir,
@@ -1433,3 +1434,149 @@ def test_trainer(self):
14331434
def tearDown(self):
14341435
# remove dir only when the test passed
14351436
shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True)
1437+
1438+
1439+
@unittest.skip("Require agentscope >= 1.0.12")
1440+
class AgentScopeTunerTest(unittest.IsolatedAsyncioTestCase):
1441+
def setUp(self) -> None:
1442+
ray.init(ignore_reinit_error=True)
1443+
1444+
def tearDown(self) -> None:
1445+
ray.shutdown(_exiting_interpreter=True)
1446+
1447+
def test_agentscope_tuner(self):
1448+
try:
1449+
from agentscope.agent import ReActAgent
1450+
from agentscope.formatter import OpenAIChatFormatter
1451+
from agentscope.message import Msg
1452+
from agentscope.model import ChatModelBase
1453+
from agentscope.tuner import (
1454+
Algorithm,
1455+
Dataset,
1456+
JudgeOutput,
1457+
TunerChatModel,
1458+
WorkflowOutput,
1459+
tune,
1460+
)
1461+
except ImportError:
1462+
self.skipTest("agentscope >= 1.0.12 is not installed")
1463+
1464+
async def workflow_func(
1465+
task: Dict,
1466+
model: ChatModelBase,
1467+
auxiliary_models: Dict[str, ChatModelBase],
1468+
) -> WorkflowOutput:
1469+
assert isinstance(model, ChatModelBase)
1470+
assert "judge_model" in auxiliary_models
1471+
assert isinstance(auxiliary_models["judge_model"], ChatModelBase)
1472+
agent = ReActAgent(
1473+
name="test_agent",
1474+
model=model,
1475+
sys_prompt="You are a helpful assistant.",
1476+
formatter=OpenAIChatFormatter(),
1477+
)
1478+
st = time.time()
1479+
response = await agent.reply(Msg("user", task["question"], role="user"))
1480+
et = time.time()
1481+
return WorkflowOutput(response=response, metrics={"workflow_time": et - st})
1482+
1483+
async def judge_func(
1484+
task: Dict, response: Msg, auxiliary_models: Dict[str, ChatModelBase]
1485+
) -> JudgeOutput:
1486+
assert "judge_model" in auxiliary_models
1487+
judge_model = auxiliary_models["judge_model"]
1488+
assert isinstance(judge_model, ChatModelBase)
1489+
agent = ReActAgent(
1490+
name="judge_agent",
1491+
model=judge_model,
1492+
sys_prompt="You are a judge to evaluate the correctness of answers.",
1493+
formatter=OpenAIChatFormatter(),
1494+
)
1495+
workflow_text_response = response.get_text_content()
1496+
st = time.time()
1497+
judge_response = await agent.reply(
1498+
Msg(
1499+
"user",
1500+
f"Question: {task['question']}\nAnswer: {workflow_text_response}\nIs the answer correct? Reply with 'Yes' or 'No'.",
1501+
role="user",
1502+
)
1503+
)
1504+
et = time.time()
1505+
judge_response = judge_response.get_text_content()
1506+
if judge_response is not None and "yes" in judge_response.lower():
1507+
is_correct = True
1508+
else:
1509+
is_correct = False
1510+
return JudgeOutput(
1511+
reward=float(is_correct),
1512+
metrics={"judge_time": et - st},
1513+
)
1514+
1515+
gsm8k_dataset = get_unittest_dataset_config("gsm8k")
1516+
1517+
dataset = Dataset(
1518+
path=gsm8k_dataset.path,
1519+
split="train",
1520+
total_steps=2,
1521+
)
1522+
eval_dataset = Dataset(
1523+
path=gsm8k_dataset.path,
1524+
split="test",
1525+
)
1526+
1527+
model = TunerChatModel(
1528+
model_path=get_model_path(),
1529+
max_model_len=4096,
1530+
max_tokens=2048,
1531+
inference_engine_num=2,
1532+
)
1533+
1534+
auxiliary_models = {
1535+
"judge_model": TunerChatModel(
1536+
model_path=get_model_path(),
1537+
max_model_len=8192,
1538+
max_tokens=2048,
1539+
inference_engine_num=2,
1540+
)
1541+
}
1542+
1543+
algorithm = Algorithm(
1544+
algorithm_type="multi_step_grpo",
1545+
batch_size=4,
1546+
group_size=4,
1547+
eval_interval_steps=2,
1548+
save_interval_steps=2,
1549+
)
1550+
1551+
tune(
1552+
workflow_func=workflow_func,
1553+
judge_func=judge_func,
1554+
train_dataset=dataset,
1555+
eval_dataset=eval_dataset,
1556+
model=model,
1557+
auxiliary_models=auxiliary_models,
1558+
algorithm=algorithm,
1559+
)
1560+
# check checkpoint dir in `./checkpoints/AgentScope/Experiment-<timestamp>`
1561+
self.assertTrue(os.path.exists("./checkpoints/AgentScope"))
1562+
exp_dirs = os.listdir("./checkpoints/AgentScope")
1563+
self.assertGreaterEqual(len(exp_dirs), 1)
1564+
latest_exp_dir = sorted(exp_dirs)[-1]
1565+
exp_dir_path = os.path.join("./checkpoints/AgentScope", latest_exp_dir)
1566+
_, step_num = get_checkpoint_dir_with_step_num(
1567+
checkpoint_root_path=exp_dir_path,
1568+
trainer_type="verl",
1569+
)
1570+
self.assertEqual(step_num, 2)
1571+
# check tensorboard
1572+
parser = TensorBoardParser(os.path.join(exp_dir_path, "monitor", "tensorboard"))
1573+
rollout_metrics = parser.metric_list("rollout")
1574+
self.assertIn("rollout/workflow_time/mean", rollout_metrics)
1575+
self.assertIn("rollout/judge_time/mean", rollout_metrics)
1576+
self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2)
1577+
eval_metrics = parser.metric_list("eval")
1578+
self.assertGreater(len(eval_metrics), 0)
1579+
self.assertEqual(parser.metric_max_step(eval_metrics[0]), 2)
1580+
actor_metrics = parser.metric_list("actor")
1581+
self.assertGreater(len(actor_metrics), 0)
1582+
self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2)

trinity/common/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ class TasksetConfig:
261261
total_epochs: int = 1 # automatically set
262262
# ! DO NOT SET, automatically set from buffer.total_steps
263263
total_steps: Optional[int] = None # automatically set
264+
# ! DO NOT SET, automatically set form ray_namespace
265+
ray_namespace: Optional[str] = None
264266

265267
def to_storage_config(self) -> StorageConfig:
266268
storage_config = StorageConfig(
@@ -285,6 +287,7 @@ def to_storage_config(self) -> StorageConfig:
285287
batch_size=self.batch_size,
286288
total_epochs=self.total_epochs,
287289
total_steps=self.total_steps,
290+
ray_namespace=self.ray_namespace,
288291
)
289292
return storage_config
290293

@@ -324,6 +327,8 @@ class ExperienceBufferConfig:
324327
total_epochs: int = 1 # automatically set
325328
# ! DO NOT SET, automatically set from buffer.total_steps
326329
total_steps: Optional[int] = None # automatically set
330+
# ! DO NOT SET, automatically set form ray_namespace
331+
ray_namespace: Optional[str] = None
327332

328333
def to_storage_config(self) -> StorageConfig:
329334
storage_config = StorageConfig(
@@ -345,6 +350,7 @@ def to_storage_config(self) -> StorageConfig:
345350
tokenizer_path=self.tokenizer_path,
346351
total_epochs=self.total_epochs,
347352
total_steps=self.total_steps,
353+
ray_namespace=self.ray_namespace,
348354
)
349355
return storage_config
350356

@@ -546,6 +552,7 @@ class InferenceModelConfig:
546552

547553
# ! DO NOT SET
548554
bundle_indices: str = ""
555+
ray_namespace: Optional[str] = None
549556

550557
# ! DO NOT SET, automatically set from model.lora_configs
551558
enable_lora: bool = False

0 commit comments

Comments
 (0)