Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class ModelConfig:
max_response_tokens: int = 2048
# The checkpoint directory, contains a latest dir link and multiple checkpoint dirs.
checkpoint_path: str = ""
# for models support both thinking and non-thinking mode, e.g., Qwen3
enable_thinking: bool = False


@dataclass
Expand Down
2 changes: 2 additions & 0 deletions trinity/common/models/vllm_async_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
output_kind=RequestOutputKind.FINAL_ONLY,
logprobs=config.explorer.logprobs,
)
self.enable_thinking = config.model.enable_thinking
self.request_id = 0
engine_args = vllm.AsyncEngineArgs(
model=config.model.model_path,
Expand Down Expand Up @@ -137,6 +138,7 @@ async def chat_async(self, messages: List[Dict], **kwargs) -> List[Experience]:
tokenize=False,
add_generation_prompt=True,
chat_template=self.chat_template,
enable_thinking=self.enable_thinking,
)
return await self.generate_async(prompt=prompt, **kwargs)

Expand Down
2 changes: 2 additions & 0 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, config: Config, **kwargs):
)
self.tokenizer = self.llm.get_tokenizer()
self.chat_template = self.tokenizer.get_chat_template()
self.enable_thinking = config.model.enable_thinking
if self.config.explorer.chat_template:
self.chat_template = self.config.explorer.chat_template
if not re.search(r"\{\%-?\s*generation\s*-?\%\}", self.chat_template):
Expand Down Expand Up @@ -233,6 +234,7 @@ def chat(self, messages: List[dict], **kwargs) -> List[Experience]:
tokenize=False,
add_generation_prompt=True,
chat_template=self.chat_template,
enable_thinking=self.enable_thinking,
)
return self.generate([prompt], **kwargs)

Expand Down
8 changes: 5 additions & 3 deletions trinity/explorer/runner_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,15 @@ def __init__(self, config: Config, models: List):
self._create_actors(config.explorer.runner_num)

def _create_actors(self, num: int = 1):
new_actors = []
for _ in range(num):
engine_index = self.engine_status.index(min(self.engine_status))
new_actor = WorkflowRunner.remote(self.config, self.models[engine_index])
ray.get(new_actor.__ray_ready__.remote())
new_actors.append(new_actor)
self.engine_status[engine_index] += 1
self.actor_to_engine_index[new_actor] = engine_index
self._return_actor(new_actor)
for actor in new_actors:
self._return_actor(actor)

def _kill_actors(self, actors):
if not isinstance(actors, list):
Expand Down Expand Up @@ -234,7 +236,7 @@ def get_next(self) -> Status:

def _return_actor(self, actor):
try:
actor.is_alive.remote()
ray.get(actor.is_alive.remote())
self._idle_actors.append(actor)
except Exception:
self.logger.info("The actor is not alive, restart a new actor")
Expand Down