Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions trinity/common/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
else:
return ray.get(self.model.convert_messages_to_experience.remote(messages))

def tokenize_text(self, text: str) -> Tensor:
"Convert text to token ids tensor."
if self.use_async:
return ray.get(self.model.tokenize_text_async.remote(text))
else:
return ray.get(self.model.tokenize_text.remote(text))

def get_ckp_version(self) -> int:
return ray.get(self.model.get_ckp_version.remote())

Expand Down
18 changes: 18 additions & 0 deletions trinity/common/models/vllm_async_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,24 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex
action_mask=action_mask,
)

async def tokenize_text_async(self, text: str) -> torch.Tensor:
"""Convert text to token ids tensor.

Args:
text (str): Input text to be tokenized

Returns:
torch.Tensor: Token ids tensor
"""
if self.tokenizer is None:
self.tokenizer = await self.async_llm.get_tokenizer()

# Tokenize the text
token_ids = self.tokenizer.encode(text)

# Convert to tensor
return torch.tensor(token_ids, dtype=torch.int32)

def shutdown(self):
"""Shutdown the vLLM v1 engine. This kills child processes forked
by the vLLM engine. If not called, the child processes will be
Expand Down
13 changes: 13 additions & 0 deletions trinity/common/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,19 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
action_mask=action_mask,
)

def tokenize_text(self, text: str) -> torch.Tensor:
"""
Convert text to token ids tensor.

Args:
text (str)

Returns:
torch.Tensor: token ids tensor
"""
token_ids = self.tokenizer.encode(text)
return torch.tensor(token_ids, dtype=torch.int32)

def has_api_server(self) -> bool:
return False

Expand Down