diff --git a/trinity/common/config.py b/trinity/common/config.py index a7a7b67d78..0c44f7c2e6 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -91,6 +91,8 @@ class ModelConfig: max_response_tokens: int = 2048 # The checkpoint directory, contains a latest dir link and multiple checkpoint dirs. checkpoint_path: str = "" + # for models support both thinking and non-thinking mode, e.g., Qwen3 + enable_thinking: bool = False @dataclass diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index 50a774d8d4..2f9249e0e9 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -65,6 +65,7 @@ def __init__( output_kind=RequestOutputKind.FINAL_ONLY, logprobs=config.explorer.logprobs, ) + self.enable_thinking = config.model.enable_thinking self.request_id = 0 engine_args = vllm.AsyncEngineArgs( model=config.model.model_path, @@ -137,6 +138,7 @@ async def chat_async(self, messages: List[Dict], **kwargs) -> List[Experience]: tokenize=False, add_generation_prompt=True, chat_template=self.chat_template, + enable_thinking=self.enable_thinking, ) return await self.generate_async(prompt=prompt, **kwargs) diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 0964484e00..c4baf567a6 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -71,6 +71,7 @@ def __init__(self, config: Config, **kwargs): ) self.tokenizer = self.llm.get_tokenizer() self.chat_template = self.tokenizer.get_chat_template() + self.enable_thinking = config.model.enable_thinking if self.config.explorer.chat_template: self.chat_template = self.config.explorer.chat_template if not re.search(r"\{\%-?\s*generation\s*-?\%\}", self.chat_template): @@ -233,6 +234,7 @@ def chat(self, messages: List[dict], **kwargs) -> List[Experience]: tokenize=False, add_generation_prompt=True, chat_template=self.chat_template, + enable_thinking=self.enable_thinking, ) return self.generate([prompt], **kwargs) diff --git a/trinity/explorer/runner_pool.py b/trinity/explorer/runner_pool.py index e7641387c8..5805f71907 100644 --- a/trinity/explorer/runner_pool.py +++ b/trinity/explorer/runner_pool.py @@ -49,13 +49,15 @@ def __init__(self, config: Config, models: List): self._create_actors(config.explorer.runner_num) def _create_actors(self, num: int = 1): + new_actors = [] for _ in range(num): engine_index = self.engine_status.index(min(self.engine_status)) new_actor = WorkflowRunner.remote(self.config, self.models[engine_index]) - ray.get(new_actor.__ray_ready__.remote()) + new_actors.append(new_actor) self.engine_status[engine_index] += 1 self.actor_to_engine_index[new_actor] = engine_index - self._return_actor(new_actor) + for actor in new_actors: + self._return_actor(actor) def _kill_actors(self, actors): if not isinstance(actors, list): @@ -234,7 +236,7 @@ def get_next(self) -> Status: def _return_actor(self, actor): try: - actor.is_alive.remote() + ray.get(actor.is_alive.remote()) self._idle_actors.append(actor) except Exception: self.logger.info("The actor is not alive, restart a new actor")