Skip to content

Commit b3e0216

Browse files
authored
Support AgentScope Workflow Function (#327)
1 parent 72a8b1e commit b3e0216

File tree

3 files changed

+129
-1
lines changed

3 files changed

+129
-1
lines changed

tests/explorer/workflow_test.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,18 @@
66
from typing import Dict, Optional
77
from unittest.mock import MagicMock
88

9+
import ray
910
from parameterized import parameterized, parameterized_class
1011
from torch import Tensor
1112

1213
from tests.common.vllm_test import CHAT_TEMPLATE
1314
from tests.tools import get_model_path, get_template_config, get_unittest_dataset_config
14-
from trinity.common.experience import EID
15+
from trinity.common.experience import EID, Experience
1516
from trinity.common.models import create_inference_models
1617
from trinity.common.models.model import ModelWrapper
1718
from trinity.common.rewards import RMGalleryFn
1819
from trinity.common.workflows import (
20+
WORKFLOWS,
1921
MathBoxedWorkflow,
2022
MathEvalWorkflow,
2123
MathRMWorkflow,
@@ -489,3 +491,44 @@ def test_multi_turn_workflow(self):
489491
else:
490492
answer = workflow.run()
491493
self.assertEqual(len(answer), 2)
494+
495+
def tearDown(self):
496+
ray.shutdown(_exiting_interpreter=True)
497+
498+
499+
class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase):
500+
@unittest.skip("Waiting for agentscope>=0.1.6")
501+
async def test_adapter(self):
502+
try:
503+
from agentscope.model import TrinityChatModel
504+
except ImportError:
505+
self.skipTest("agentscope >= 0.1.6 is not installed")
506+
507+
async def as_workflow_func(task, model) -> float:
508+
self.assertIsInstance(task, dict)
509+
self.assertIsInstance(model, TrinityChatModel)
510+
return task["reward"]
511+
512+
model = MagicMock()
513+
openai_client = MagicMock()
514+
openai_client.model_path = "Qwen/Qwen3-8B"
515+
model.get_openai_async_client.return_value = openai_client
516+
model.extract_experience_from_history.return_value = [
517+
Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, logprobs=Tensor([0.1, 0.2])),
518+
Experience(tokens=Tensor([3, 4, 5]), prompt_length=2, logprobs=Tensor([0.3])),
519+
]
520+
521+
as_adapter_cls = WORKFLOWS.get("agentscope_workflow_adapter")
522+
as_adapter = as_adapter_cls(
523+
task=Task(
524+
raw_task={"reward": 0.1},
525+
workflow_args={"workflow_func": as_workflow_func},
526+
),
527+
model=model,
528+
)
529+
result = await as_adapter.run_async()
530+
self.assertEqual(len(result), 2)
531+
self.assertEqual(result[0].reward, 0.1)
532+
self.assertEqual(result[0].prompt_length, 1)
533+
self.assertEqual(result[1].reward, 0.1)
534+
self.assertEqual(result[1].prompt_length, 2)

trinity/common/workflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from trinity.common.workflows.agentscope.react.react_workflow import (
44
AgentScopeReActWorkflow,
55
)
6+
from trinity.common.workflows.agentscope_workflow import AgentScopeWorkflowAdapter
67
from trinity.common.workflows.customized_math_workflows import (
78
AsyncMathBoxedWorkflow,
89
MathBoxedWorkflow,
@@ -92,4 +93,5 @@
9293
"AsyncSimpleMMWorkflow",
9394
"SimpleMMWorkflow",
9495
"RubricJudgeWorkflow",
96+
"AgentScopeWorkflowAdapter",
9597
]
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
from typing import Awaitable, Callable, Dict, List, Optional
2+
3+
import openai
4+
5+
from trinity.common.experience import Experience
6+
from trinity.common.models.model import ModelWrapper
7+
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
8+
9+
10+
@WORKFLOWS.register_module("agentscope_workflow_adapter")
11+
class AgentScopeWorkflowAdapter(Workflow):
12+
"""Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow."""
13+
14+
def __init__(
15+
self,
16+
*,
17+
task: Task,
18+
model: ModelWrapper,
19+
auxiliary_models: Optional[List[openai.OpenAI]] = None,
20+
):
21+
"""Initialize the adapter with the task and model."""
22+
try:
23+
from agentscope.model import TrinityChatModel
24+
except ImportError:
25+
raise ImportError(
26+
"This workflow requires agentscope >= 0.1.6, please install "
27+
"it via `pip install agentscope>=0.1.6`",
28+
)
29+
30+
super().__init__(
31+
task=task,
32+
model=model,
33+
auxiliary_models=auxiliary_models,
34+
)
35+
self.workflow_func: Callable[
36+
[Dict, TrinityChatModel], Awaitable[float]
37+
] = task.workflow_args.get("workflow_func", None)
38+
39+
if self.workflow_func is None:
40+
raise ValueError(
41+
"The 'workflow_func' is not provided.",
42+
)
43+
44+
self.chat_model: TrinityChatModel = TrinityChatModel(
45+
model.get_openai_async_client(),
46+
)
47+
48+
@property
49+
def asynchronous(self) -> bool:
50+
"""This workflow runs asynchronously."""
51+
return True
52+
53+
@property
54+
def repeatable(self) -> bool:
55+
"""This workflow is not repeatable."""
56+
return False
57+
58+
@property
59+
def resetable(self) -> bool:
60+
"""This workflow cannot be reset."""
61+
return False
62+
63+
def construct_experiences(
64+
self,
65+
reward: float,
66+
) -> List[Experience]:
67+
"""Construct experiences from the agent's interaction history.
68+
69+
Args:
70+
reward (float): The reward value to assign to each experience.
71+
72+
Returns:
73+
List: A list of Experience objects.
74+
"""
75+
exps = self.model.extract_experience_from_history()
76+
for exp in exps:
77+
exp.reward = reward
78+
return exps
79+
80+
async def run_async(self) -> List[Experience]:
81+
"""Run the workflow asynchronously and return experiences."""
82+
reward = await self.workflow_func(self.task.raw_task, self.chat_model) # type: ignore [arg-type]
83+
return self.construct_experiences(reward)

0 commit comments

Comments
 (0)