@@ -76,8 +76,8 @@ We need to read two kinds of experiences: usual experiences and expert experienc
7676class MixSampleStrategy (SampleStrategy ):
7777 """ The default sample strategy."""
7878
79- def __init__ (self , buffer_config : BufferConfig, trainer_type : str , ** kwargs ):
80- super ().__init__ (buffer_config, trainer_type )
79+ def __init__ (self , buffer_config : BufferConfig, ** kwargs ):
80+ super ().__init__ (buffer_config)
8181 self .expert_data_ratio = kwargs.get(" expert_data_ratio" , 0.5 )
8282 tot_batch_size = buffer_config.read_batch_size
8383 expert_batch_size = ceil(self .expert_data_ratio * tot_batch_size)
@@ -101,7 +101,7 @@ class MixSampleStrategy(SampleStrategy):
101101 buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config
102102 )
103103
104- def sample (self , step : int ) -> Tuple[Any , Dict, List]:
104+ def sample (self , step : int ) -> Tuple[Experiences , Dict, List]:
105105 metrics = {}
106106 with Timer(metrics, " read_time" ):
107107 usual_exp_list = self .usual_exp_buffer.read()
@@ -113,63 +113,32 @@ class MixSampleStrategy(SampleStrategy):
113113 expert_exp_list = self .expert_exp_buffer.read()
114114 for exp in expert_exp_list:
115115 exp.reward = 0.0
116- exp.logprobs = torch.zeros_like(exp.tokens, dtype = torch.float32)
116+ exp.logprobs = torch.zeros_like(
117+ exp.tokens[exp.prompt_length :], dtype = torch.float32
118+ )
117119 if exp.info is None :
118120 exp.info = {}
119121 exp.info[" is_expert" ] = True
120122
121123 exp_list = usual_exp_list + expert_exp_list
122124 repr_samples = representative_sample(exp_list)
123125
124- is_expert_mask = torch.tensor([exp.info[" is_expert" ] for exp in exp_list], dtype = torch.bool)
125-
126126 with Timer(metrics, " gather_time" ):
127- exps = Experiences.gather_experiences(exp_list, self .pad_token_id) # type: ignore
128-
129- if self .trainer_type == " verl" :
130- with Timer(metrics, " convert_time" ):
131- data = to_data_proto_mix(exps, is_expert_mask)
132- return data, metrics, repr_samples
133- else :
134- raise NotImplementedError (f " backend { self .trainer_type} is not supported " )
127+ exps = Experiences.gather_experiences(
128+ experiences = exp_list,
129+ pad_token_id = self .pad_token_id, # type: ignore [ arg -type ]
130+ custom_fields = [
131+ CustomField(
132+ source_field = " is_expert" ,
133+ destination_field = " expert_mask" ,
134+ data_type = torch.bool,
135+ )
136+ ],
137+ ) # type: ignore
138+ return exps, metrics, repr_samples
135139```
136140
137- We also need to add an ` is_expert_mask ` field when transforming to DataProto to indicate the data type.
138-
139- ``` diff
140- + def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto:
141- attention_mask = experiences.attention_masks
142- cumsum = torch.cumsum(attention_mask, dim=-1)
143- position_ids = torch.clip(cumsum - 1, 0, None).long()
144- batch_dict = {
145- "uid": np.array([eid.tid for eid in experiences.eids]),
146- "unique_ids": np.array([eid.uid for eid in experiences.eids]),
147- "position_ids": position_ids,
148- "input_ids": experiences.tokens.long(),
149- "responses": experiences.tokens[:, experiences.prompt_length :].long(),
150- "attention_mask": attention_mask.long(),
151- "response_mask": (
152- experiences.action_masks[:, experiences.prompt_length :].long()
153- if hasattr(experiences, "action_masks") and experiences.action_masks is not None
154- else attention_mask[:, experiences.prompt_length :].long()
155- ),
156- + "is_expert_mask": is_expert_mask,
157- }
158- if experiences.rewards is not None:
159- token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
160- eos_mask_idx = cumsum.argmax(dim=-1)
161- token_level_rewards[
162- torch.arange(experiences.batch_size), eos_mask_idx
163- ] = experiences.rewards
164- token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
165- batch_dict.update(
166- {
167- "token_level_scores": token_level_rewards,
168- "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
169- }
170- )
171- return DataProto.from_single_dict(batch_dict)
172- ```
141+ Here we use the ` custom_fields ` argument of ` Experiences.gather_experiences ` to add a new field ` expert_mask ` , which indicates whether the experience is from an expert or not. This field will be used in the policy loss function to distinguish between usual and expert experiences.
173142
174143
175144## Step 3: Define the Policy Loss Function
@@ -217,15 +186,15 @@ class MIXPolicyLossFn(PolicyLossFn):
217186 old_logprob : torch.Tensor,
218187 action_mask : torch.Tensor,
219188 advantages : torch.Tensor,
220- is_expert_mask : torch.Tensor,
189+ expert_mask : torch.Tensor,
221190 ** kwargs ,
222191 ) -> Tuple[torch.Tensor, Dict]:
223192 assert (
224- len (is_expert_mask ) == logprob.shape[0 ]
225- ), f " Error: { len (is_expert_mask )= } != { logprob.shape[0 ]= } "
193+ len (expert_mask ) == logprob.shape[0 ]
194+ ), f " Error: { len (expert_mask )= } != { logprob.shape[0 ]= } "
226195
227- n_usual_exp = torch.sum(~ is_expert_mask ).item()
228- n_expert_exp = torch.sum(is_expert_mask ).item()
196+ n_usual_exp = torch.sum(~ expert_mask ).item()
197+ n_expert_exp = torch.sum(expert_mask ).item()
229198
230199 if self .use_dynamic_bsz:
231200 per_micro_batch_weight_usual = self .experience_per_gpu / (
@@ -240,10 +209,10 @@ class MIXPolicyLossFn(PolicyLossFn):
240209
241210 if n_usual_exp > 0 :
242211 grpo_loss, grpo_metrics = self .grpo_loss_fn(
243- logprob[~ is_expert_mask ],
244- old_logprob[~ is_expert_mask ],
245- action_mask[~ is_expert_mask ],
246- advantages[~ is_expert_mask ],
212+ logprob[~ expert_mask ],
213+ old_logprob[~ expert_mask ],
214+ action_mask[~ expert_mask ],
215+ advantages[~ expert_mask ],
247216 ** kwargs,
248217 )
249218 grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual
@@ -257,8 +226,8 @@ class MIXPolicyLossFn(PolicyLossFn):
257226 # SFT Loss (expert)
258227 if n_expert_exp > 0 :
259228 sft_loss, sft_metrics = self .sft_loss_fn(
260- logprob[is_expert_mask ],
261- action_mask[is_expert_mask ],
229+ logprob[expert_mask ],
230+ action_mask[expert_mask ],
262231 )
263232 sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert
264233 sft_metrics = {
0 commit comments