From 325c700f7e68c7af3f7b0990a087a037cb9108da Mon Sep 17 00:00:00 2001 From: pxc Date: Sun, 4 Jan 2026 17:55:38 +0800 Subject: [PATCH 1/7] fix image link --- docs/sphinx_doc/source/tutorial/example_tinker_backend.md | 2 +- docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_tinker_backend.md b/docs/sphinx_doc/source/tutorial/example_tinker_backend.md index a1a3db5061..bcdb34e96e 100644 --- a/docs/sphinx_doc/source/tutorial/example_tinker_backend.md +++ b/docs/sphinx_doc/source/tutorial/example_tinker_backend.md @@ -211,4 +211,4 @@ synchronizer: Since Llama-3.2-3B is a base (non-instruct-tuned) model, it has limited ability to follow formatting instructions. Additionally, we trained for only **one epoch**. As a result, both backends achieved final rewards just slightly above 0.1. Nonetheless, the training curves show a clear upward trend in reward, indicating successful learning. The results are visualized below: -![Training Rewards on GSM8K](../../docs/sphinx_doc/assets/tinker-gsm8k.png) +![Training Rewards on GSM8K](../../assets/tinker-gsm8k.png) diff --git a/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md b/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md index a56d6eb671..3c360f6029 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md @@ -210,4 +210,4 @@ synchronizer: 由于 Llama-3.2-3B 是基础(非指令微调)模型,其格式化指令跟随能力有限,且本实验仅训练了**一个 epoch**。因此,两种后端的最终 reward 都略高于 0.1。但训练曲线显示 reward 呈明显上升趋势,表明模型已成功学习。结果可视化如下: -![GSM8K 训练奖励曲线](../../docs/sphinx_doc/assets/tinker-gsm8k.png) +![Training Rewards on GSM8K](../../assets/tinker-gsm8k.png) From 629011a18d1d83b6469af3944f3bf9d9a518b679 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 5 Jan 2026 13:50:59 +0800 Subject: [PATCH 2/7] enhance agentscope workflow function --- tests/common/vllm_test.py | 1 - trinity/common/models/model.py | 49 ++++++----- trinity/common/models/tinker_model.py | 8 +- trinity/common/models/vllm_model.py | 10 +-- .../common/workflows/agentscope_workflow.py | 84 +++++++++++++------ 5 files changed, 89 insertions(+), 63 deletions(-) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index b1deb90dff..adb523667e 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -135,7 +135,6 @@ async def test_generate(self): await prepare_engines(self.engines, self.auxiliary_engines) await self.model_wrapper.prepare() self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path) - self.assertEqual(await self.model_wrapper.model_path_async, self.config.model.model_path) prompts = ["Hello, world!", "Hello, my name is"] n = self.config.algorithm.repeat_times if self.use_async: diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 0bc8606736..5d59e926d0 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -13,6 +13,7 @@ from PIL import Image from torch import Tensor +from trinity.common.config import InferenceModelConfig from trinity.common.constants import RunningStatus from trinity.common.experience import Experience from trinity.utils.log import get_logger @@ -21,6 +22,10 @@ class InferenceModel(ABC): """A model for high performance for rollout inference.""" + def __init__(self, config: InferenceModelConfig) -> None: + self.config = config + self.logger = get_logger(__name__) + async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: """Generate a responses from a prompt in async.""" raise NotImplementedError @@ -66,13 +71,9 @@ def get_api_server_url(self) -> Optional[str]: """Get the API server URL if available.""" return None - def get_model_path(self) -> Optional[str]: - """Get the model path""" - return None - - def get_model_name(self) -> Optional[str]: - """Get the name of the model.""" - return None + def get_model_config(self) -> InferenceModelConfig: + """Get the model configuration.""" + return self.config def _history_recorder(func): @@ -117,7 +118,9 @@ def __init__( engine_type.startswith("vllm") or engine_type == "tinker" ), "Only vLLM and tinker model is supported for now." self.model = model - self._model_name = None + self.config: InferenceModelConfig = None # init during prepare + self._model_name: str = None + self._model_path: str = None self.api_address: str = None self.openai_client: openai.OpenAI = None self.openai_async_client: openai.AsyncOpenAI = None @@ -133,7 +136,14 @@ def __init__( async def prepare(self) -> None: """Prepare the model wrapper.""" - self._model_name = await self.model.get_model_name.remote() + self.config = await self.model.get_model_config.remote() + self._model_name = self.config.name + self._model_path = self.config.model_path + self._generate_kwargs = { + "temperature": self.config.temperature, + "top_p": self.config.top_p, + "max_tokens": self.config.max_response_tokens, + } self.api_address = await self.model.get_api_server_url.remote() if self.api_address is None: self.logger.info("API server is not enabled for inference model.") @@ -284,12 +294,7 @@ async def model_version_async(self) -> int: @property def model_path(self) -> str: """Get the model path.""" - return ray.get(self.model.get_model_path.remote()) - - @property - async def model_path_async(self) -> str: - """Get the model path.""" - return await self.model.get_model_path.remote() + return self._model_path @property def model_name(self) -> Optional[str]: @@ -297,9 +302,9 @@ def model_name(self) -> Optional[str]: return self._model_name @property - async def model_name_async(self) -> Optional[str]: - """Get the name of the model.""" - return self._model_name + def generate_kwargs(self) -> Dict[str, Any]: + """Get the generation kwargs for openai client.""" + return self._generate_kwargs def get_lora_request(self) -> Any: if self.enable_lora: @@ -316,7 +321,7 @@ async def get_lora_request_async(self) -> Any: async def get_message_token_len(self, messages: List[dict]) -> int: return await self.model.get_message_token_len.remote(messages) - def get_openai_client(self) -> openai.OpenAI: + def get_openai_client(self, enable_logprobs: bool = True) -> openai.OpenAI: """Get the openai client. Returns: @@ -338,7 +343,7 @@ def get_openai_client(self) -> openai.OpenAI: ori_create = self.openai_client.chat.completions.create def record_chat_completions(*args, **kwargs): - logprobs = kwargs.pop("logprobs", True) + logprobs = kwargs.pop("logprobs", enable_logprobs) extra_body = kwargs.pop("extra_body", {}) if self.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: @@ -353,7 +358,7 @@ def record_chat_completions(*args, **kwargs): setattr(self.openai_client, "model_path", self.openai_client.models.list().data[0].id) return self.openai_client - def get_openai_async_client(self) -> openai.AsyncOpenAI: + def get_openai_async_client(self, enable_logprobs: bool = True) -> openai.AsyncOpenAI: """Get the async openai client. Returns: @@ -376,7 +381,7 @@ def get_openai_async_client(self) -> openai.AsyncOpenAI: ori_create = self.openai_async_client.chat.completions.create async def record_chat_completions(*args, **kwargs): - logprobs = kwargs.pop("logprobs", True) + logprobs = kwargs.pop("logprobs", enable_logprobs) extra_body = kwargs.pop("extra_body", {}) if self.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: diff --git a/trinity/common/models/tinker_model.py b/trinity/common/models/tinker_model.py index e8b6d492a1..92451526aa 100644 --- a/trinity/common/models/tinker_model.py +++ b/trinity/common/models/tinker_model.py @@ -11,7 +11,6 @@ from trinity.common.models.model import InferenceModel from trinity.common.models.utils import get_action_mask_method from trinity.manager.synchronizer import Synchronizer -from trinity.utils.log import get_logger class TinkerModel(InferenceModel): @@ -19,10 +18,9 @@ def __init__( self, config: InferenceModelConfig, ) -> None: - self.config = config + super().__init__(config) self.model_version = -1 self.synchronizer = Synchronizer.get_actor(namespace=ray.get_runtime_context().namespace) - self.logger = get_logger(__name__) self.model = None self.tokenizer = None self.chat_template = None @@ -199,7 +197,3 @@ def get_api_server_url(self) -> Optional[str]: """Get the API server URL if available.""" # TODO: tinker will support openai api later return None - - def get_model_path(self) -> Optional[str]: - """Get the model path""" - return self.config.model_path # type: ignore [return-value] diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 5d839a51b2..cab0e0ada5 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -21,7 +21,6 @@ from trinity.common.models.model import InferenceModel from trinity.common.models.utils import get_action_mask_method from trinity.common.models.vllm_patch import get_vllm_version -from trinity.utils.log import get_logger # V0 engine is deprecated since vLLM v0.10.2, related code will be removed in the future. @@ -36,10 +35,11 @@ def __init__( self, config: InferenceModelConfig, ) -> None: + super().__init__(config) + import vllm from vllm.sampling_params import RequestOutputKind - self.logger = get_logger(__name__) self.vllm_version = get_vllm_version() self.config = config self.use_v1 = config.use_v1 @@ -715,12 +715,6 @@ async def reset_prefix_cache(self) -> None: def get_model_version(self) -> int: return self.model_version - def get_model_path(self) -> str: - return self.config.model_path # type: ignore [return-value] - - def get_model_name(self) -> Optional[str]: - return self.config.name # type: ignore [return-value] - def get_lora_request(self, lora_path: Optional[str] = None) -> Any: from vllm.lora.request import LoRARequest diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index c95beb56f4..b89d472f1f 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -1,12 +1,18 @@ +import inspect from typing import Awaitable, Callable, Dict, List, Optional from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.workflows.workflow import Task, Workflow +from trinity.utils.annotations import Deprecated +@Deprecated class AgentScopeWorkflowAdapter(Workflow): - """Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow.""" + """Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow. + Only for agentscope versions between 1.0.7 and 1.0.11. + For agentscope >= 1.0.11, please use AgentScopeWorkflowAdapterV1. + """ is_async: bool = True @@ -75,7 +81,9 @@ async def run_async(self) -> List[Experience]: class AgentScopeWorkflowAdapterV1(Workflow): - """A more general adapter to wrap agentscope trainable workflow and judge functions into a Trinity Workflow.""" + """A more general adapter to wrap agentscope trainable workflow and judge functions into a Trinity Workflow. + For + """ is_async: bool = True @@ -88,11 +96,11 @@ def __init__( ): """Initialize the adapter with the task and model.""" try: - from agentscope.model import TrinityChatModel + from agentscope.model import OpenAIChatModel except ImportError: raise ImportError( - "This workflow requires agentscope >= 1.0.11, please install " - "it via `pip install agentscope>=1.0.11`", + "This workflow requires agentscope >= 1.0.12, please install " + "it via `pip install agentscope>=1.0.12`", ) super().__init__( @@ -102,14 +110,16 @@ def __init__( ) self.workflow_func = task.workflow_args.get("workflow_func", None) self.judge_func = task.workflow_args.get("judge_func", None) + self._openai_client = self.model.get_openai_async_client() if self.workflow_func is None: raise ValueError( "The 'workflow_func' is not provided.", ) - self.chat_model: TrinityChatModel = TrinityChatModel( - model.get_openai_async_client(), + self.chat_model: OpenAIChatModel = OpenAIChatModel( + api_key="EMPTY", + model_name=self._openai_client.model_path, generate_kwargs={ "temperature": self.task.rollout_args.temperature, "top_p": self.task.rollout_args.top_p, @@ -118,18 +128,21 @@ def __init__( "top_logprobs": self.task.rollout_args.logprobs, }, ) - - # TODO: customize generate_kwargs for auxiliary models if needed - if self.auxiliary_model_wrappers is not None and self.auxiliary_models is not None: - self.auxiliary_chat_models = { - aux_model_wrapper.model_name - or f"auxiliary_model_{i}": TrinityChatModel(openai_async_client=aux_model) - for i, (aux_model_wrapper, aux_model) in enumerate( - zip(self.auxiliary_model_wrappers, self.auxiliary_models) + self.chat_model.client = self._openai_client + self.auxiliary_chat_models: Dict[str, OpenAIChatModel] = {} + if self.auxiliary_model_wrappers is not None: + for aux_model_wrapper in self.auxiliary_model_wrappers: + aux_model_client = aux_model_wrapper.get_openai_async_client() + aux_chat_model = OpenAIChatModel( + api_key="EMPTY", + model_name=aux_model_client.model_path, + generate_kwargs=aux_model_wrapper.generate_kwargs, ) - } - else: - self.auxiliary_chat_models = {} + aux_chat_model.client = aux_model_client + assert ( + aux_model_wrapper.model_name is not None + ), "Auxiliary model must have a name. This should not happen." + self.auxiliary_chat_models[aux_model_wrapper.model_name] = aux_chat_model def construct_experiences( self, @@ -159,21 +172,42 @@ async def run_async(self) -> List[Experience]: from agentscope.tuner import JudgeOutput, WorkflowOutput except ImportError: raise ImportError( - "Fail to import agentscope tuner related types. Please ensure agentscope>=1.0.11 is installed." + "Fail to import agentscope tuner related types. Please ensure agentscope>=1.0.12 is installed." ) metrics = {} - workflow_output: WorkflowOutput = await self.workflow_func( - self.task.raw_task, self.chat_model, self.auxiliary_chat_models - ) # type: ignore [arg-type] + workflow_sig = inspect.signature(self.workflow_func) + if "auxiliary_models" in workflow_sig.parameters: + workflow_output = await self.workflow_func( + self.task.raw_task, self.chat_model, self.auxiliary_chat_models + ) + else: + workflow_output = await self.workflow_func(self.task.raw_task, self.chat_model) + if not isinstance(workflow_output, WorkflowOutput): + raise ValueError( + "The 'workflow_func' must return a WorkflowOutput object.", + ) metrics.update(workflow_output.metrics or {}) if self.judge_func is not None: assert ( workflow_output.response is not None ), "Workflow must provide response for judging." - judge_output: JudgeOutput = await self.judge_func( - self.task.raw_task, workflow_output.response, self.auxiliary_chat_models - ) # type: ignore [arg-type] + judge_sig = inspect.signature(self.judge_func) + if "auxiliary_models" in judge_sig.parameters: + judge_output = await self.judge_func( + self.task.raw_task, + workflow_output.response, + self.auxiliary_chat_models, + ) + else: + judge_output = await self.judge_func( + self.task.raw_task, + workflow_output.response, + ) + if not isinstance(judge_output, JudgeOutput): + raise ValueError( + "The 'judge_func' must return a JudgeOutput object.", + ) reward = judge_output.reward metrics.update(judge_output.metrics or {}) else: From d96ff6b57c334cb24e20b1b3f5013adfacf4bb58 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 5 Jan 2026 14:31:23 +0800 Subject: [PATCH 3/7] add as adapter unittests --- tests/explorer/workflow_test.py | 53 ++++++++++++++++++- trinity/common/models/model.py | 12 +++-- .../common/workflows/agentscope_workflow.py | 7 +-- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 5899dfa54f..1f648eb2de 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -551,7 +551,7 @@ async def monitor_routine(): class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase): - async def test_adapter(self): + async def test_adapter_v0(self): try: from agentscope.model import TrinityChatModel except ImportError: @@ -586,6 +586,57 @@ async def as_workflow_func(task, model) -> float: self.assertEqual(result[1].reward, 0.1) self.assertEqual(result[1].prompt_length, 2) + async def test_adapter_v1(self): + try: + from agentscope.model import ChatModelBase + from agentscope.tuner import JudgeOutput, WorkflowOutput + except ImportError: + self.skipTest("agentscope >= 1.0.12 is not installed") + + async def as_workflow_func(task, model) -> WorkflowOutput: + self.assertIsInstance(task, dict) + self.assertIsInstance(model, ChatModelBase) + return WorkflowOutput( + reward=task["reward"], + metrics={"workflow_metric_1": 0.0}, + ) + + async def as_judge_func(task, workflow_output) -> JudgeOutput: + self.assertIsInstance(task, dict) + self.assertIsInstance(workflow_output, WorkflowOutput) + return JudgeOutput( + reward=workflow_output.reward, + metrics={"judge_metric_1": 1.0}, + ) + + model = MagicMock() + openai_client = MagicMock() + openai_client.model_path = "Qwen/Qwen3-8B" + model.get_openai_async_client.return_value = openai_client + model.extract_experience_from_history.return_value = [ + Experience(tokens=Tensor([0, 1, 2]), prompt_length=1, logprobs=Tensor([0.1, 0.2])), + ] + + as_adapter_cls = WORKFLOWS.get("agentscope_workflow_adapter_v1") + as_adapter = as_adapter_cls( + task=Task( + raw_task={"reward": 0.2}, + workflow_args={ + "workflow_func": as_workflow_func, + "judge_func": as_judge_func, + }, + ), + model=model, + ) + result = await as_adapter.run_async() + self.assertEqual(len(result), 1) + self.assertEqual(result[0].reward, 0.2) + self.assertEqual(result[0].prompt_length, 1) + metrics = result[-1].metrics + self.assertEqual(len(metrics), 2) + self.assertEqual(metrics["workflow_metric_1"], 0.0) + self.assertEqual(metrics["judge_metric_1"], 1.0) + class DummyModelWrapper: def __init__(self, model, engine_type="vllm", **kwargs): diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 5d59e926d0..2f4bff7e11 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -144,6 +144,10 @@ async def prepare(self) -> None: "top_p": self.config.top_p, "max_tokens": self.config.max_response_tokens, } + if self.config.enable_thinking is not None: + self._generate_kwargs["chat_template_kwargs"] = { + "enable_thinking": self.config.enable_thinking + } self.api_address = await self.model.get_api_server_url.remote() if self.api_address is None: self.logger.info("API server is not enabled for inference model.") @@ -321,7 +325,7 @@ async def get_lora_request_async(self) -> Any: async def get_message_token_len(self, messages: List[dict]) -> int: return await self.model.get_message_token_len.remote(messages) - def get_openai_client(self, enable_logprobs: bool = True) -> openai.OpenAI: + def get_openai_client(self) -> openai.OpenAI: """Get the openai client. Returns: @@ -343,7 +347,7 @@ def get_openai_client(self, enable_logprobs: bool = True) -> openai.OpenAI: ori_create = self.openai_client.chat.completions.create def record_chat_completions(*args, **kwargs): - logprobs = kwargs.pop("logprobs", enable_logprobs) + logprobs = kwargs.pop("logprobs", True) extra_body = kwargs.pop("extra_body", {}) if self.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: @@ -358,7 +362,7 @@ def record_chat_completions(*args, **kwargs): setattr(self.openai_client, "model_path", self.openai_client.models.list().data[0].id) return self.openai_client - def get_openai_async_client(self, enable_logprobs: bool = True) -> openai.AsyncOpenAI: + def get_openai_async_client(self) -> openai.AsyncOpenAI: """Get the async openai client. Returns: @@ -381,7 +385,7 @@ def get_openai_async_client(self, enable_logprobs: bool = True) -> openai.AsyncO ori_create = self.openai_async_client.chat.completions.create async def record_chat_completions(*args, **kwargs): - logprobs = kwargs.pop("logprobs", enable_logprobs) + logprobs = kwargs.pop("logprobs", True) extra_body = kwargs.pop("extra_body", {}) if self.enable_thinking is not None: if "chat_template_kwargs" not in extra_body: diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index b89d472f1f..7524e3c0a6 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -189,20 +189,17 @@ async def run_async(self) -> List[Experience]: ) metrics.update(workflow_output.metrics or {}) if self.judge_func is not None: - assert ( - workflow_output.response is not None - ), "Workflow must provide response for judging." judge_sig = inspect.signature(self.judge_func) if "auxiliary_models" in judge_sig.parameters: judge_output = await self.judge_func( self.task.raw_task, - workflow_output.response, + workflow_output, self.auxiliary_chat_models, ) else: judge_output = await self.judge_func( self.task.raw_task, - workflow_output.response, + workflow_output, ) if not isinstance(judge_output, JudgeOutput): raise ValueError( From 5978d37c00d44246775255462fd84951577b6945 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 5 Jan 2026 14:33:35 +0800 Subject: [PATCH 4/7] fix comments --- docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md | 2 +- trinity/common/workflows/agentscope_workflow.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md b/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md index 3c360f6029..abb015fa08 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_tinker_backend.md @@ -210,4 +210,4 @@ synchronizer: 由于 Llama-3.2-3B 是基础(非指令微调)模型,其格式化指令跟随能力有限,且本实验仅训练了**一个 epoch**。因此,两种后端的最终 reward 都略高于 0.1。但训练曲线显示 reward 呈明显上升趋势,表明模型已成功学习。结果可视化如下: -![Training Rewards on GSM8K](../../assets/tinker-gsm8k.png) +![GSM8K 训练奖励曲线](../../assets/tinker-gsm8k.png) diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index 7524e3c0a6..15c671a59f 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -11,7 +11,7 @@ class AgentScopeWorkflowAdapter(Workflow): """Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow. Only for agentscope versions between 1.0.7 and 1.0.11. - For agentscope >= 1.0.11, please use AgentScopeWorkflowAdapterV1. + For agentscope >= 1.0.12, please use AgentScopeWorkflowAdapterV1. """ is_async: bool = True @@ -82,7 +82,7 @@ async def run_async(self) -> List[Experience]: class AgentScopeWorkflowAdapterV1(Workflow): """A more general adapter to wrap agentscope trainable workflow and judge functions into a Trinity Workflow. - For + Only for agentscope versions >= 1.0.12. """ is_async: bool = True From 85344e5202c565ee67793d9f3d723647a0316020 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 5 Jan 2026 17:53:41 +0800 Subject: [PATCH 5/7] add tests for v1 adapter --- pyproject.toml | 2 +- scripts/docker/Dockerfile.uv | 3 +- tests/explorer/workflow_test.py | 7 +- tests/trainer/trainer_test.py | 147 ++++++++++++++++++ trinity/common/config.py | 7 + trinity/common/models/__init__.py | 6 +- trinity/common/models/model.py | 4 +- trinity/common/models/vllm_model.py | 5 +- .../common/workflows/agentscope_workflow.py | 21 ++- trinity/explorer/explorer.py | 4 +- trinity/explorer/scheduler.py | 2 +- trinity/trainer/trainer.py | 2 +- 12 files changed, 187 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 23a872d22b..365053f2ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,7 +82,7 @@ megatron = [ "mbridge>=0.13.0", ] tinker = [ - "tinker", # tinker requires python>=3.11 + "tinker; python_version >= '3.11'", ] doc = [ diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index 4d08ecd60d..3aafc8c1a0 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -7,7 +7,8 @@ # # Note: # 1. This Dockerfile uses 'uv' to create a virtual environment for better package management. If you want a simpler setup without 'uv', please refer to `scripts/docker/Dockerfile`. -# 2. Make sure to use `uv pip` to install packages within the virtual environment. +# 2. The uv virtual environment is created at `/opt/venv`, use `source /opt/venv/bin/activate` to activate it. +# 3. Make sure to use `uv pip` to install packages within the virtual environment. FROM nvcr.io/nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 1f648eb2de..9156ac0a39 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -598,14 +598,15 @@ async def as_workflow_func(task, model) -> WorkflowOutput: self.assertIsInstance(model, ChatModelBase) return WorkflowOutput( reward=task["reward"], + response=task["reward"], metrics={"workflow_metric_1": 0.0}, ) - async def as_judge_func(task, workflow_output) -> JudgeOutput: + async def as_judge_func(task, response) -> JudgeOutput: self.assertIsInstance(task, dict) - self.assertIsInstance(workflow_output, WorkflowOutput) + self.assertIsInstance(response, float) return JudgeOutput( - reward=workflow_output.reward, + reward=response, metrics={"judge_metric_1": 1.0}, ) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 2a54b9a494..17556d1fcb 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -9,6 +9,7 @@ import unittest from copy import deepcopy from datetime import datetime +from typing import Dict from unittest import mock import ray @@ -1433,3 +1434,149 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) + + +@unittest.skip("Require agentscope >= 1.0.12") +class AgentScopeTunerTest(unittest.IsolatedAsyncioTestCase): + def setUp(self) -> None: + ray.init(ignore_reinit_error=True) + + def tearDown(self) -> None: + ray.shutdown(_exiting_interpreter=True) + + def test_agentscope_tuner(self): + try: + from agentscope.agent import ReActAgent + from agentscope.formatter import OpenAIChatFormatter + from agentscope.message import Msg + from agentscope.model import ChatModelBase + from agentscope.tuner import ( + Algorithm, + Dataset, + JudgeOutput, + TunerChatModel, + WorkflowOutput, + tune, + ) + except ImportError: + self.skipTest("agentscope >= 1.0.12 is not installed") + + async def workflow_func( + task: Dict, + model: ChatModelBase, + auxiliary_models: Dict[str, ChatModelBase], + ) -> WorkflowOutput: + assert isinstance(model, ChatModelBase) + assert "judge_model" in auxiliary_models + assert isinstance(auxiliary_models["judge_model"], ChatModelBase) + agent = ReActAgent( + name="test_agent", + model=model, + sys_prompt="You are a helpful assistant.", + formatter=OpenAIChatFormatter(), + ) + st = time.time() + response = await agent.reply(Msg("user", task["question"], role="user")) + et = time.time() + return WorkflowOutput(response=response, metrics={"workflow_time": et - st}) + + async def judge_func( + task: Dict, response: Msg, auxiliary_models: Dict[str, ChatModelBase] + ) -> JudgeOutput: + assert "judge_model" in auxiliary_models + judge_model = auxiliary_models["judge_model"] + assert isinstance(judge_model, ChatModelBase) + agent = ReActAgent( + name="judge_agent", + model=judge_model, + sys_prompt="You are a judge to evaluate the correctness of answers.", + formatter=OpenAIChatFormatter(), + ) + workflow_text_response = response.get_text_content() + st = time.time() + judge_response = await agent.reply( + Msg( + "user", + f"Question: {task['question']}\nAnswer: {workflow_text_response}\nIs the answer correct? Reply with 'Yes' or 'No'.", + role="user", + ) + ) + et = time.time() + judge_response = judge_response.get_text_content() + if judge_response is not None and "yes" in judge_response.lower(): + is_correct = True + else: + is_correct = False + return JudgeOutput( + reward=float(is_correct), + metrics={"judge_time": et - st}, + ) + + gsm8k_dataset = get_unittest_dataset_config("gsm8k") + + dataset = Dataset( + path=gsm8k_dataset.path, + split="train", + total_steps=2, + ) + eval_dataset = Dataset( + path=gsm8k_dataset.path, + split="test", + ) + + model = TunerChatModel( + model_path=get_model_path(), + max_model_len=4096, + max_tokens=2048, + inference_engine_num=2, + ) + + auxiliary_models = { + "judge_model": TunerChatModel( + model_path=get_model_path(), + max_model_len=8192, + max_tokens=2048, + inference_engine_num=2, + ) + } + + algorithm = Algorithm( + algorithm_type="multi_step_grpo", + batch_size=4, + group_size=4, + eval_interval_steps=2, + save_interval_steps=2, + ) + + tune( + workflow_func=workflow_func, + judge_func=judge_func, + train_dataset=dataset, + eval_dataset=eval_dataset, + model=model, + auxiliary_models=auxiliary_models, + algorithm=algorithm, + ) + # check checkpoint dir in `./checkpoints/AgentScope/Experiment-` + self.assertTrue(os.path.exists("./checkpoints/AgentScope")) + exp_dirs = os.listdir("./checkpoints/AgentScope") + self.assertGreaterEqual(len(exp_dirs), 1) + latest_exp_dir = sorted(exp_dirs)[-1] + exp_dir_path = os.path.join("./checkpoints/AgentScope", latest_exp_dir) + _, step_num = get_checkpoint_dir_with_step_num( + checkpoint_root_path=exp_dir_path, + trainer_type="verl", + ) + self.assertEqual(step_num, 2) + # check tensorboard + parser = TensorBoardParser(os.path.join(exp_dir_path, "monitor", "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertIn("rollout/workflow_time/mean", rollout_metrics) + self.assertIn("rollout/judge_time/mean", rollout_metrics) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) + eval_metrics = parser.metric_list("eval") + self.assertGreater(len(eval_metrics), 0) + self.assertEqual(parser.metric_max_step(eval_metrics[0]), 2) + actor_metrics = parser.metric_list("actor") + self.assertGreater(len(actor_metrics), 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) diff --git a/trinity/common/config.py b/trinity/common/config.py index a9264cc6b3..df5286cb41 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -261,6 +261,8 @@ class TasksetConfig: total_epochs: int = 1 # automatically set # ! DO NOT SET, automatically set from buffer.total_steps total_steps: Optional[int] = None # automatically set + # ! DO NOT SET, automatically set form ray_namespace + ray_namespace: Optional[str] = None def to_storage_config(self) -> StorageConfig: storage_config = StorageConfig( @@ -285,6 +287,7 @@ def to_storage_config(self) -> StorageConfig: batch_size=self.batch_size, total_epochs=self.total_epochs, total_steps=self.total_steps, + ray_namespace=self.ray_namespace, ) return storage_config @@ -324,6 +327,8 @@ class ExperienceBufferConfig: total_epochs: int = 1 # automatically set # ! DO NOT SET, automatically set from buffer.total_steps total_steps: Optional[int] = None # automatically set + # ! DO NOT SET, automatically set form ray_namespace + ray_namespace: Optional[str] = None def to_storage_config(self) -> StorageConfig: storage_config = StorageConfig( @@ -345,6 +350,7 @@ def to_storage_config(self) -> StorageConfig: tokenizer_path=self.tokenizer_path, total_epochs=self.total_epochs, total_steps=self.total_steps, + ray_namespace=self.ray_namespace, ) return storage_config @@ -546,6 +552,7 @@ class InferenceModelConfig: # ! DO NOT SET bundle_indices: str = "" + ray_namespace: Optional[str] = None # ! DO NOT SET, automatically set from model.lora_configs enable_lora: bool = False diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 46958faa6c..42674ea147 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -58,7 +58,7 @@ def create_inference_models( from trinity.common.models.tinker_model import TinkerModel engine_cls = TinkerModel - namespace = ray.get_runtime_context().namespace + namespace = config.ray_namespace rollout_engines = [ ray.remote(engine_cls) .options( @@ -111,7 +111,8 @@ def create_inference_models( for bundle_id, node_id in bundle_node_map.items(): node_bundle_map[node_id].append(bundle_id) allocator = _BundleAllocator(node_bundle_map) - namespace = ray.get_runtime_context().namespace + namespace = config.ray_namespace + config.explorer.rollout_model.ray_namespace = namespace # create rollout models # in 'serve' mode, we always enable openai api for rollout model if config.mode == "serve": @@ -147,6 +148,7 @@ def create_inference_models( # create auxiliary models for i, model_config in enumerate(config.explorer.auxiliary_models): engines = [] + model_config.ray_namespace = namespace for j in range(model_config.engine_num): bundles_for_engine = allocator.allocate(model_config.tensor_parallel_size) model_config.enable_openai_api = True diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 2f4bff7e11..ac28276876 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -145,8 +145,8 @@ async def prepare(self) -> None: "max_tokens": self.config.max_response_tokens, } if self.config.enable_thinking is not None: - self._generate_kwargs["chat_template_kwargs"] = { - "enable_thinking": self.config.enable_thinking + self._generate_kwargs["extra_body"] = { + "chat_template_kwargs": {"enable_thinking": self.config.enable_thinking} } self.api_address = await self.model.get_api_server_url.remote() if self.api_address is None: diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index cab0e0ada5..f3921e91f5 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -6,7 +6,6 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple import numpy as np -import ray import torch from packaging.version import parse as parse_version from PIL import Image @@ -41,7 +40,6 @@ def __init__( from vllm.sampling_params import RequestOutputKind self.vllm_version = get_vllm_version() - self.config = config self.use_v1 = config.use_v1 if config.tensor_parallel_size != 1: os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -80,6 +78,7 @@ def __init__( ignore_eos=config.ignore_eos, ) self.enable_thinking = config.enable_thinking + self.ray_namespace = config.ray_namespace self.request_id = 0 max_model_len = config.max_model_len self.enable_lora = config.enable_lora @@ -638,7 +637,7 @@ async def init_process_group( timeout, state_dict_meta, explorer_name, - ray.get_runtime_context().namespace, + self.ray_namespace, ), ) diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index 15c671a59f..feccfc5932 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -120,6 +120,7 @@ def __init__( self.chat_model: OpenAIChatModel = OpenAIChatModel( api_key="EMPTY", model_name=self._openai_client.model_path, + stream=False, generate_kwargs={ "temperature": self.task.rollout_args.temperature, "top_p": self.task.rollout_args.top_p, @@ -137,6 +138,7 @@ def __init__( api_key="EMPTY", model_name=aux_model_client.model_path, generate_kwargs=aux_model_wrapper.generate_kwargs, + stream=False, ) aux_chat_model.client = aux_model_client assert ( @@ -179,10 +181,15 @@ async def run_async(self) -> List[Experience]: workflow_sig = inspect.signature(self.workflow_func) if "auxiliary_models" in workflow_sig.parameters: workflow_output = await self.workflow_func( - self.task.raw_task, self.chat_model, self.auxiliary_chat_models + task=self.task.raw_task, + model=self.chat_model, + auxiliary_models=self.auxiliary_chat_models, ) else: - workflow_output = await self.workflow_func(self.task.raw_task, self.chat_model) + workflow_output = await self.workflow_func( + task=self.task.raw_task, + model=self.chat_model, + ) if not isinstance(workflow_output, WorkflowOutput): raise ValueError( "The 'workflow_func' must return a WorkflowOutput object.", @@ -192,14 +199,14 @@ async def run_async(self) -> List[Experience]: judge_sig = inspect.signature(self.judge_func) if "auxiliary_models" in judge_sig.parameters: judge_output = await self.judge_func( - self.task.raw_task, - workflow_output, - self.auxiliary_chat_models, + task=self.task.raw_task, + response=workflow_output.response, + auxiliary_models=self.auxiliary_chat_models, ) else: judge_output = await self.judge_func( - self.task.raw_task, - workflow_output, + task=self.task.raw_task, + response=workflow_output.response, ) if not isinstance(judge_output, JudgeOutput): raise ValueError( diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index e369c7c17f..b0893b8c52 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -472,7 +472,7 @@ def _init_experience_pipeline(self) -> ray.actor.ActorHandle: ray.remote(ExperiencePipeline) .options( name=f"{self.config.explorer.name}_pipeline", - namespace=ray.get_runtime_context().namespace, + namespace=self.config.ray_namespace, scheduling_strategy=NodeAffinitySchedulingStrategy( node_id=node_id, soft=False, @@ -531,7 +531,7 @@ def get_actor(cls, config: Config): ray.remote(cls) .options( name=config.explorer.name, - namespace=ray.get_runtime_context().namespace, + namespace=config.ray_namespace, ) .remote(config) ) diff --git a/trinity/explorer/scheduler.py b/trinity/explorer/scheduler.py index 16812e2b22..f84bb6ea26 100644 --- a/trinity/explorer/scheduler.py +++ b/trinity/explorer/scheduler.py @@ -88,7 +88,7 @@ def __init__( self.config = config self.retry_times = config.explorer.max_retry_times self.timeout = config.explorer.max_timeout - self.namespace = ray.get_runtime_context().namespace + self.namespace = config.ray_namespace self.runner = self._create_runner() self.state = {} diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index c4901f3a2f..4cb9c52cc4 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -221,7 +221,7 @@ def get_actor(cls, config: Config): """Get a Ray actor for the trainer.""" return ( ray.remote(cls) - .options(name=config.trainer.name, namespace=ray.get_runtime_context().namespace) + .options(name=config.trainer.name, namespace=config.ray_namespace) .remote(config) ) From 8cee9599eada7c45e1da602ea19bbaffb5b2c584 Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 5 Jan 2026 19:21:55 +0800 Subject: [PATCH 6/7] fix tests --- tests/explorer/scheduler_test.py | 5 +++++ tests/explorer/workflow_test.py | 5 +++++ tests/trainer/trainer_test.py | 6 +++--- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 7aec80ed23..2cd5e9a08d 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -171,6 +171,11 @@ async def run_async(self) -> List[Experience]: @ray.remote class DummyModel(InferenceModel): + def __init__(self): + from trinity.common.config import InferenceModelConfig + + super().__init__(InferenceModelConfig(model_path="dummy_model")) + def sync_model(self, model_version, update_weight_args_list): return True diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 9156ac0a39..ae5cb5a343 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -15,6 +15,7 @@ from tests.common.vllm_test import CHAT_TEMPLATE from tests.tools import get_model_path, get_template_config, get_unittest_dataset_config +from trinity.common.config import InferenceModelConfig from trinity.common.experience import EID, Experience from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper @@ -730,9 +731,13 @@ async def mock_get_api_server_url_remote(): async def mock_get_model_version_remote(): return 1 + async def mock_get_model_config_remote(): + return InferenceModelConfig(model_path="dummy_model") + model = MagicMock() model.get_api_server_url.remote = MagicMock(side_effect=mock_get_api_server_url_remote) model.get_model_version.remote = MagicMock(side_effect=mock_get_model_version_remote) + model.get_model_config.remote = MagicMock(side_effect=mock_get_model_config_remote) runner = WorkflowRunner( config, diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 17556d1fcb..9626b8c4b6 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -1135,10 +1135,10 @@ async def test_serve_with_trainer(self): # noqa: C901 + metrics["rollout/model_1/total_request_count"], metrics["rollout/total_experience_count"], ) - # at least updated to version 2 + # at least updated to version 1 await asyncio.sleep(5) # wait for model version update - self.assertGreaterEqual(metrics["rollout/model_0/model_version"], 2) - self.assertGreaterEqual(metrics["rollout/model_1/model_version"], 2) + self.assertGreaterEqual(metrics["rollout/model_0/model_version"], 1) + self.assertGreaterEqual(metrics["rollout/model_1/model_version"], 1) # check final checkpoint _, step_num = get_checkpoint_dir_with_step_num( checkpoint_root_path=serve_config.checkpoint_job_dir, From 72ef0aa62ceff3d2cf6c23d131ab20f1bd600cfc Mon Sep 17 00:00:00 2001 From: pxc Date: Mon, 5 Jan 2026 20:38:48 +0800 Subject: [PATCH 7/7] fix pre-commit --- tests/buffer/file_test.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/buffer/file_test.py b/tests/buffer/file_test.py index b73ee27579..ea9cbe6abc 100644 --- a/tests/buffer/file_test.py +++ b/tests/buffer/file_test.py @@ -109,7 +109,11 @@ def setUp(self): name="test_buffer", storage_type=StorageType.FILE.value ) self.config.check_and_update() - ray.init(ignore_reinit_error=True, runtime_env={"env_vars": self.config.get_envs()}) + ray.init( + ignore_reinit_error=True, + runtime_env={"env_vars": self.config.get_envs()}, + namespace="trinity_unittest", + ) os.makedirs(self.config.buffer.cache_dir, exist_ok=True) file_path = self.config.buffer.trainer_input.experience_buffer.path if os.path.exists(file_path):