|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +"""We include the customized math workflows in this file.""" |
| 3 | + |
| 4 | +from dataclasses import asdict |
| 5 | +from typing import List |
| 6 | + |
| 7 | +from trinity.common.experience import Experience |
| 8 | +from trinity.common.rewards.reward_fn import MathBoxedRewardFn |
| 9 | +from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task |
| 10 | +from trinity.utils.log import get_logger |
| 11 | + |
| 12 | +logger = get_logger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +@WORKFLOWS.register_module("math_boxed_workflow") |
| 16 | +class MathBoxedWorkflow(SimpleWorkflow): |
| 17 | + """A workflow for math tasks that give answers in boxed format.""" |
| 18 | + |
| 19 | + def reset(self, task: Task): |
| 20 | + self.format_args = task.format_args |
| 21 | + self.system_prompt = task.format_args.system_prompt |
| 22 | + self.reply_prefix = task.format_args.reply_prefix |
| 23 | + |
| 24 | + self.raw_task = task.raw_task |
| 25 | + self.task_desc = task.task_desc |
| 26 | + self.truth = task.truth |
| 27 | + |
| 28 | + # Rollout args |
| 29 | + rollout_args = asdict(task.rollout_args) |
| 30 | + self.rollout_args = rollout_args |
| 31 | + self.is_eval = task.is_eval |
| 32 | + |
| 33 | + self.workflow_args = task.workflow_args |
| 34 | + |
| 35 | + self.use_base = self.workflow_args.get("use_base", False) |
| 36 | + self.with_think = self.workflow_args.get("with_think", False) |
| 37 | + self.format_score_coef = self.workflow_args.get("format_score_coef", 0.1) |
| 38 | + |
| 39 | + default_prompt = ( |
| 40 | + """Please reason step by step, and put your final answer within \\boxed{}.""" |
| 41 | + ) |
| 42 | + |
| 43 | + default_prompt_with_think = """You are a helpful assistant that solves MATH problems. You should first thinks about the reasoning process in mind and then provides the user with the answer. You should present your reasoning process using the format: <think>\n ...your reasoning process here... </think>\n first. You should always include your final answer in \\boxed{} as closed-form results.""" |
| 44 | + |
| 45 | + if self.system_prompt is None: |
| 46 | + if self.with_think: |
| 47 | + self.system_prompt = default_prompt_with_think |
| 48 | + else: |
| 49 | + self.system_prompt = default_prompt |
| 50 | + |
| 51 | + self.reward_fn = MathBoxedRewardFn() |
| 52 | + |
| 53 | + def format_prompt(self): |
| 54 | + prompt_text = "" |
| 55 | + if self.system_prompt: |
| 56 | + prompt_text += "System:" + self.system_prompt |
| 57 | + prompt_text += "\nUser:\n" + self.task_desc + "\nAssistant:\n" |
| 58 | + else: |
| 59 | + prompt_text += "User:\n" + self.task_desc + "\nAssistant:\n" |
| 60 | + return prompt_text |
| 61 | + |
| 62 | + def run(self) -> List[Experience]: |
| 63 | + # TODO: Optimize the generate function |
| 64 | + if not self.use_base: |
| 65 | + messages = self.format_messages() |
| 66 | + else: |
| 67 | + prompt_text = self.format_prompt() |
| 68 | + |
| 69 | + logger.debug("start chat") |
| 70 | + if not self.use_base: |
| 71 | + responses = self.model.chat(messages, **self.rollout_args) |
| 72 | + else: |
| 73 | + responses = self.model.generate([prompt_text], **self.rollout_args) |
| 74 | + |
| 75 | + for response in responses: |
| 76 | + reward = MathBoxedRewardFn()( # type: ignore [misc] |
| 77 | + response=response.response_text, # type: ignore [arg-type] |
| 78 | + truth=self.truth, |
| 79 | + return_dict=self.is_eval, |
| 80 | + with_think=self.with_think, |
| 81 | + format_score_coef=self.format_score_coef, |
| 82 | + ) |
| 83 | + logger.debug( |
| 84 | + f"self.task_desc: {self.task_desc}, messages: {messages}, response: {response.response_text}, reward: {reward}" |
| 85 | + ) |
| 86 | + if isinstance(reward, dict): |
| 87 | + if response.metrics is None: |
| 88 | + response.metrics = {} |
| 89 | + response.metrics.update(reward) |
| 90 | + reward = sum(reward.values()) |
| 91 | + response.reward = reward |
| 92 | + return responses |
0 commit comments