Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions tests/common/experience_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ def test_eid_properties(self):
class TestExperience(unittest.TestCase):
def test_single_turn_experience(self):
tokens = torch.tensor([10, 11, 12], dtype=torch.int32)
logprobs = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32)
logprobs = torch.tensor([0.2, 0.3], dtype=torch.float32)
exp = Experience(tokens=tokens, logprobs=logprobs, reward=1.0, prompt_length=1)
self.assertEqual(exp.experience_type.name, "SINGLE_TURN")
self.assertTrue(torch.equal(exp.tokens, tokens))
self.assertTrue(torch.equal(exp.logprobs, logprobs))
self.assertEqual(exp.reward, 1.0)
self.assertEqual(exp.prompt_length, 1)
self.assertTrue(torch.equal(exp.action_mask, torch.tensor([0, 1, 1], dtype=torch.bool)))
self.assertTrue(torch.equal(exp.action_mask, torch.tensor([1, 1], dtype=torch.bool)))

def test_multi_turn_experience(self):
tokens = torch.tensor([1, 2, 3, 4])
Expand Down Expand Up @@ -171,13 +171,17 @@ def test_batch_conversion(self):
tokens=torch.tensor([1, 2]),
prompt_length=1,
reward=float(0.1),
logprobs=torch.tensor([0, 0.1]),
logprobs=torch.tensor([0.1]),
advantages=torch.tensor([0.1]),
returns=torch.tensor([0.4]),
),
Experience(
tokens=torch.tensor([1, 2, 3]),
prompt_length=2,
reward=float(0.2),
logprobs=torch.tensor([0, 0, 0.1]),
logprobs=torch.tensor([0.1]),
advantages=torch.tensor([0.3]),
returns=torch.tensor([0.2]),
),
]
batch = Experiences.gather_experiences(exps)
Expand All @@ -199,45 +203,53 @@ def test_batch_conversion(self):
)
self.assertTrue(
torch.all(
batch.logprobs[i][
prompt_length
- exps[i].prompt_length : prompt_length
+ exps[i].tokens.size(0)
- exps[i].prompt_length
]
batch.logprobs[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].logprobs
)
)
self.assertTrue(
torch.all(
batch.action_masks[i][
prompt_length
- exps[i].prompt_length : prompt_length
- exps[i].prompt_length
+ exps[i].action_mask.size(0)
]
batch.action_masks[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].action_mask
)
)
self.assertTrue(
torch.all(
batch.advantages[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].advantages
)
)
self.assertTrue(
torch.all(
batch.returns[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].returns
)
)

def test_multiturn_experience_batch_converstion(self):
exps = [
Experience(
tokens=torch.tensor([1, 2, 3, 4]),
tokens=torch.tensor([1, 2, 3, 4, 5, 6]),
reward=float(0.3),
logprobs=torch.tensor([0, 0, 0.1, 0.2]),
action_mask=torch.tensor([1, 0, 1, 0]),
logprobs=torch.tensor([0, 0.1, 0.2, 0.3]),
prompt_length=2,
action_mask=torch.tensor([1, 0, 1, 1]),
advantages=torch.tensor([0.1, 0, 0.2, 0.3]),
returns=torch.tensor([0.5, 0, 0.7, 0.8]),
),
Experience(
tokens=torch.tensor([1, 2, 3, 4]),
reward=float(0.4),
logprobs=torch.tensor([0, 0, 0, 0.1]),
action_mask=torch.tensor([1, 0, 0, 1]),
logprobs=torch.tensor([0, 0.1]),
prompt_length=2,
action_mask=torch.tensor([1, 1]),
advantages=torch.tensor([0.2, 0.3]),
returns=torch.tensor([0.6, 0.9]),
),
]
batch = Experiences.gather_experiences(exps)
self.assertEqual(batch.batch_size, 2)
self.assertEqual(batch.prompt_length, 1)
self.assertEqual(batch.prompt_length, 2)
prompt_length = batch.prompt_length
for i in range(batch.batch_size):
self.assertEqual(batch.rewards[i], exps[i].reward)
Expand All @@ -254,26 +266,28 @@ def test_multiturn_experience_batch_converstion(self):
)
self.assertTrue(
torch.all(
batch.logprobs[i][
prompt_length
- exps[i].prompt_length : prompt_length
+ exps[i].tokens.size(0)
- exps[i].prompt_length
]
batch.logprobs[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].logprobs
)
)
self.assertTrue(
torch.all(
batch.action_masks[i][
prompt_length
- exps[i].prompt_length : prompt_length
- exps[i].prompt_length
+ exps[i].action_mask.size(0)
]
batch.action_masks[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].action_mask
)
)
self.assertTrue(
torch.all(
batch.advantages[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].advantages
)
)
self.assertTrue(
torch.all(
batch.returns[i][: exps[i].tokens.size(0) - exps[i].prompt_length]
== exps[i].returns
)
)

def test_dpo_experience_batch_conversion(self):
exps = [
Expand Down
5 changes: 1 addition & 4 deletions tests/common/vllm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,7 @@ async def test_generate(
self.assertEqual(exp.prompt_length, history_exp.prompt_length)
self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist())
for result in results:
input_logprobs = result.logprobs[: result.prompt_length]
output_logprobs = result.logprobs[result.prompt_length :]
self.assertTrue(torch.all(input_logprobs == 0))
self.assertTrue(torch.any(output_logprobs != 0))
self.assertTrue(torch.any(result.logprobs != 0))
if self.use_async:
logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist())
else:
Expand Down
6 changes: 3 additions & 3 deletions trinity/algorithm/sample_strategy/mix_sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def sample(self, step: int) -> Tuple[Any, Dict, List]:
expert_exp_list = self.expert_exp_buffer.read()
for exp in expert_exp_list:
exp.reward = 0.0
exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32)
exp.logprobs = torch.zeros_like(exp.tokens[exp.prompt_length:], dtype=torch.float32)
if exp.info is None:
exp.info = {}
exp.info["is_expert"] = True
Expand Down Expand Up @@ -98,7 +98,7 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor):
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
"attention_mask": attention_mask.long(),
"response_mask": (
experiences.action_masks[:, experiences.prompt_length :].long()
experiences.action_masks.long()
if hasattr(experiences, "action_masks") and experiences.action_masks is not None
else attention_mask[:, experiences.prompt_length :].long()
),
Expand All @@ -114,7 +114,7 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor):
batch_dict.update(
{
"token_level_scores": token_level_rewards,
"old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
"old_log_probs": experiences.logprobs, # type: ignore
}
)
return DataProto.from_single_dict(batch_dict)
4 changes: 2 additions & 2 deletions trinity/algorithm/sample_strategy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def to_data_proto(experiences: Experiences) -> DataProto:
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
"attention_mask": attention_mask.long(),
"response_mask": (
experiences.action_masks[:, experiences.prompt_length :].long()
experiences.action_masks.long()
if hasattr(experiences, "action_masks") and experiences.action_masks is not None
else attention_mask[:, experiences.prompt_length :].long()
),
Expand All @@ -35,7 +35,7 @@ def to_data_proto(experiences: Experiences) -> DataProto:
batch_dict.update(
{
"token_level_scores": token_level_rewards,
"old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore
"old_log_probs": experiences.logprobs, # type: ignore
}
)
return DataProto.from_single_dict(batch_dict)
Expand Down
3 changes: 2 additions & 1 deletion trinity/buffer/schema/sql_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,15 @@ def from_messages(
"""Convert a list of messages into a single instance of SFT data."""
from trinity.common.models.utils import tokenize_and_mask_messages_hf

tokens, action_mask = tokenize_and_mask_messages_hf(
tokens, action_mask, prompt_length = tokenize_and_mask_messages_hf(
tokenizer=tokenizer,
messages=messages,
chat_template=chat_template,
)
exp = Experience(
tokens=tokens,
action_mask=action_mask,
prompt_length=prompt_length,
info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])},
)
return cls(
Expand Down
Loading
Loading