Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion trinity/common/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
# tool_call
"tool_call_workflow": "trinity.common.workflows.customized_toolcall_workflows.ToolCallWorkflow",
# agentscope
"agentscope_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow",
"agentscope_workflow_adapter": "trinity.common.workflows.agentscope_workflow.AgentScopeWorkflowAdapter",
"agentscope_workflow_adapter_v1": "trinity.common.workflows.agentscope_workflow.AgentScopeWorkflowAdapterV1",
"agentscope_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow",
"agentscope_react_math_workflow": "trinity.common.workflows.envs.agentscope.agentscopev1_react_workflow.AgentScopeReactMathWorkflow",
"as_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow",
"agentscopev0_react_math_workflow": "trinity.common.workflows.envs.agentscope.agentscopev0_react_workflow.AgentScopeV0ReactMathWorkflow",
Expand Down
107 changes: 105 additions & 2 deletions trinity/common/workflows/agentscope_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def __init__(
from agentscope.model import TrinityChatModel
except ImportError:
raise ImportError(
"This workflow requires agentscope >= 0.1.6, please install "
"it via `pip install agentscope>=0.1.6`",
"This workflow requires agentscope >= 1.0.7, please install "
"it via `pip install agentscope>=1.0.7`",
)

super().__init__(
Expand Down Expand Up @@ -72,3 +72,106 @@ async def run_async(self) -> List[Experience]:
"""Run the workflow asynchronously and return experiences."""
reward = await self.workflow_func(self.task.raw_task, self.chat_model) # type: ignore [arg-type]
return self.construct_experiences(reward)


class AgentScopeWorkflowAdapterV1(Workflow):
"""A more general adapter to wrap agentscope trainable workflow and judge functions into a Trinity Workflow."""

is_async: bool = True

def __init__(
self,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[ModelWrapper]] = None,
):
"""Initialize the adapter with the task and model."""
try:
from agentscope.model import TrinityChatModel
except ImportError:
raise ImportError(
"This workflow requires agentscope >= 1.0.11, please install "
"it via `pip install agentscope>=1.0.11`",
)

super().__init__(
task=task,
model=model,
auxiliary_models=auxiliary_models,
)
self.workflow_func = task.workflow_args.get("workflow_func", None)
self.judge_func = task.workflow_args.get("judge_func", None)

if self.workflow_func is None:
raise ValueError(
"The 'workflow_func' is not provided.",
)

self.chat_model: TrinityChatModel = TrinityChatModel(
model.get_openai_async_client(),
generate_kwargs={
"temperature": self.task.rollout_args.temperature,
"top_p": self.task.rollout_args.top_p,
"max_tokens": self.task.rollout_args.max_tokens or 4096,
"logprobs": True,
"top_logprobs": self.task.rollout_args.logprobs,
},
)
self.auxiliary_chat_models = [
TrinityChatModel(
aux_model,
)
for aux_model in (self.auxiliary_models or [])
]

def construct_experiences(
self,
reward: float,
metrics: Dict,
) -> List[Experience]:
"""Construct experiences from the agent's interaction history.

Args:
reward (float): The reward value to assign to each experience.

Returns:
List: A list of Experience objects.
"""
exps = self.model.extract_experience_from_history()
for exp in exps:
exp.reward = reward
# only attach metrics to the last experience
if len(metrics) > 0:
exps[-1].metrics = metrics
return exps

async def run_async(self) -> List[Experience]:
"""Run the workflow asynchronously and return experiences."""
try:
from agentscope.tuner import JudgeOutput, WorkflowOutput
except ImportError:
self.logger.error(
"Fail to import agentscope tuner related types. Please ensure agentscope>=1.0.11 is installed."
)

metrics = {}
workflow_output: WorkflowOutput = await self.workflow_func(
self.task.raw_task, self.chat_model, self.auxiliary_chat_models
) # type: ignore [arg-type]
metrics.update(workflow_output.metrics or {})
if self.judge_func is not None:
assert (
workflow_output.response is not None
), "Workflow must provide response for judging."
judge_output: JudgeOutput = await self.judge_func(
self.task.raw_task, workflow_output.response, self.auxiliary_chat_models
) # type: ignore [arg-type]
reward = judge_output.reward
metrics.update(judge_output.metrics or {})
else:
assert (
workflow_output.reward is not None
), "Either workflow or judge must provide reward."
reward = workflow_output.reward
return self.construct_experiences(reward, metrics)