Skip to content

Commit bd235db

Browse files
authored
Standardize Experience and Sample Strategy (#141)
1 parent 6cc30e9 commit bd235db

File tree

17 files changed

+411
-398
lines changed

17 files changed

+411
-398
lines changed

docs/sphinx_doc/source/tutorial/example_mix_algo.md

Lines changed: 30 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ We need to read two kinds of experiences: usual experiences and expert experienc
7676
class MixSampleStrategy(SampleStrategy):
7777
"""The default sample strategy."""
7878

79-
def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
80-
super().__init__(buffer_config, trainer_type)
79+
def __init__(self, buffer_config: BufferConfig, **kwargs):
80+
super().__init__(buffer_config)
8181
self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5)
8282
tot_batch_size = buffer_config.read_batch_size
8383
expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size)
@@ -101,7 +101,7 @@ class MixSampleStrategy(SampleStrategy):
101101
buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config
102102
)
103103

104-
def sample(self, step: int) -> Tuple[Any, Dict, List]:
104+
def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
105105
metrics = {}
106106
with Timer(metrics, "read_time"):
107107
usual_exp_list = self.usual_exp_buffer.read()
@@ -113,63 +113,32 @@ class MixSampleStrategy(SampleStrategy):
113113
expert_exp_list = self.expert_exp_buffer.read()
114114
for exp in expert_exp_list:
115115
exp.reward = 0.0
116-
exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32)
116+
exp.logprobs = torch.zeros_like(
117+
exp.tokens[exp.prompt_length :], dtype=torch.float32
118+
)
117119
if exp.info is None:
118120
exp.info = {}
119121
exp.info["is_expert"] = True
120122

121123
exp_list = usual_exp_list + expert_exp_list
122124
repr_samples = representative_sample(exp_list)
123125

124-
is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool)
125-
126126
with Timer(metrics, "gather_time"):
127-
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
128-
129-
if self.trainer_type == "verl":
130-
with Timer(metrics, "convert_time"):
131-
data = to_data_proto_mix(exps, is_expert_mask)
132-
return data, metrics, repr_samples
133-
else:
134-
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
127+
exps = Experiences.gather_experiences(
128+
experiences=exp_list,
129+
pad_token_id=self.pad_token_id, # type: ignore [arg-type]
130+
custom_fields=[
131+
CustomField(
132+
source_field="is_expert",
133+
destination_field="expert_mask",
134+
data_type=torch.bool,
135+
)
136+
],
137+
) # type: ignore
138+
return exps, metrics, repr_samples
135139
```
136140

137-
We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type.
138-
139-
```diff
140-
+ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto:
141-
attention_mask = experiences.attention_masks
142-
cumsum = torch.cumsum(attention_mask, dim=-1)
143-
position_ids = torch.clip(cumsum - 1, 0, None).long()
144-
batch_dict = {
145-
"uid": np.array([eid.tid for eid in experiences.eids]),
146-
"unique_ids": np.array([eid.uid for eid in experiences.eids]),
147-
"position_ids": position_ids,
148-
"input_ids": experiences.tokens.long(),
149-
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
150-
"attention_mask": attention_mask.long(),
151-
"response_mask": (
152-
experiences.action_masks[:, experiences.prompt_length :].long()
153-
if hasattr(experiences, "action_masks") and experiences.action_masks is not None
154-
else attention_mask[:, experiences.prompt_length :].long()
155-
),
156-
+ "is_expert_mask": is_expert_mask,
157-
}
158-
if experiences.rewards is not None:
159-
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
160-
eos_mask_idx = cumsum.argmax(dim=-1)
161-
token_level_rewards[
162-
torch.arange(experiences.batch_size), eos_mask_idx
163-
] = experiences.rewards
164-
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
165-
batch_dict.update(
166-
{
167-
"token_level_scores": token_level_rewards,
168-
"old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
169-
}
170-
)
171-
return DataProto.from_single_dict(batch_dict)
172-
```
141+
Here we use the `custom_fields` argument of `Experiences.gather_experiences` to add a new field `expert_mask`, which indicates whether the experience is from an expert or not. This field will be used in the policy loss function to distinguish between usual and expert experiences.
173142

174143

175144
## Step 3: Define the Policy Loss Function
@@ -217,15 +186,15 @@ class MIXPolicyLossFn(PolicyLossFn):
217186
old_logprob: torch.Tensor,
218187
action_mask: torch.Tensor,
219188
advantages: torch.Tensor,
220-
is_expert_mask: torch.Tensor,
189+
expert_mask: torch.Tensor,
221190
**kwargs,
222191
) -> Tuple[torch.Tensor, Dict]:
223192
assert (
224-
len(is_expert_mask) == logprob.shape[0]
225-
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
193+
len(expert_mask) == logprob.shape[0]
194+
), f"Error: {len(expert_mask)=} != {logprob.shape[0]=}"
226195

227-
n_usual_exp = torch.sum(~is_expert_mask).item()
228-
n_expert_exp = torch.sum(is_expert_mask).item()
196+
n_usual_exp = torch.sum(~expert_mask).item()
197+
n_expert_exp = torch.sum(expert_mask).item()
229198

230199
if self.use_dynamic_bsz:
231200
per_micro_batch_weight_usual = self.experience_per_gpu / (
@@ -240,10 +209,10 @@ class MIXPolicyLossFn(PolicyLossFn):
240209

241210
if n_usual_exp > 0:
242211
grpo_loss, grpo_metrics = self.grpo_loss_fn(
243-
logprob[~is_expert_mask],
244-
old_logprob[~is_expert_mask],
245-
action_mask[~is_expert_mask],
246-
advantages[~is_expert_mask],
212+
logprob[~expert_mask],
213+
old_logprob[~expert_mask],
214+
action_mask[~expert_mask],
215+
advantages[~expert_mask],
247216
**kwargs,
248217
)
249218
grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual
@@ -257,8 +226,8 @@ class MIXPolicyLossFn(PolicyLossFn):
257226
# SFT Loss (expert)
258227
if n_expert_exp > 0:
259228
sft_loss, sft_metrics = self.sft_loss_fn(
260-
logprob[is_expert_mask],
261-
action_mask[is_expert_mask],
229+
logprob[expert_mask],
230+
action_mask[expert_mask],
262231
)
263232
sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert
264233
sft_metrics = {

tests/algorithm/policy_loss_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def setUp(self):
2626
"ref_log_prob": 2 * torch.rand(shape) - 1,
2727
"response_mask": torch.rand(shape) > 0.5,
2828
"advantages": 2 * torch.rand(shape) - 1,
29-
"is_expert_mask": torch.rand(shape[0]) > 0.5,
29+
"expert_mask": torch.rand(shape[0]) > 0.5,
3030
}
3131
)
3232

tests/common/experience_test.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ def test_eid_properties(self):
4343
class TestExperience(unittest.TestCase):
4444
def test_single_turn_experience(self):
4545
tokens = torch.tensor([10, 11, 12], dtype=torch.int32)
46-
logprobs = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32)
46+
logprobs = torch.tensor([0.2, 0.3], dtype=torch.float32)
4747
exp = Experience(tokens=tokens, logprobs=logprobs, reward=1.0, prompt_length=1)
4848
self.assertEqual(exp.experience_type.name, "SINGLE_TURN")
4949
self.assertTrue(torch.equal(exp.tokens, tokens))
5050
self.assertTrue(torch.equal(exp.logprobs, logprobs))
5151
self.assertEqual(exp.reward, 1.0)
5252
self.assertEqual(exp.prompt_length, 1)
53-
self.assertTrue(torch.equal(exp.action_mask, torch.tensor([0, 1, 1], dtype=torch.bool)))
53+
self.assertTrue(torch.equal(exp.action_mask, torch.tensor([1, 1], dtype=torch.bool)))
5454

5555
def test_multi_turn_experience(self):
5656
tokens = torch.tensor([1, 2, 3, 4])
@@ -171,13 +171,17 @@ def test_batch_conversion(self):
171171
tokens=torch.tensor([1, 2]),
172172
prompt_length=1,
173173
reward=float(0.1),
174-
logprobs=torch.tensor([0, 0.1]),
174+
logprobs=torch.tensor([0.1]),
175+
advantages=torch.tensor([0.1]),
176+
returns=torch.tensor([0.4]),
175177
),
176178
Experience(
177179
tokens=torch.tensor([1, 2, 3]),
178180
prompt_length=2,
179181
reward=float(0.2),
180-
logprobs=torch.tensor([0, 0, 0.1]),
182+
logprobs=torch.tensor([0.1]),
183+
advantages=torch.tensor([0.3]),
184+
returns=torch.tensor([0.2]),
181185
),
182186
]
183187
batch = Experiences.gather_experiences(exps)
@@ -199,45 +203,53 @@ def test_batch_conversion(self):
199203
)
200204
self.assertTrue(
201205
torch.all(
202-
batch.logprobs[i][
203-
prompt_length
204-
- exps[i].prompt_length : prompt_length
205-
+ exps[i].tokens.size(0)
206-
- exps[i].prompt_length
207-
]
206+
batch.logprobs[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
208207
== exps[i].logprobs
209208
)
210209
)
211210
self.assertTrue(
212211
torch.all(
213-
batch.action_masks[i][
214-
prompt_length
215-
- exps[i].prompt_length : prompt_length
216-
- exps[i].prompt_length
217-
+ exps[i].action_mask.size(0)
218-
]
212+
batch.action_masks[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
219213
== exps[i].action_mask
220214
)
221215
)
216+
self.assertTrue(
217+
torch.all(
218+
batch.advantages[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
219+
== exps[i].advantages
220+
)
221+
)
222+
self.assertTrue(
223+
torch.all(
224+
batch.returns[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
225+
== exps[i].returns
226+
)
227+
)
222228

223229
def test_multiturn_experience_batch_converstion(self):
224230
exps = [
225231
Experience(
226-
tokens=torch.tensor([1, 2, 3, 4]),
232+
tokens=torch.tensor([1, 2, 3, 4, 5, 6]),
227233
reward=float(0.3),
228-
logprobs=torch.tensor([0, 0, 0.1, 0.2]),
229-
action_mask=torch.tensor([1, 0, 1, 0]),
234+
logprobs=torch.tensor([0, 0.1, 0.2, 0.3]),
235+
prompt_length=2,
236+
action_mask=torch.tensor([1, 0, 1, 1]),
237+
advantages=torch.tensor([0.1, 0, 0.2, 0.3]),
238+
returns=torch.tensor([0.5, 0, 0.7, 0.8]),
230239
),
231240
Experience(
232241
tokens=torch.tensor([1, 2, 3, 4]),
233242
reward=float(0.4),
234-
logprobs=torch.tensor([0, 0, 0, 0.1]),
235-
action_mask=torch.tensor([1, 0, 0, 1]),
243+
logprobs=torch.tensor([0, 0.1]),
244+
prompt_length=2,
245+
action_mask=torch.tensor([1, 1]),
246+
advantages=torch.tensor([0.2, 0.3]),
247+
returns=torch.tensor([0.6, 0.9]),
236248
),
237249
]
238250
batch = Experiences.gather_experiences(exps)
239251
self.assertEqual(batch.batch_size, 2)
240-
self.assertEqual(batch.prompt_length, 1)
252+
self.assertEqual(batch.prompt_length, 2)
241253
prompt_length = batch.prompt_length
242254
for i in range(batch.batch_size):
243255
self.assertEqual(batch.rewards[i], exps[i].reward)
@@ -254,26 +266,28 @@ def test_multiturn_experience_batch_converstion(self):
254266
)
255267
self.assertTrue(
256268
torch.all(
257-
batch.logprobs[i][
258-
prompt_length
259-
- exps[i].prompt_length : prompt_length
260-
+ exps[i].tokens.size(0)
261-
- exps[i].prompt_length
262-
]
269+
batch.logprobs[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
263270
== exps[i].logprobs
264271
)
265272
)
266273
self.assertTrue(
267274
torch.all(
268-
batch.action_masks[i][
269-
prompt_length
270-
- exps[i].prompt_length : prompt_length
271-
- exps[i].prompt_length
272-
+ exps[i].action_mask.size(0)
273-
]
275+
batch.action_masks[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
274276
== exps[i].action_mask
275277
)
276278
)
279+
self.assertTrue(
280+
torch.all(
281+
batch.advantages[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
282+
== exps[i].advantages
283+
)
284+
)
285+
self.assertTrue(
286+
torch.all(
287+
batch.returns[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
288+
== exps[i].returns
289+
)
290+
)
277291

278292
def test_dpo_experience_batch_conversion(self):
279293
exps = [

tests/common/vllm_test.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,12 @@ async def test_generate(
159159
self.assertEqual(exp.prompt_length, history_exp.prompt_length)
160160
self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist())
161161
for result in results:
162-
input_logprobs = result.logprobs[: result.prompt_length]
163-
output_logprobs = result.logprobs[result.prompt_length :]
164-
self.assertTrue(torch.all(input_logprobs == 0))
165-
self.assertTrue(torch.any(output_logprobs != 0))
162+
self.assertTrue(torch.any(result.logprobs != 0))
166163
if self.use_async:
167164
logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist())
168165
else:
169166
logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist())
170-
self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0])
167+
self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0] - 1)
171168
if self.config.explorer.rollout_model.enable_history:
172169
history_experiences = self.model_wrapper.extract_experience_from_history()
173170
self.assertTrue(len(history_experiences) == 0)
@@ -190,7 +187,10 @@ async def test_generate(
190187
return_assistant_tokens_mask=True,
191188
return_dict=True,
192189
)
193-
self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask))
190+
prompt_length = torch.argmax(result_dict["assistant_masks"][0]).item()
191+
self.assertTrue(
192+
torch.equal(result_dict["assistant_masks"][0][prompt_length:], exp.action_mask)
193+
)
194194
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
195195
self.assertRaises(ValueError, self.model_wrapper.get_openai_client)
196196
if self.config.explorer.rollout_model.enable_history:
@@ -284,12 +284,12 @@ def test_assistant_token_mask(self):
284284
},
285285
]
286286
tokenizer = AutoTokenizer.from_pretrained(get_model_path())
287-
token_ids, action_mask = tokenize_and_mask_messages_default(
287+
token_ids, action_mask, prompt_length = tokenize_and_mask_messages_default(
288288
tokenizer=tokenizer,
289289
messages=messages,
290290
chat_template=CHAT_TEMPLATE,
291291
)
292-
token_ids_hf, action_mask_hf = tokenize_and_mask_messages_hf(
292+
token_ids_hf, action_mask_hf, prompt_length_hf = tokenize_and_mask_messages_hf(
293293
tokenizer=tokenizer,
294294
messages=messages,
295295
chat_template=CHAT_TEMPLATE,
@@ -298,3 +298,4 @@ def test_assistant_token_mask(self):
298298
self.assertEqual(action_mask.shape, action_mask_hf.shape)
299299
self.assertTrue(torch.equal(token_ids, token_ids_hf))
300300
self.assertTrue(torch.equal(action_mask, action_mask_hf))
301+
self.assertEqual(prompt_length, prompt_length_hf)

0 commit comments

Comments
 (0)