Skip to content

Commit 6e8ab38

Browse files
committed
add use base format
1 parent 3830dd3 commit 6e8ab38

File tree

5 files changed

+56
-2
lines changed

5 files changed

+56
-2
lines changed

trinity/common/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class FormatConfig:
4949
# for unpaired preference dataset
5050
label_key: str = ""
5151

52+
use_base_format: bool = False
53+
5254

5355
@dataclass
5456
class GenerationConfig:

trinity/common/models/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
9999
else:
100100
return ray.get(self.model.convert_messages_to_experience.remote(messages))
101101

102+
def tokenize_text(self, text: str) -> Tensor:
103+
if self.use_async:
104+
return ray.get(self.model.tokenize_text_async.remote(text))
105+
else:
106+
return ray.get(self.model.tokenize_text.remote(text))
107+
102108
def get_ckp_version(self) -> int:
103109
return ray.get(self.model.get_ckp_version.remote())
104110

trinity/common/models/vllm_async_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,24 @@ async def convert_messages_to_experience_async(self, messages: List[dict]) -> Ex
227227
action_mask=action_mask,
228228
)
229229

230+
async def tokenize_text_async(self, text: str) -> torch.Tensor:
231+
"""Convert text to token ids tensor.
232+
233+
Args:
234+
text (str): Input text to be tokenized
235+
236+
Returns:
237+
torch.Tensor: Token ids tensor
238+
"""
239+
if self.tokenizer is None:
240+
self.tokenizer = await self.async_llm.get_tokenizer()
241+
242+
# Tokenize the text
243+
token_ids = self.tokenizer.encode(text)
244+
245+
# Convert to tensor
246+
return torch.tensor(token_ids, dtype=torch.int32)
247+
230248
def shutdown(self):
231249
"""Shutdown the vLLM v1 engine. This kills child processes forked
232250
by the vLLM engine. If not called, the child processes will be

trinity/common/models/vllm_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,19 @@ def convert_messages_to_experience(self, messages: List[dict]) -> Experience:
268268
action_mask=action_mask,
269269
)
270270

271+
def tokenize_text(self, text: str) -> torch.Tensor:
272+
"""
273+
Shape text to token ids.
274+
275+
Args:
276+
text (str)
277+
278+
Returns:
279+
torch.Tensor: token ids tensor
280+
"""
281+
token_ids = self.tokenizer.encode(text)
282+
return torch.tensor(token_ids, dtype=torch.int32)
283+
271284
def has_api_server(self) -> bool:
272285
return False
273286

trinity/common/workflows/workflow.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,12 +190,27 @@ def format_messages(self):
190190
messages.append({"role": "assistant", "content": self.reply_prefix})
191191
return messages
192192

193+
def format_prompt(self):
194+
prompt_text = ""
195+
if self.system_prompt:
196+
prompt_text += self.system_prompt
197+
prompt_text += "\nTask:\n" + self.task_desc + "\nResponse:\n"
198+
else:
199+
prompt_text += "\nTask:\n" + self.task_desc + "\nResponse:\n"
200+
return prompt_text
201+
193202
def run(self) -> List[Experience]:
194203
# TODO: Optimize the generate function
195-
messages = self.format_messages()
204+
if self.format_args.use_base_format:
205+
prompt_text = self.format_prompt()
206+
else:
207+
messages = self.format_messages()
196208

197209
logger.debug("start chat")
198-
responses = self.model.chat(messages, **self.rollout_args)
210+
if self.format_args.use_base_format:
211+
responses = self.model.generate([prompt_text], **self.rollout_args)
212+
else:
213+
responses = self.model.chat(messages, **self.rollout_args)
199214
for response in responses:
200215
reward = self.reward_fn( # type: ignore [misc]
201216
response=response.response_text, # type: ignore [arg-type]

0 commit comments

Comments
 (0)