diff --git a/trinity/common/config.py b/trinity/common/config.py index c168bb48b1..fbb1c2d148 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -49,6 +49,8 @@ class FormatConfig: # for unpaired preference dataset label_key: str = "" + use_base_format: bool = False + @dataclass class GenerationConfig: diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 603bd1ced4..1ec255f13f 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -190,12 +190,27 @@ def format_messages(self): messages.append({"role": "assistant", "content": self.reply_prefix}) return messages + def format_prompt(self): + prompt_text = "" + if self.system_prompt: + prompt_text += self.system_prompt + prompt_text += "\nTask:\n" + self.task_desc + "\nResponse:\n" + else: + prompt_text += "\nTask:\n" + self.task_desc + "\nResponse:\n" + return prompt_text + def run(self) -> List[Experience]: # TODO: Optimize the generate function - messages = self.format_messages() + if self.format_args.use_base_format: + prompt_text = self.format_prompt() + else: + messages = self.format_messages() logger.debug("start chat") - responses = self.model.chat(messages, **self.rollout_args) + if self.format_args.use_base_format: + responses = self.model.generate([prompt_text], **self.rollout_args) + else: + responses = self.model.chat(messages, **self.rollout_args) for response in responses: reward = self.reward_fn( # type: ignore [misc] response=response.response_text, # type: ignore [arg-type]