11115. Train with importance_sampling loss
1212"""
1313
14+ from dataclasses import asdict
1415from typing import List , Optional
1516
1617import openai
1718
1819from trinity .common .experience import Experience
1920from trinity .common .models .model import ModelWrapper
20- from trinity .common .workflows .workflow import WORKFLOWS , BaseSimpleWorkflow , Task
21+ from trinity .common .rewards .qwen25_eval import verify_math_answer
22+ from trinity .common .workflows .workflow import WORKFLOWS , Task , Workflow
2123
2224
2325@WORKFLOWS .register_module ("on_policy_distill_workflow" )
24- class OnPolicyDistillWorkflow (BaseSimpleWorkflow ):
26+ class OnPolicyDistillWorkflow (Workflow ):
2527 """On-policy distillation workflow.
2628
2729 Computes and stores teacher_logprobs in experience.info.
2830 The advantage_fn in trainer will compute:
2931 advantages = teacher_logprobs - student_logprobs
32+
33+ Note: This workflow does NOT use reward_fn because:
34+ - Advantage is computed from teacher-student logprobs difference
35+ - No external reward signal is needed
3036 """
3137
3238 is_async : bool = True
@@ -41,8 +47,13 @@ def __init__(
4147 auxiliary_models : Optional [List [openai .OpenAI ]] = None ,
4248 auxiliary_model_wrappers : Optional [List [ModelWrapper ]] = None ,
4349 ):
44- super ().__init__ (task = task , model = model , auxiliary_models = auxiliary_models )
45- self .auxiliary_model_wrappers = auxiliary_model_wrappers
50+ super ().__init__ (
51+ task = task ,
52+ model = model ,
53+ auxiliary_models = auxiliary_models ,
54+ auxiliary_model_wrappers = auxiliary_model_wrappers ,
55+ )
56+ self .reset (task )
4657
4758 assert (
4859 auxiliary_model_wrappers is not None and len (auxiliary_model_wrappers ) >= 1
@@ -51,6 +62,49 @@ def __init__(
5162
5263 self .temperature = task .workflow_args .get ("temperature" , 1.0 )
5364
65+ def reset (self , task : Task ):
66+ """Reset the workflow with a new task.
67+
68+ Unlike BaseSimpleWorkflow, this does NOT require reward_fn.
69+ """
70+ self .task = task
71+ self .format_args = task .format_args
72+ self .system_prompt = task .format_args .system_prompt
73+ self .reply_prefix = task .format_args .reply_prefix
74+ self .raw_task = task .raw_task
75+ self .task_desc = task .task_desc
76+ self .truth = task .truth
77+
78+ def set_repeat_times (self , repeat_times , run_id_base ):
79+ self .repeat_times = repeat_times
80+ self .task .rollout_args .n = repeat_times
81+ self .run_id_base = run_id_base
82+
83+ @property
84+ def rollout_args (self ):
85+ return asdict (self .task .rollout_args )
86+
87+ def format_messages (self ):
88+ """Format messages for the instruct model.
89+
90+ Default format: system_prompt (optional) + task_desc + reply_prefix (optional)
91+ """
92+ messages = []
93+ if self .system_prompt :
94+ messages .append ({"role" : "system" , "content" : self .system_prompt })
95+ messages .append ({"role" : "user" , "content" : self .task_desc })
96+ if self .reply_prefix :
97+ messages .append ({"role" : "assistant" , "content" : self .reply_prefix })
98+ return messages
99+
100+ def compute_reward (self , response : Experience ) -> float :
101+ """Compute reward for a response.
102+
103+ In base class, returns 0.0 as advantage is computed from teacher-student logprobs.
104+ Subclasses can override this to compute actual rewards.
105+ """
106+ return 0.0
107+
54108 async def run_async (self ) -> List [Experience ]:
55109 messages = self .format_messages ()
56110
@@ -79,13 +133,16 @@ async def run_async(self) -> List[Experience]:
79133 # Step 3: Store teacher_logprobs for advantage_fn
80134 response .teacher_logprobs = teacher_resp_logprobs
81135
82- # Set a dummy reward (actual advantage computed by advantage_fn)
83- response .reward = 0.0
84- response .eid .run = i + self .run_id_base
85-
86- # Metrics for monitoring
136+ # Initialize metrics
87137 if response .metrics is None :
88138 response .metrics = {}
139+
140+ # Compute reward (subclasses can override compute_reward)
141+ response .reward = self .compute_reward (response )
142+
143+ response .eid .run = i + self .run_id_base
144+
145+ # KL divergence for monitoring
89146 kl = (student_resp_logprobs - teacher_resp_logprobs ).sum ().item ()
90147 response .metrics ["kl_divergence" ] = kl
91148
@@ -94,4 +151,53 @@ async def run_async(self) -> List[Experience]:
94151
95152@WORKFLOWS .register_module ("async_on_policy_distill_workflow" )
96153class AsyncOnPolicyDistillWorkflow (OnPolicyDistillWorkflow ):
154+ """Alias for OnPolicyDistillWorkflow (already async)."""
155+
156+ pass
157+
158+
159+ @WORKFLOWS .register_module ("on_policy_distill_math_workflow" )
160+ class OnPolicyDistillMathWorkflow (OnPolicyDistillWorkflow ):
161+ """On-policy distillation workflow with Qwen2.5-Math style format.
162+
163+ This workflow:
164+ - Uses Qwen2.5-Math style prompt format (same as math_eval_workflow)
165+ - Computes accuracy using verify_math_answer as reward
166+ - Suitable for math reasoning tasks like GSM8K, MATH, etc.
167+ """
168+
169+ def format_messages (self ):
170+ """Format messages using Qwen2.5-Math style.
171+
172+ System prompt: "You are a helpful assistant."
173+ User prompt: "{question}\n Please reason step by step, and put your final answer within \\ boxed{}."
174+ """
175+ system_prompt = "You are a helpful assistant."
176+ user_prompt = f"{ self .task_desc } \n Please reason step by step, and put your final answer within \\ boxed{{}}."
177+ return [
178+ {"role" : "system" , "content" : system_prompt },
179+ {"role" : "user" , "content" : user_prompt },
180+ ]
181+
182+ def compute_reward (self , response : Experience ) -> float :
183+ """Compute accuracy as reward using Qwen2.5-Math evaluation.
184+
185+ Returns 1.0 if answer is correct, 0.0 otherwise.
186+ """
187+ if response .response_text and self .truth :
188+ accuracy , _ = verify_math_answer (
189+ response_text = response .response_text , ground_truth = self .truth
190+ )
191+ # Store accuracy in metrics
192+ if response .metrics is None :
193+ response .metrics = {}
194+ response .metrics ["accuracy" ] = accuracy
195+ return float (accuracy )
196+ return 0.0
197+
198+
199+ @WORKFLOWS .register_module ("async_on_policy_distill_math_workflow" )
200+ class AsyncOnPolicyDistillMathWorkflow (OnPolicyDistillMathWorkflow ):
201+ """Alias for OnPolicyDistillMathWorkflow (already async)."""
202+
97203 pass
0 commit comments