diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index baf5ffdd8c..e63c83f4ef 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -99,6 +99,13 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: else: return ray.get(self.model.convert_messages_to_experience.remote(messages)) + def tokenize_text(self, text: str) -> Tensor: + "Convert text to token ids tensor." + if self.use_async: + return ray.get(self.model.tokenize_text_async.remote(text)) + else: + return ray.get(self.model.tokenize_text.remote(text)) + def get_ckp_version(self) -> int: return ray.get(self.model.get_ckp_version.remote()) diff --git a/trinity/common/models/vllm_async_model.py b/trinity/common/models/vllm_async_model.py index ae5c4db9c1..37d956e72e 100644 --- a/trinity/common/models/vllm_async_model.py +++ b/trinity/common/models/vllm_async_model.py @@ -227,6 +227,24 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex action_mask=action_mask, ) + async def tokenize_text_async(self, text: str) -> torch.Tensor: + """Convert text to token ids tensor. + + Args: + text (str): Input text to be tokenized + + Returns: + torch.Tensor: Token ids tensor + """ + if self.tokenizer is None: + self.tokenizer = await self.async_llm.get_tokenizer() + + # Tokenize the text + token_ids = self.tokenizer.encode(text) + + # Convert to tensor + return torch.tensor(token_ids, dtype=torch.int32) + def shutdown(self): """Shutdown the vLLM v1 engine. This kills child processes forked by the vLLM engine. If not called, the child processes will be diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 32ab98fe8a..cc46852c75 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -268,6 +268,19 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience: action_mask=action_mask, ) + def tokenize_text(self, text: str) -> torch.Tensor: + """ + Convert text to token ids tensor. + + Args: + text (str) + + Returns: + torch.Tensor: token ids tensor + """ + token_ids = self.tokenizer.encode(text) + return torch.tensor(token_ids, dtype=torch.int32) + def has_api_server(self) -> bool: return False