Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 30 additions & 61 deletions docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ We need to read two kinds of experiences: usual experiences and expert experienc
class MixSampleStrategy(SampleStrategy):
"""The default sample strategy."""

def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs):
super().__init__(buffer_config, trainer_type)
def __init__(self, buffer_config: BufferConfig, **kwargs):
super().__init__(buffer_config)
self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5)
tot_batch_size = buffer_config.read_batch_size
expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size)
Expand All @@ -101,7 +101,7 @@ class MixSampleStrategy(SampleStrategy):
buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config
)

def sample(self, step: int) -> Tuple[Any, Dict, List]:
def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
metrics = {}
with Timer(metrics, "read_time"):
usual_exp_list = self.usual_exp_buffer.read()
Expand All @@ -113,63 +113,32 @@ class MixSampleStrategy(SampleStrategy):
expert_exp_list = self.expert_exp_buffer.read()
for exp in expert_exp_list:
exp.reward = 0.0
exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32)
exp.logprobs = torch.zeros_like(
exp.tokens[exp.prompt_length :], dtype=torch.float32
)
if exp.info is None:
exp.info = {}
exp.info["is_expert"] = True

exp_list = usual_exp_list + expert_exp_list
repr_samples = representative_sample(exp_list)

is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool)

with Timer(metrics, "gather_time"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore

if self.trainer_type == "verl":
with Timer(metrics, "convert_time"):
data = to_data_proto_mix(exps, is_expert_mask)
return data, metrics, repr_samples
else:
raise NotImplementedError(f"backend {self.trainer_type} is not supported")
exps = Experiences.gather_experiences(
experiences=exp_list,
pad_token_id=self.pad_token_id, # type: ignore [arg-type]
custom_fields=[
CustomField(
source_field="is_expert",
destination_field="expert_mask",
data_type=torch.bool,
)
],
) # type: ignore
return exps, metrics, repr_samples
```

We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type.

```diff
+ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto:
attention_mask = experiences.attention_masks
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
batch_dict = {
"uid": np.array([eid.tid for eid in experiences.eids]),
"unique_ids": np.array([eid.uid for eid in experiences.eids]),
"position_ids": position_ids,
"input_ids": experiences.tokens.long(),
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
"attention_mask": attention_mask.long(),
"response_mask": (
experiences.action_masks[:, experiences.prompt_length :].long()
if hasattr(experiences, "action_masks") and experiences.action_masks is not None
else attention_mask[:, experiences.prompt_length :].long()
),
+ "is_expert_mask": is_expert_mask,
}
if experiences.rewards is not None:
token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype)
eos_mask_idx = cumsum.argmax(dim=-1)
token_level_rewards[
torch.arange(experiences.batch_size), eos_mask_idx
] = experiences.rewards
token_level_rewards = token_level_rewards[:, experiences.prompt_length :]
batch_dict.update(
{
"token_level_scores": token_level_rewards,
"old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
}
)
return DataProto.from_single_dict(batch_dict)
```
Here we use the `custom_fields` argument of `Experiences.gather_experiences` to add a new field `expert_mask`, which indicates whether the experience is from an expert or not. This field will be used in the policy loss function to distinguish between usual and expert experiences.


## Step 3: Define the Policy Loss Function
Expand Down Expand Up @@ -217,15 +186,15 @@ class MIXPolicyLossFn(PolicyLossFn):
old_logprob: torch.Tensor,
action_mask: torch.Tensor,
advantages: torch.Tensor,
is_expert_mask: torch.Tensor,
expert_mask: torch.Tensor,
**kwargs,
) -> Tuple[torch.Tensor, Dict]:
assert (
len(is_expert_mask) == logprob.shape[0]
), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}"
len(expert_mask) == logprob.shape[0]
), f"Error: {len(expert_mask)=} != {logprob.shape[0]=}"

n_usual_exp = torch.sum(~is_expert_mask).item()
n_expert_exp = torch.sum(is_expert_mask).item()
n_usual_exp = torch.sum(~expert_mask).item()
n_expert_exp = torch.sum(expert_mask).item()

if self.use_dynamic_bsz:
per_micro_batch_weight_usual = self.experience_per_gpu / (
Expand All @@ -240,10 +209,10 @@ class MIXPolicyLossFn(PolicyLossFn):

if n_usual_exp > 0:
grpo_loss, grpo_metrics = self.grpo_loss_fn(
logprob[~is_expert_mask],
old_logprob[~is_expert_mask],
action_mask[~is_expert_mask],
advantages[~is_expert_mask],
logprob[~expert_mask],
old_logprob[~expert_mask],
action_mask[~expert_mask],
advantages[~expert_mask],
**kwargs,
)
grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual
Expand All @@ -257,8 +226,8 @@ class MIXPolicyLossFn(PolicyLossFn):
# SFT Loss (expert)
if n_expert_exp > 0:
sft_loss, sft_metrics = self.sft_loss_fn(
logprob[is_expert_mask],
action_mask[is_expert_mask],
logprob[expert_mask],
action_mask[expert_mask],
)
sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert
sft_metrics = {
Expand Down
2 changes: 1 addition & 1 deletion tests/algorithm/policy_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def setUp(self):
"ref_log_prob": 2 * torch.rand(shape) - 1,
"response_mask": torch.rand(shape) > 0.5,
"advantages": 2 * torch.rand(shape) - 1,
"is_expert_mask": torch.rand(shape[0]) > 0.5,
"expert_mask": torch.rand(shape[0]) > 0.5,
}
)

Expand Down
82 changes: 48 additions & 34 deletions tests/common/experience_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def test_eid_properties(self):
class TestExperience(unittest.TestCase):
def test_single_turn_experience(self):
tokens = torch.tensor([10, 11, 12], dtype=torch.int32)
logprobs = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32)
logprobs = torch.tensor([0.2, 0.3], dtype=torch.float32)
exp = Experience(tokens=tokens, logprobs=logprobs, reward=1.0, prompt_length=1)
self.assertEqual(exp.experience_type.name, "SINGLE_TURN")
self.assertTrue(torch.equal(exp.tokens, tokens))
self.assertTrue(torch.equal(exp.logprobs, logprobs))
self.assertEqual(exp.reward, 1.0)
self.assertEqual(exp.prompt_length, 1)
self.assertTrue(torch.equal(exp.action_mask, torch.tensor([0, 1, 1], dtype=torch.bool)))
self.assertTrue(torch.equal(exp.action_mask, torch.tensor([1, 1], dtype=torch.bool)))

def test_multi_turn_experience(self):
tokens = torch.tensor([1, 2, 3, 4])
Expand Down Expand Up @@ -171,13 +171,17 @@ def test_batch_conversion(self):
tokens=torch.tensor([1, 2]),
prompt_length=1,
reward=float(0.1),
logprobs=torch.tensor([0, 0.1]),
logprobs=torch.tensor([0.1]),
advantages=torch.tensor([0.1]),
returns=torch.tensor([0.4]),
),
Experience(
tokens=torch.tensor([1, 2, 3]),
prompt_length=2,
reward=float(0.2),
logprobs=torch.tensor([0, 0, 0.1]),
logprobs=torch.tensor([0.1]),
advantages=torch.tensor([0.3]),
returns=torch.tensor([0.2]),
),
]
batch = Experiences.gather_experiences(exps)
Expand All @@ -199,45 +203,53 @@ def test_batch_conversion(self):
)
self.assertTrue(
torch.all(
batch.logprobs[i][
prompt_length
- exps[i].prompt_length : prompt_length
+ exps[i].tokens.size(0)
- exps[i].prompt_length
]
batch.logprobs[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].logprobs
)
)
self.assertTrue(
torch.all(
batch.action_masks[i][
prompt_length
- exps[i].prompt_length : prompt_length
- exps[i].prompt_length
+ exps[i].action_mask.size(0)
]
batch.action_masks[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].action_mask
)
)
self.assertTrue(
torch.all(
batch.advantages[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].advantages
)
)
self.assertTrue(
torch.all(
batch.returns[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].returns
)
)

def test_multiturn_experience_batch_converstion(self):
exps = [
Experience(
tokens=torch.tensor([1, 2, 3, 4]),
tokens=torch.tensor([1, 2, 3, 4, 5, 6]),
reward=float(0.3),
logprobs=torch.tensor([0, 0, 0.1, 0.2]),
action_mask=torch.tensor([1, 0, 1, 0]),
logprobs=torch.tensor([0, 0.1, 0.2, 0.3]),
prompt_length=2,
action_mask=torch.tensor([1, 0, 1, 1]),
advantages=torch.tensor([0.1, 0, 0.2, 0.3]),
returns=torch.tensor([0.5, 0, 0.7, 0.8]),
),
Experience(
tokens=torch.tensor([1, 2, 3, 4]),
reward=float(0.4),
logprobs=torch.tensor([0, 0, 0, 0.1]),
action_mask=torch.tensor([1, 0, 0, 1]),
logprobs=torch.tensor([0, 0.1]),
prompt_length=2,
action_mask=torch.tensor([1, 1]),
advantages=torch.tensor([0.2, 0.3]),
returns=torch.tensor([0.6, 0.9]),
),
]
batch = Experiences.gather_experiences(exps)
self.assertEqual(batch.batch_size, 2)
self.assertEqual(batch.prompt_length, 1)
self.assertEqual(batch.prompt_length, 2)
prompt_length = batch.prompt_length
for i in range(batch.batch_size):
self.assertEqual(batch.rewards[i], exps[i].reward)
Expand All @@ -254,26 +266,28 @@ def test_multiturn_experience_batch_converstion(self):
)
self.assertTrue(
torch.all(
batch.logprobs[i][
prompt_length
- exps[i].prompt_length : prompt_length
+ exps[i].tokens.size(0)
- exps[i].prompt_length
]
batch.logprobs[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].logprobs
)
)
self.assertTrue(
torch.all(
batch.action_masks[i][
prompt_length
- exps[i].prompt_length : prompt_length
- exps[i].prompt_length
+ exps[i].action_mask.size(0)
]
batch.action_masks[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].action_mask
)
)
self.assertTrue(
torch.all(
batch.advantages[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].advantages
)
)
self.assertTrue(
torch.all(
batch.returns[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].returns
)
)

def test_dpo_experience_batch_conversion(self):
exps = [
Expand Down
17 changes: 9 additions & 8 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,12 @@ async def test_generate(
self.assertEqual(exp.prompt_length, history_exp.prompt_length)
self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist())
for result in results:
input_logprobs = result.logprobs[: result.prompt_length]
output_logprobs = result.logprobs[result.prompt_length :]
self.assertTrue(torch.all(input_logprobs == 0))
self.assertTrue(torch.any(output_logprobs != 0))
self.assertTrue(torch.any(result.logprobs != 0))
if self.use_async:
logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist())
else:
logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist())
self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0])
self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0] - 1)
if self.config.explorer.rollout_model.enable_history:
history_experiences = self.model_wrapper.extract_experience_from_history()
self.assertTrue(len(history_experiences) == 0)
Expand All @@ -190,7 +187,10 @@ async def test_generate(
return_assistant_tokens_mask=True,
return_dict=True,
)
self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask))
prompt_length = torch.argmax(result_dict["assistant_masks"][0]).item()
self.assertTrue(
torch.equal(result_dict["assistant_masks"][0][prompt_length:], exp.action_mask)
)
self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens))
self.assertRaises(ValueError, self.model_wrapper.get_openai_client)
if self.config.explorer.rollout_model.enable_history:
Expand Down Expand Up @@ -284,12 +284,12 @@ def test_assistant_token_mask(self):
},
]
tokenizer = AutoTokenizer.from_pretrained(get_model_path())
token_ids, action_mask = tokenize_and_mask_messages_default(
token_ids, action_mask, prompt_length = tokenize_and_mask_messages_default(
tokenizer=tokenizer,
messages=messages,
chat_template=CHAT_TEMPLATE,
)
token_ids_hf, action_mask_hf = tokenize_and_mask_messages_hf(
token_ids_hf, action_mask_hf, prompt_length_hf = tokenize_and_mask_messages_hf(
tokenizer=tokenizer,
messages=messages,
chat_template=CHAT_TEMPLATE,
Expand All @@ -298,3 +298,4 @@ def test_assistant_token_mask(self):
self.assertEqual(action_mask.shape, action_mask_hf.shape)
self.assertTrue(torch.equal(token_ids, token_ids_hf))
self.assertTrue(torch.equal(action_mask, action_mask_hf))
self.assertEqual(prompt_length, prompt_length_hf)
Loading