@@ -31,6 +31,7 @@ def reset(self, task: Task):
3131 self .is_eval = task .is_eval
3232
3333 self .workflow_args = task .workflow_args
34+ self .reward_fn_args = task .reward_fn_args
3435
3536 self .use_base = self .workflow_args .get ("use_base" , False )
3637 self .with_think = self .workflow_args .get ("with_think" , False )
@@ -49,9 +50,18 @@ def reset(self, task: Task):
4950 self .system_prompt = default_prompt
5051
5152 if task .reward_fn is None :
52- self .reward_fn = MathBoxedRewardFn ()
53+ self .reward_fn = MathBoxedRewardFn (** self . reward_fn_args )
5354 else :
54- self .reward_fn = task .reward_fn
55+ self .reward_fn = task .reward_fn (** self .reward_fn_args )
56+
57+ def format_prompt (self ):
58+ prompt_text = ""
59+ if self .system_prompt :
60+ prompt_text += "System:" + self .system_prompt
61+ prompt_text += "\n User:\n " + self .task_desc + "\n Assistant:\n "
62+ else :
63+ prompt_text += "User:\n " + self .task_desc + "\n Assistant:\n "
64+ return prompt_text
5565
5666 def run (self ) -> List [Experience ]:
5767 if not self .use_base :
@@ -71,6 +81,7 @@ def run(self) -> List[Experience]:
7181 truth = self .truth ,
7282 with_think = self .with_think ,
7383 format_score_coef = self .format_score_coef ,
84+ response_token = response .tokens [response .prompt_length :],
7485 )
7586
7687 if response .metrics is None :
@@ -79,7 +90,12 @@ def run(self) -> List[Experience]:
7990 reward = sum (reward_dict .values ())
8091 response .reward = reward
8192
82- logger .debug (
83- f"self.task_desc: { self .task_desc } , messages: { messages } , response: { response .response_text } , reward: { reward } "
84- )
93+ if not self .use_base :
94+ logger .debug (
95+ f"self.task_desc: { self .task_desc } , messages: { messages } , response: { response .response_text } , reward: { reward } "
96+ )
97+ else :
98+ logger .debug (
99+ f"self.task_desc: { self .task_desc } , prompt_text: { prompt_text } , response: { response .response_text } , reward: { reward } "
100+ )
85101 return responses
0 commit comments