diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 52aa212b0a..099ee323e4 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -76,8 +76,8 @@ We need to read two kinds of experiences: usual experiences and expert experienc class MixSampleStrategy(SampleStrategy): """The default sample strategy.""" - def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): - super().__init__(buffer_config, trainer_type) + def __init__(self, buffer_config: BufferConfig, **kwargs): + super().__init__(buffer_config) self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) tot_batch_size = buffer_config.read_batch_size expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) @@ -101,7 +101,7 @@ class MixSampleStrategy(SampleStrategy): buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config ) - def sample(self, step: int) -> Tuple[Any, Dict, List]: + def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} with Timer(metrics, "read_time"): usual_exp_list = self.usual_exp_buffer.read() @@ -113,7 +113,9 @@ class MixSampleStrategy(SampleStrategy): expert_exp_list = self.expert_exp_buffer.read() for exp in expert_exp_list: exp.reward = 0.0 - exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + exp.logprobs = torch.zeros_like( + exp.tokens[exp.prompt_length :], dtype=torch.float32 + ) if exp.info is None: exp.info = {} exp.info["is_expert"] = True @@ -121,55 +123,22 @@ class MixSampleStrategy(SampleStrategy): exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) - is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) - with Timer(metrics, "gather_time"): - exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - - if self.trainer_type == "verl": - with Timer(metrics, "convert_time"): - data = to_data_proto_mix(exps, is_expert_mask) - return data, metrics, repr_samples - else: - raise NotImplementedError(f"backend {self.trainer_type} is not supported") + exps = Experiences.gather_experiences( + experiences=exp_list, + pad_token_id=self.pad_token_id, # type: ignore [arg-type] + custom_fields=[ + CustomField( + source_field="is_expert", + destination_field="expert_mask", + data_type=torch.bool, + ) + ], + ) # type: ignore + return exps, metrics, repr_samples ``` -We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type. - -```diff -+ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: - attention_mask = experiences.attention_masks - cumsum = torch.cumsum(attention_mask, dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - batch_dict = { - "uid": np.array([eid.tid for eid in experiences.eids]), - "unique_ids": np.array([eid.uid for eid in experiences.eids]), - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), -+ "is_expert_mask": is_expert_mask, - } - if experiences.rewards is not None: - token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) - eos_mask_idx = cumsum.argmax(dim=-1) - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] - batch_dict.update( - { - "token_level_scores": token_level_rewards, - "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore - } - ) - return DataProto.from_single_dict(batch_dict) -``` +Here we use the `custom_fields` argument of `Experiences.gather_experiences` to add a new field `expert_mask`, which indicates whether the experience is from an expert or not. This field will be used in the policy loss function to distinguish between usual and expert experiences. ## Step 3: Define the Policy Loss Function @@ -217,15 +186,15 @@ class MIXPolicyLossFn(PolicyLossFn): old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, - is_expert_mask: torch.Tensor, + expert_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: assert ( - len(is_expert_mask) == logprob.shape[0] - ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" + len(expert_mask) == logprob.shape[0] + ), f"Error: {len(expert_mask)=} != {logprob.shape[0]=}" - n_usual_exp = torch.sum(~is_expert_mask).item() - n_expert_exp = torch.sum(is_expert_mask).item() + n_usual_exp = torch.sum(~expert_mask).item() + n_expert_exp = torch.sum(expert_mask).item() if self.use_dynamic_bsz: per_micro_batch_weight_usual = self.experience_per_gpu / ( @@ -240,10 +209,10 @@ class MIXPolicyLossFn(PolicyLossFn): if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( - logprob[~is_expert_mask], - old_logprob[~is_expert_mask], - action_mask[~is_expert_mask], - advantages[~is_expert_mask], + logprob[~expert_mask], + old_logprob[~expert_mask], + action_mask[~expert_mask], + advantages[~expert_mask], **kwargs, ) grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual @@ -257,8 +226,8 @@ class MIXPolicyLossFn(PolicyLossFn): # SFT Loss (expert) if n_expert_exp > 0: sft_loss, sft_metrics = self.sft_loss_fn( - logprob[is_expert_mask], - action_mask[is_expert_mask], + logprob[expert_mask], + action_mask[expert_mask], ) sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert sft_metrics = { diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index ba88feb2d7..b6bd5da21d 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -26,7 +26,7 @@ def setUp(self): "ref_log_prob": 2 * torch.rand(shape) - 1, "response_mask": torch.rand(shape) > 0.5, "advantages": 2 * torch.rand(shape) - 1, - "is_expert_mask": torch.rand(shape[0]) > 0.5, + "expert_mask": torch.rand(shape[0]) > 0.5, } ) diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 44f96e16f7..212b88fb2d 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -43,14 +43,14 @@ def test_eid_properties(self): class TestExperience(unittest.TestCase): def test_single_turn_experience(self): tokens = torch.tensor([10, 11, 12], dtype=torch.int32) - logprobs = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32) + logprobs = torch.tensor([0.2, 0.3], dtype=torch.float32) exp = Experience(tokens=tokens, logprobs=logprobs, reward=1.0, prompt_length=1) self.assertEqual(exp.experience_type.name, "SINGLE_TURN") self.assertTrue(torch.equal(exp.tokens, tokens)) self.assertTrue(torch.equal(exp.logprobs, logprobs)) self.assertEqual(exp.reward, 1.0) self.assertEqual(exp.prompt_length, 1) - self.assertTrue(torch.equal(exp.action_mask, torch.tensor([0, 1, 1], dtype=torch.bool))) + self.assertTrue(torch.equal(exp.action_mask, torch.tensor([1, 1], dtype=torch.bool))) def test_multi_turn_experience(self): tokens = torch.tensor([1, 2, 3, 4]) @@ -171,13 +171,17 @@ def test_batch_conversion(self): tokens=torch.tensor([1, 2]), prompt_length=1, reward=float(0.1), - logprobs=torch.tensor([0, 0.1]), + logprobs=torch.tensor([0.1]), + advantages=torch.tensor([0.1]), + returns=torch.tensor([0.4]), ), Experience( tokens=torch.tensor([1, 2, 3]), prompt_length=2, reward=float(0.2), - logprobs=torch.tensor([0, 0, 0.1]), + logprobs=torch.tensor([0.1]), + advantages=torch.tensor([0.3]), + returns=torch.tensor([0.2]), ), ] batch = Experiences.gather_experiences(exps) @@ -199,45 +203,53 @@ def test_batch_conversion(self): ) self.assertTrue( torch.all( - batch.logprobs[i][ - prompt_length - - exps[i].prompt_length : prompt_length - + exps[i].tokens.size(0) - - exps[i].prompt_length - ] + batch.logprobs[i][: exps[i].tokens.size(0) - exps[i].prompt_length] == exps[i].logprobs ) ) self.assertTrue( torch.all( - batch.action_masks[i][ - prompt_length - - exps[i].prompt_length : prompt_length - - exps[i].prompt_length - + exps[i].action_mask.size(0) - ] + batch.action_masks[i][: exps[i].tokens.size(0) - exps[i].prompt_length] == exps[i].action_mask ) ) + self.assertTrue( + torch.all( + batch.advantages[i][: exps[i].tokens.size(0) - exps[i].prompt_length] + == exps[i].advantages + ) + ) + self.assertTrue( + torch.all( + batch.returns[i][: exps[i].tokens.size(0) - exps[i].prompt_length] + == exps[i].returns + ) + ) def test_multiturn_experience_batch_converstion(self): exps = [ Experience( - tokens=torch.tensor([1, 2, 3, 4]), + tokens=torch.tensor([1, 2, 3, 4, 5, 6]), reward=float(0.3), - logprobs=torch.tensor([0, 0, 0.1, 0.2]), - action_mask=torch.tensor([1, 0, 1, 0]), + logprobs=torch.tensor([0, 0.1, 0.2, 0.3]), + prompt_length=2, + action_mask=torch.tensor([1, 0, 1, 1]), + advantages=torch.tensor([0.1, 0, 0.2, 0.3]), + returns=torch.tensor([0.5, 0, 0.7, 0.8]), ), Experience( tokens=torch.tensor([1, 2, 3, 4]), reward=float(0.4), - logprobs=torch.tensor([0, 0, 0, 0.1]), - action_mask=torch.tensor([1, 0, 0, 1]), + logprobs=torch.tensor([0, 0.1]), + prompt_length=2, + action_mask=torch.tensor([1, 1]), + advantages=torch.tensor([0.2, 0.3]), + returns=torch.tensor([0.6, 0.9]), ), ] batch = Experiences.gather_experiences(exps) self.assertEqual(batch.batch_size, 2) - self.assertEqual(batch.prompt_length, 1) + self.assertEqual(batch.prompt_length, 2) prompt_length = batch.prompt_length for i in range(batch.batch_size): self.assertEqual(batch.rewards[i], exps[i].reward) @@ -254,26 +266,28 @@ def test_multiturn_experience_batch_converstion(self): ) self.assertTrue( torch.all( - batch.logprobs[i][ - prompt_length - - exps[i].prompt_length : prompt_length - + exps[i].tokens.size(0) - - exps[i].prompt_length - ] + batch.logprobs[i][: exps[i].tokens.size(0) - exps[i].prompt_length] == exps[i].logprobs ) ) self.assertTrue( torch.all( - batch.action_masks[i][ - prompt_length - - exps[i].prompt_length : prompt_length - - exps[i].prompt_length - + exps[i].action_mask.size(0) - ] + batch.action_masks[i][: exps[i].tokens.size(0) - exps[i].prompt_length] == exps[i].action_mask ) ) + self.assertTrue( + torch.all( + batch.advantages[i][: exps[i].tokens.size(0) - exps[i].prompt_length] + == exps[i].advantages + ) + ) + self.assertTrue( + torch.all( + batch.returns[i][: exps[i].tokens.size(0) - exps[i].prompt_length] + == exps[i].returns + ) + ) def test_dpo_experience_batch_conversion(self): exps = [ diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 71ccf32b7f..fdb0d458ae 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -159,15 +159,12 @@ async def test_generate( self.assertEqual(exp.prompt_length, history_exp.prompt_length) self.assertEqual(exp.logprobs.tolist(), history_exp.logprobs.tolist()) for result in results: - input_logprobs = result.logprobs[: result.prompt_length] - output_logprobs = result.logprobs[result.prompt_length :] - self.assertTrue(torch.all(input_logprobs == 0)) - self.assertTrue(torch.any(output_logprobs != 0)) + self.assertTrue(torch.any(result.logprobs != 0)) if self.use_async: logprobs = await self.model_wrapper.logprobs_async(results[0].tokens.tolist()) else: logprobs = self.model_wrapper.logprobs(results[0].tokens.tolist()) - self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0]) + self.assertEqual(logprobs.shape[0], results[0].tokens.shape[0] - 1) if self.config.explorer.rollout_model.enable_history: history_experiences = self.model_wrapper.extract_experience_from_history() self.assertTrue(len(history_experiences) == 0) @@ -190,7 +187,10 @@ async def test_generate( return_assistant_tokens_mask=True, return_dict=True, ) - self.assertTrue(torch.equal(result_dict["assistant_masks"][0], exp.action_mask)) + prompt_length = torch.argmax(result_dict["assistant_masks"][0]).item() + self.assertTrue( + torch.equal(result_dict["assistant_masks"][0][prompt_length:], exp.action_mask) + ) self.assertTrue(torch.equal(result_dict["input_ids"][0], exp.tokens)) self.assertRaises(ValueError, self.model_wrapper.get_openai_client) if self.config.explorer.rollout_model.enable_history: @@ -284,12 +284,12 @@ def test_assistant_token_mask(self): }, ] tokenizer = AutoTokenizer.from_pretrained(get_model_path()) - token_ids, action_mask = tokenize_and_mask_messages_default( + token_ids, action_mask, prompt_length = tokenize_and_mask_messages_default( tokenizer=tokenizer, messages=messages, chat_template=CHAT_TEMPLATE, ) - token_ids_hf, action_mask_hf = tokenize_and_mask_messages_hf( + token_ids_hf, action_mask_hf, prompt_length_hf = tokenize_and_mask_messages_hf( tokenizer=tokenizer, messages=messages, chat_template=CHAT_TEMPLATE, @@ -298,3 +298,4 @@ def test_assistant_token_mask(self): self.assertEqual(action_mask.shape, action_mask_hf.shape) self.assertTrue(torch.equal(token_ids, token_ids_hf)) self.assertTrue(torch.equal(action_mask, action_mask_hf)) + self.assertEqual(prompt_length, prompt_length_hf) diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 76c89c42d9..37f20f0236 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -14,7 +14,7 @@ class MIXPolicyLossFn(PolicyLossFn): """Implements a mixed policy loss combining GRPO and SFT losses. This loss function applies different loss components to data based on whether - it comes from an expert or not, as indicated by `is_expert_mask`. It combines: + it comes from an expert or not, as indicated by `expert_mask`. It combines: - GRPO loss (self.grpo_loss_fn) for non-expert data - SFT loss (self.sft_loss_fn) for expert data - Weighting parameter `mu` @@ -62,15 +62,15 @@ def __call__( # type: ignore old_logprob: torch.Tensor, action_mask: torch.Tensor, advantages: torch.Tensor, - is_expert_mask: torch.Tensor, + expert_mask: torch.Tensor, **kwargs, ) -> Tuple[torch.Tensor, Dict]: assert ( - len(is_expert_mask) == logprob.shape[0] - ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" + len(expert_mask) == logprob.shape[0] + ), f"Error: {len(expert_mask)=} != {logprob.shape[0]=}" - n_usual_exp = torch.sum(~is_expert_mask).item() - n_expert_exp = torch.sum(is_expert_mask).item() + n_usual_exp = torch.sum(~expert_mask).item() + n_expert_exp = torch.sum(expert_mask).item() if self.use_dynamic_bsz: per_micro_batch_weight_usual = self.experience_per_gpu / ( @@ -85,10 +85,10 @@ def __call__( # type: ignore if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( - logprob[~is_expert_mask], - old_logprob[~is_expert_mask], - action_mask[~is_expert_mask], - advantages[~is_expert_mask], + logprob[~expert_mask], + old_logprob[~expert_mask], + action_mask[~expert_mask], + advantages[~expert_mask], **kwargs, ) grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual @@ -102,8 +102,8 @@ def __call__( # type: ignore # SFT Loss (expert) if n_expert_exp > 0: sft_loss, sft_metrics = self.sft_loss_fn( - logprob[is_expert_mask], - action_mask[is_expert_mask], + logprob[expert_mask], + action_mask[expert_mask], ) sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert sft_metrics = { diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 60f908afe2..b45560a269 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -1,8 +1,7 @@ import copy from math import ceil -from typing import Any, Dict, List, Tuple +from typing import Dict, List, Tuple -import numpy as np import torch from trinity.algorithm.sample_strategy.sample_strategy import ( @@ -12,7 +11,7 @@ from trinity.algorithm.sample_strategy.utils import representative_sample from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig -from trinity.common.experience import Experiences +from trinity.common.experience import CustomField, Experiences from trinity.utils.timer import Timer @@ -20,8 +19,8 @@ class MixSampleStrategy(SampleStrategy): """The default sample strategy.""" - def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): - super().__init__(buffer_config, trainer_type) + def __init__(self, buffer_config: BufferConfig, **kwargs): + super().__init__(buffer_config) self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) tot_batch_size = buffer_config.read_batch_size expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) @@ -45,7 +44,7 @@ def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config ) - def sample(self, step: int) -> Tuple[Any, Dict, List]: + def sample(self, step: int) -> Tuple[Experiences, Dict, List]: metrics = {} with Timer(metrics, "read_time"): usual_exp_list = self.usual_exp_buffer.read() @@ -57,7 +56,9 @@ def sample(self, step: int) -> Tuple[Any, Dict, List]: expert_exp_list = self.expert_exp_buffer.read() for exp in expert_exp_list: exp.reward = 0.0 - exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + exp.logprobs = torch.zeros_like( + exp.tokens[exp.prompt_length :], dtype=torch.float32 + ) if exp.info is None: exp.info = {} exp.info["is_expert"] = True @@ -65,56 +66,22 @@ def sample(self, step: int) -> Tuple[Any, Dict, List]: exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) - is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) - with Timer(metrics, "gather_time"): - exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - - if self.trainer_type == "verl": - with Timer(metrics, "convert_time"): - data = to_data_proto_mix(exps, is_expert_mask) - return data, metrics, repr_samples - else: - raise NotImplementedError(f"backend {self.trainer_type} is not supported") + exps = Experiences.gather_experiences( + experiences=exp_list, + pad_token_id=self.pad_token_id, # type: ignore [arg-type] + custom_fields=[ + CustomField( + source_field="is_expert", + destination_field="expert_mask", + data_type=torch.bool, + ) + ], + ) # type: ignore + return exps, metrics, repr_samples @classmethod def default_args(cls) -> Dict: return { "expert_data_ratio": 0.5, } - - -def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor): - from verl.trainer.ppo.ray_trainer import DataProto - - attention_mask = experiences.attention_masks - cumsum = torch.cumsum(attention_mask, dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - batch_dict = { - "uid": np.array([eid.tid for eid in experiences.eids]), - "unique_ids": np.array([eid.uid for eid in experiences.eids]), - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - "is_expert_mask": is_expert_mask, - } - if experiences.rewards is not None: - token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) - eos_mask_idx = cumsum.argmax(dim=-1) - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] - batch_dict.update( - { - "token_level_scores": token_level_rewards, - "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore - } - ) - return DataProto.from_single_dict(batch_dict) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index b923ab17a6..b6b3c1e356 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Tuple -from trinity.algorithm.sample_strategy.utils import representative_sample, to_data_proto +from trinity.algorithm.sample_strategy.utils import representative_sample from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig from trinity.common.experience import Experiences @@ -12,36 +12,22 @@ class SampleStrategy(ABC): - def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs) -> None: + def __init__(self, buffer_config: BufferConfig, **kwargs) -> None: self.pad_token_id = buffer_config.pad_token_id - self.trainer_type = trainer_type @abstractmethod - def sample(self, step: int) -> Tuple[Any, Dict, List]: + def sample(self, step: int) -> Tuple[Experiences, Dict, List]: """Sample data from buffer. Args: step (`int`): The step number of current step. Returns: - `Any`: The sampled data. + `Experiences`: The sampled Experiences data. `Dict`: Metrics for logging. `List`: Representative data for logging. """ - # Experimental API - @abstractmethod - def warmup_state(self, step: int) -> Tuple[bool, bool]: - """Check the warmup state of the current step. - - Args: - step (`int`): The step number of current step. - - Returns: - `bool`: Current step is in warmup or not. - `bool`: Warmup is finished on this step or not. - """ - @classmethod @abstractmethod def default_args(cls) -> dict: @@ -52,8 +38,8 @@ def default_args(cls) -> dict: class WarmupSampleStrategy(SampleStrategy): """The default sample strategy.""" - def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): - super().__init__(buffer_config, trainer_type) + def __init__(self, buffer_config: BufferConfig, **kwargs): + super().__init__(buffer_config) self.exp_buffer = get_buffer_reader( buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore ) @@ -67,7 +53,7 @@ def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): else: self.sft_buffer = None - def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: + def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: metrics = {} with Timer(metrics, "read_time"): if step <= self.sft_warmup_steps: @@ -77,15 +63,7 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: repr_samples = representative_sample(exp_list) with Timer(metrics, "gather_time"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - if self.trainer_type == "verl": - with Timer(metrics, "convert_time"): - data = to_data_proto(exps) - return data, metrics, repr_samples - else: - raise NotImplementedError(f"backend {self.trainer_type} is not supported") - - def warmup_state(self, step: int) -> Tuple[bool, bool]: - return step <= self.sft_warmup_steps, step == self.sft_warmup_steps + return exps, metrics, repr_samples @classmethod def default_args(cls) -> dict: @@ -94,8 +72,8 @@ def default_args(cls) -> dict: @SAMPLE_STRATEGY.register_module("default") class DefaultSampleStrategy(SampleStrategy): - def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): - super().__init__(buffer_config, trainer_type) + def __init__(self, buffer_config: BufferConfig, **kwargs): + super().__init__(buffer_config) self.exp_buffer = get_buffer_reader( buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore ) @@ -107,15 +85,7 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: repr_samples = representative_sample(exp_list) with Timer(metrics, "gather_time"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - if self.trainer_type == "verl": - with Timer(metrics, "convert_time"): - data = to_data_proto(exps) - return data, metrics, repr_samples - else: - raise NotImplementedError(f"backend {self.trainer_type} is not supported") - - def warmup_state(self, step: int) -> Tuple[bool, bool]: - return False, False + return exps, metrics, repr_samples @classmethod def default_args(cls) -> dict: diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py index cba97e6d9e..80c76ec016 100644 --- a/trinity/algorithm/sample_strategy/utils.py +++ b/trinity/algorithm/sample_strategy/utils.py @@ -1,44 +1,7 @@ import random from typing import List -import numpy as np -import torch -from verl.trainer.ppo.ray_trainer import DataProto - -from trinity.common.experience import Experience, Experiences - - -def to_data_proto(experiences: Experiences) -> DataProto: - attention_mask = experiences.attention_masks - cumsum = torch.cumsum(attention_mask, dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - batch_dict = { - "uid": np.array([eid.tid for eid in experiences.eids]), - "unique_ids": np.array([eid.uid for eid in experiences.eids]), - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - } - if experiences.rewards is not None: - token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) - eos_mask_idx = cumsum.argmax(dim=-1) - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] - batch_dict.update( - { - "token_level_scores": token_level_rewards, - "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore - } - ) - return DataProto.from_single_dict(batch_dict) +from trinity.common.experience import Experience def representative_sample(experiences: List[Experience]) -> List[dict]: diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index 5ac3da2666..61149b8b25 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -85,7 +85,7 @@ def from_messages( """Convert a list of messages into a single instance of SFT data.""" from trinity.common.models.utils import tokenize_and_mask_messages_hf - tokens, action_mask = tokenize_and_mask_messages_hf( + tokens, action_mask, prompt_length = tokenize_and_mask_messages_hf( tokenizer=tokenizer, messages=messages, chat_template=chat_template, @@ -93,6 +93,7 @@ def from_messages( exp = Experience( tokens=tokens, action_mask=action_mask, + prompt_length=prompt_length, info={"response_num": sum([1 if m["role"] == "assistant" else 0 for m in messages])}, ) return cls( diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 0c5f98e89c..89f8d3c10e 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -80,20 +80,34 @@ class ExperienceType(Enum): DPO = "dpo" # DPO experience, e.g., a chosen and rejected response pair +@dataclass(frozen=True) +class CustomField: + """Custom field for Experiences. + + This is used to store additional information into the Experiences class. + """ + + source_field: str # The source field name in the Experience.info + destination_field: str # The destination field name in the Experiences class + data_type: torch.dtype # The data type of the field, e.g., torch.float32, torch.int64, etc. + + @dataclass class Experience: eid: EID = field(default_factory=EID) # Unique identifier for the experience tokens: Optional[Tensor] = None # [seq_length] - logprobs: Optional[Tensor] = None # [seq_length] + logprobs: Optional[Tensor] = None # [resp_length] reward: Optional[float] = None + advantages: Optional[Tensor] = None # [resp_length] + returns: Optional[Tensor] = None # [resp_length] # Type of the experience, automatically set based on the presence of action_mask or chosen/rejected experience_type: ExperienceType = ExperienceType.SINGLE_TURN - info: Optional[dict] = field( + info: dict = field( default_factory=dict - ) # Additional information about the experience - metrics: Optional[dict[str, float]] = field( + ) # Additional information about the experience, can also be used to store custom fields + metrics: dict[str, float] = field( default_factory=dict - ) # Metrics associated with the experience + ) # Metrics associated with the experience, directly used by the monitor # for single-turn experiences prompt_length: int = 1 # Length of the prompt in tokens, used for generating attention masks @@ -101,9 +115,8 @@ class Experience: prompt_text: Optional[str] = None # Text of the prompt # for multi-turn experiences - action_mask: Optional[ - Tensor - ] = None # Action mask which indicates which tokens are generated by the model + # Action mask which indicates which tokens are generated by the model + action_mask: Optional[Tensor] = None # [resp_length] messages: Optional[List[dict]] = None # List of messages # for dpo experiences @@ -119,6 +132,8 @@ def __init__( tokens, logprobs=None, reward=None, + advantages=None, + returns=None, info=None, metrics=None, prompt_length=1, @@ -145,10 +160,7 @@ def __init__( assert ( len(tokens) > prompt_length ), f"Token ids must be longer than the prompt length. Got len(tokens)={len(tokens)}, prompt_length={prompt_length}." - action_mask = torch.zeros(len(tokens), dtype=torch.bool) - action_mask[prompt_length:] = 1 - elif experience_type == ExperienceType.MULTI_TURN: - prompt_length = 1 + action_mask = torch.ones(len(tokens) - prompt_length, dtype=torch.bool) elif experience_type == ExperienceType.DPO: prompt_length = len(tokens) @@ -156,6 +168,8 @@ def __init__( self.tokens = tokens self.logprobs = logprobs self.reward = reward + self.advantages = advantages + self.returns = returns self.experience_type = experience_type self.info = info or {} self.metrics = metrics or {} @@ -211,9 +225,14 @@ def to_dict(self) -> dict: return res @classmethod - def gather(cls, experiences: List[Experience], pad_token_id: int = 0) -> Experiences: + def gather( + cls, + experiences: List[Experience], + pad_token_id: int = 0, + custom_fields: Optional[List[CustomField]] = None, + ) -> Experiences: if len(experiences) == 0: - return empty_experiences() + return empty_experiences(custom_fields) exp_type = experiences[0].experience_type if exp_type == ExperienceType.DPO: experiences = split_dpo_experience_to_single_turn(experiences) @@ -231,7 +250,7 @@ def gather(cls, experiences: List[Experience], pad_token_id: int = 0) -> Experie rewards = None # gather action_masks - action_masks = gather_action_masks(experiences, max_prompt_length, max_response_length) + action_masks = gather_action_masks(experiences, max_response_length) # gather attention_masks attention_masks = gather_attention_masks( @@ -241,19 +260,45 @@ def gather(cls, experiences: List[Experience], pad_token_id: int = 0) -> Experie # gather logprobs if all(exp.logprobs is not None for exp in experiences): - logprobs = gather_logprobs(experiences, max_prompt_length, max_response_length) + logprobs = gather_logprobs(experiences, max_response_length) else: logprobs = None - return Experiences( + # gather advantages + if all(exp.advantages is not None for exp in experiences): + advantages = gather_advantages(experiences, max_response_length) + else: + advantages = None + + # gather returns + if all(exp.returns is not None for exp in experiences): + returns = gather_returns(experiences, max_response_length) + else: + returns = None + + exps = Experiences( eids=eids, tokens=tokens, rewards=rewards, + advantages=advantages, + returns=returns, attention_masks=attention_masks, action_masks=action_masks, prompt_length=max_prompt_length, logprobs=logprobs, ) + if custom_fields is not None: + for custom_field in custom_fields: + exps.custom_fields.append(custom_field.destination_field) + setattr( + exps, + custom_field.destination_field, + torch.tensor( + [exp.info[custom_field.source_field] for exp in experiences], + dtype=custom_field.data_type, + ), + ) + return exps def split_dpo_experience_to_single_turn(experiences: List[Experience]) -> List[Experience]: @@ -313,12 +358,17 @@ class Experiences: """ eids: List[EID] # Experience IDs of each experience in the batch - tokens: Tensor - rewards: Tensor - attention_masks: Tensor - action_masks: Optional[Tensor] + tokens: Tensor # [batch_size, seq_length] + rewards: Tensor # [batch_size] + advantages: Optional[Tensor] # [batch_size, response_length] + returns: Optional[Tensor] # [batch_size, response_length] + attention_masks: Tensor # [batch_size, sequence_length] + action_masks: Optional[Tensor] # [batch_size, response_length] prompt_length: int - logprobs: Optional[Tensor] + logprobs: Optional[Tensor] # [batch_size, response_length] + custom_fields: List[str] = field( + default_factory=list + ) # Custom fields to include in the gathered experiences @property def batch_size(self) -> int: @@ -327,27 +377,46 @@ def batch_size(self) -> int: @classmethod def gather_experiences( - cls, experiences: list[Experience], pad_token_id: int = 0 + cls, + experiences: list[Experience], + pad_token_id: int = 0, + custom_fields: Optional[List[CustomField]] = None, ) -> Experiences: """Gather a batch of experiences from a list of experiences. This method will automatically pad the `tokens` and `logprobs` of input experiences to the same length. + + Args: + experiences (list[Experience]): A list of experiences to gather. + pad_token_id (int): The token ID to use for padding. Default is 0. + custom_fields (Optional[List[CustomField]]): Custom fields to include in the gathered experiences. """ if len(experiences) == 0: - return empty_experiences() - return experiences[0].__class__.gather(experiences, pad_token_id=pad_token_id) + return empty_experiences(custom_fields) + return experiences[0].__class__.gather( + experiences, pad_token_id=pad_token_id, custom_fields=custom_fields + ) -def empty_experiences() -> Experiences: - return Experiences( +def empty_experiences(custom_fields: Optional[List[CustomField]]) -> Experiences: + exps = Experiences( tokens=torch.empty(0, dtype=torch.int32), rewards=torch.empty(0, dtype=torch.float32), + advantages=torch.empty(0, dtype=torch.float32), + returns=torch.empty(0, dtype=torch.float32), attention_masks=torch.empty(0, dtype=torch.bool), action_masks=torch.empty(0, dtype=torch.bool), logprobs=torch.empty(0, dtype=torch.float32), prompt_length=torch.empty(0, dtype=torch.int32), eids=[], ) + if custom_fields is not None: + for custom_field in custom_fields: + exps.custom_fields.append(custom_field.destination_field) + setattr( + exps, custom_field.destination_field, torch.empty(0, dtype=custom_field.data_type) + ) + return exps def gather_token_ids( @@ -376,19 +445,14 @@ def gather_token_ids( ) -def gather_action_masks(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: +def gather_action_masks(experiences, max_response_length: int) -> Tensor: return torch.stack( [ torch.cat( [ - torch.full( - (max_prompt_length - exp.prompt_length,), - 0, - dtype=torch.bool, - ), exp.action_mask, torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), + (max_response_length - len(exp.action_mask),), 0, dtype=torch.bool, ), @@ -412,22 +476,59 @@ def gather_attention_masks(experiences, max_prompt_length: int, max_response_len return attention_masks -def gather_logprobs(experiences, max_prompt_length: int, max_response_length: int) -> Tensor: +def gather_logprobs(experiences, max_response_length: int) -> Tensor: logprob_dtype = experiences[0].logprobs.dtype # type: ignore [union-attr] return torch.stack( [ torch.cat( [ + exp.logprobs, torch.full( - (max_prompt_length - exp.prompt_length,), + (max_response_length - len(exp.logprobs),), 0.0, dtype=logprob_dtype, ), - exp.logprobs, + ] + ) + for exp in experiences + ] + ) + + +def gather_advantages(experiences, max_response_length: int) -> Optional[Tensor]: + if experiences[0].advantages is None: + return None + advantages_dtype = experiences[0].advantages.dtype + return torch.stack( + [ + torch.cat( + [ + exp.advantages, torch.full( - (max_response_length + exp.prompt_length - len(exp.tokens),), + (max_response_length - len(exp.advantages),), 0.0, - dtype=logprob_dtype, + dtype=advantages_dtype, + ), + ] + ) + for exp in experiences + ] + ) + + +def gather_returns(experiences, max_response_length: int) -> Optional[Tensor]: + if experiences[0].returns is None: + return None + returns_dtype = experiences[0].returns.dtype + return torch.stack( + [ + torch.cat( + [ + exp.returns, + torch.full( + (max_response_length - len(exp.returns),), + 0.0, + dtype=returns_dtype, ), ] ) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 96d700678e..1cc6f1c19c 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -196,16 +196,7 @@ def convert_api_output_to_experience( torch.tensor(choice.token_ids, dtype=torch.int32), ) ), - logprobs=torch.cat( - ( - torch.full( - (len(output.prompt_token_ids),), - 0.0, - dtype=torch.float32, - ), - extract_logprobs(choice), - ) - ), + logprobs=extract_logprobs(choice), prompt_length=len(output.prompt_token_ids), response_text=choice.message.content, ) diff --git a/trinity/common/models/utils.py b/trinity/common/models/utils.py index 087b190e86..67e5b59504 100644 --- a/trinity/common/models/utils.py +++ b/trinity/common/models/utils.py @@ -16,7 +16,7 @@ def tokenize_and_mask_messages_hf( tokenizer: Any, messages: List[dict], chat_template: Optional[str] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, int]: """Calculate the assistant token mask with `chat_template`. Args: @@ -25,8 +25,9 @@ def tokenize_and_mask_messages_hf( messages (List[dict]): Messages with `role` and `content` fields. Returns: - Tuple[torch.Tensor, torch.Tensor]: The token_ids (sequence_length) - and assistant_masks (sequence_length). + `torch.Tensor`: The token_ids (sequence_length) + `torch.Tensor`: Assistant_masks (sequence_length). + `int`: Prompt length. """ token_dict = tokenizer.apply_chat_template( messages, @@ -39,14 +40,16 @@ def tokenize_and_mask_messages_hf( return_assistant_tokens_mask=True, return_dict=True, ) - return (token_dict["input_ids"][0], token_dict["assistant_masks"][0]) + # find the first assistant token, the tokens before are prompt tokens + prompt_length = torch.argmax(token_dict["assistant_masks"][0]).item() + return token_dict["input_ids"][0], token_dict["assistant_masks"][0], prompt_length def tokenize_and_mask_messages_default( tokenizer: Any, messages: List[dict], chat_template: Optional[str] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor, int]: """Calculate the assistant token mask. Args: @@ -55,8 +58,9 @@ def tokenize_and_mask_messages_default( messages (List[dict]): Messages with `role` and `content` fields. Returns: - Tuple[torch.Tensor, torch.Tensor]: The token_ids (sequence_length) - and assistant_masks (sequence_length). + `torch.Tensor`: The token_ids (sequence_length) + `torch.Tensor`: Assistant_masks (sequence_length). + `int`: Prompt length. Note: This method is based on the assumption that as the number of chat rounds increases, @@ -98,7 +102,8 @@ def tokenize_and_mask_messages_default( ) prompt_response_length = prompt_response_token_ids.shape[1] assistant_token_mask[prompt_length:prompt_response_length] = 1 - return (tokens[0], assistant_token_mask) + prompt_length = torch.argmax(assistant_token_mask).item() + return tokens[0], assistant_token_mask, prompt_length def get_checkpoint_dir_with_step_num( diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index 30c3e00b3e..a1e6070b92 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -154,11 +154,6 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: ), logprobs=torch.cat( ( - torch.full( - (len(output.prompt_token_ids),), - 0.0, - dtype=torch.float32, - ), torch.tensor( [ list(logprob_dict.values())[0].logprob @@ -177,7 +172,14 @@ async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: return experiences async def logprobs(self, token_ids: List[int]) -> torch.Tensor: - """Calculate the logprobs of the given tokens in async.""" + """Calculate the logprobs of the given tokens in async. + + Args: + token_ids (List[int]): The input token ids (seq_length). + + Returns: + A tensor of logprobs (seq_length - 1). + """ output = await self._generate_internal( prompt={"prompt_token_ids": token_ids}, n=1, @@ -185,11 +187,7 @@ async def logprobs(self, token_ids: List[int]) -> torch.Tensor: prompt_logprobs=0, # vLLM return `prompt_logprobs + 1` logrpobs for each token ) return torch.tensor( - [0] - + [ - list(logprob_dict.values())[0].logprob - for logprob_dict in output.prompt_logprobs[1:] - ], + [list(logprob_dict.values())[0].logprob for logprob_dict in output.prompt_logprobs[1:]], dtype=torch.float32, ) @@ -217,14 +215,15 @@ async def convert_messages_to_experience(self, messages: List[dict]) -> Experien self.tokenizer = await self.async_llm.get_tokenizer() if self.chat_template is None: self.chat_template = self.tokenizer.get_chat_template() - token_ids, action_mask = self.action_mask_method( + token_ids, action_mask, prompt_length = self.action_mask_method( self.tokenizer, messages, self.chat_template ) logprobs = await self.logprobs(token_ids=token_ids.tolist()) return Experience( tokens=token_ids, logprobs=logprobs, - action_mask=action_mask, + prompt_length=prompt_length, + action_mask=action_mask[prompt_length:], # Exclude the prompt tokens ) def shutdown(self): diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index 608db16abb..18e3ad768d 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -42,6 +42,7 @@ def __init__( self.model_wrapper = ModelWrapper( model, config.explorer.rollout_model.engine_type, + enable_history=config.explorer.rollout_model.enable_history, ) self.auxiliary_models = [] if auxiliary_models is not None: diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 1378449cf2..53d0ce94be 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -7,12 +7,18 @@ import os import traceback from abc import ABC, abstractmethod +from typing import Dict, List, Tuple +import pandas as pd import ray +from trinity.algorithm import SAMPLE_STRATEGY +from trinity.algorithm.utils import prefix_metrics from trinity.common.config import Config from trinity.common.constants import RunningStatus, SyncMethod +from trinity.common.experience import Experiences from trinity.utils.log import get_logger +from trinity.utils.monitor import MONITOR class Trainer: @@ -23,6 +29,17 @@ def __init__(self, config: Config) -> None: self.logger = get_logger(__name__) self.engine = get_trainer_wrapper(config) self.explorer_ref = None + self.monitor = MONITOR.get(config.monitor.monitor_type)( + project=config.project, + name=config.name, + role=config.trainer.name, + config=config, + ) + self._sample_exps_to_log = [] + self.sample_strategy = SAMPLE_STRATEGY.get(config.algorithm.sample_strategy)( + buffer_config=config.buffer, + **config.algorithm.sample_strategy_args, + ) def prepare(self) -> None: """Prepare the trainer.""" @@ -49,7 +66,26 @@ def train_step(self) -> bool: Returns: bool: Whether to continue training. """ - return self.engine.train_step() + try: + batch, sample_metrics, repr_samples = self.sample_strategy.sample( + self.train_step_num + 1 + ) + except StopIteration: + self.logger.info("No more samples to train. Stopping training.") + if ( + self.config.trainer.save_interval == 0 + or self.train_step_num % self.config.trainer.save_interval != 0 + ): + self.logger.info(f"Saving at step {self.train_step_num}.") + self.engine.save_checkpoint() + self.logger.info(f"Saved at step {self.train_step_num}.") + return False + continue_run, metrics = self.engine.train_step(batch) + prefix_metrics(sample_metrics, "sample", metrics) + self.monitor.log(data=metrics, step=self.train_step_num) + if self.config.trainer.enable_preview: + self._log_experiences(repr_samples) + return continue_run def need_sync(self) -> bool: """Whether to sync the model weight.""" @@ -73,14 +109,26 @@ def sync_weight(self) -> None: f"Trainer synchronizing weights at step {self.engine.train_step_num} end." ) + def _log_experiences(self, samples: List[Dict]) -> None: + self._sample_exps_to_log.extend(samples) + if self.train_step_num % self.config.synchronizer.sync_interval == 0: + self.monitor.log_table( + "rollout_examples", pd.DataFrame(self._sample_exps_to_log), self.train_step_num + ) + self._sample_exps_to_log.clear() + def shutdown(self) -> None: # if checkpoint not saved, save the last checkpoint - step_num = self.engine.train_step_num - path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{step_num}") + path = os.path.join(self.config.checkpoint_job_dir, f"global_step_{self.train_step_num}") if not os.path.isdir(path) or len(os.listdir(path)) == 0: self.engine.save_checkpoint() self.engine.monitor.close() + @property + def train_step_num(self) -> int: + """Get the current training step number.""" + return self.engine.train_step_num + class TrainEngineWrapper(ABC): """A wrapper class to wrap various training engines.""" @@ -95,8 +143,16 @@ def train_step_num(self) -> int: """Get the current training step number.""" @abstractmethod - def train_step(self) -> bool: - """Training.""" + def train_step(self, batch: Experiences) -> Tuple[bool, Dict]: + """Training one step. + + Args: + batch (Experiences): A batch of experiences to train. + + Returns: + bool: Whether to continue training. + Dict: Metrics of the training step. + """ @abstractmethod def save_checkpoint(self) -> None: diff --git a/trinity/trainer/verl/converter.py b/trinity/trainer/verl/converter.py new file mode 100644 index 0000000000..c7b2e92763 --- /dev/null +++ b/trinity/trainer/verl/converter.py @@ -0,0 +1,44 @@ +"""Convert Experiences to verl.DataProto.""" + +import numpy as np +import torch +from verl import DataProto + +from trinity.common.experience import Experiences + + +def to_data_proto(experiences: Experiences) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array([eid.tid for eid in experiences.eids]), + "unique_ids": np.array([eid.uid for eid in experiences.eids]), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks.long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), + } + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs, # type: ignore + } + ) + if experiences.custom_fields: + for field in experiences.custom_fields: + if hasattr(experiences, field): + batch_dict[field] = getattr(experiences, field) + return DataProto.from_single_dict(batch_dict) diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 2198f4d2d1..a281840593 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -5,10 +5,8 @@ """ import os import sys -from pprint import pprint -from typing import Dict, List +from typing import Dict, Tuple -import pandas as pd import ray import torch from omegaconf import OmegaConf @@ -38,8 +36,8 @@ from trinity.common.config import Config from trinity.common.experience import Experiences from trinity.trainer.trainer import TrainEngineWrapper +from trinity.trainer.verl.converter import to_data_proto from trinity.utils.log import get_logger -from trinity.utils.monitor import MONITOR class _InternalDataLoader: @@ -147,13 +145,6 @@ def __init__( ray_worker_group_cls, ) self.init_workers() - self.monitor = MONITOR.get(global_config.monitor.monitor_type)( - project=config.trainer.project_name, - name=config.trainer.experiment_name, - role=global_config.trainer.name, - config=global_config, - ) - self.reset_experiences_example_table() self.logger = get_logger(__name__) def _validate_config(self): # TODO @@ -273,36 +264,15 @@ def prepare(self): # load checkpoint before doing anything self._load_checkpoint() - # perform validation before training - # currently, we only support validation using the reward_function. - if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): - val_metrics = self._validate() - pprint(f"Initial validation metrics: {val_metrics}") - self.monitor.log(data=val_metrics, step=self.global_steps) - if self.config.trainer.get("val_only", False): - return - def _create_dataloader(self, train_dataset, val_dataset, collate_fn, train_sampler): self.train_dataloader = _InternalDataLoader(self.config) # TODO: compute total training steps self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize - def train_step(self) -> bool: # noqa C901 + def train_step(self, batch: Experiences) -> Tuple[bool, Dict]: # noqa C901 self.logger.info(f"Training at step {self.global_steps + 1} started.") + batch = to_data_proto(batch) metrics = {} - try: - batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1) - prefix_metrics(sample_metrics, "sample", metrics) - except StopIteration: - print("No more data to train. Stop training.") - if ( - self.config.trainer.save_freq == 0 - or self.global_steps % self.config.trainer.save_freq != 0 - ): - self.logger.info(f"Saving at step {self.global_steps}.") - self._save_checkpoint() - self.logger.info(f"Saved at step {self.global_steps}.") - return False self.global_steps += 1 self.logger.info(f"Sampling at step {self.global_steps} done.") timing_raw = {} @@ -381,12 +351,6 @@ def train_step(self) -> bool: # noqa C901 compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus) ) - if self.algorithm.use_advantage and self.config.enable_preview: # TODO - self._log_experiences(exp_samples) - - # TODO: make a canonical logger that supports various backend - self.monitor.log(data=metrics, step=self.global_steps) - train_status = self.global_steps < self.total_training_steps if not train_status or self.algorithm_manager.need_save(self.global_steps): if ( @@ -398,40 +362,7 @@ def train_step(self) -> bool: # noqa C901 self._save_checkpoint() self.logger.info(f"Saved at step {self.global_steps}.") self.logger.info(f"Training at step {self.global_steps} finished.") - return train_status - - def _log_single_experience( - self, experiences: Experiences, idx: int, skip_special_tokens: bool - ) -> None: - reward = experiences.rewards[idx] - attn_mask = experiences.attention_masks[idx].bool() - prompt_token = experiences.tokens[idx][: experiences.prompt_length][ - attn_mask[: experiences.prompt_length] - ] - response_token = experiences.tokens[idx][experiences.prompt_length :][ - attn_mask[experiences.prompt_length :] - ] - prompt_text = self.tokenizer.decode(prompt_token, skip_special_tokens=skip_special_tokens) - response_text = self.tokenizer.decode( - response_token, skip_special_tokens=skip_special_tokens - ) - new_row = pd.DataFrame( - { - "step": [self.global_steps], - "reward": [reward], - "prompt": [prompt_text], - "response": [response_text], - } - ) - self.sample_exps_to_log = pd.concat([self.sample_exps_to_log, new_row], ignore_index=True) - - def _log_experiences(self, samples: List[Dict]) -> None: - self.sample_exps_to_log.extend(samples) - if self.global_steps % self.config.trainer.sync_freq == 0: - self.monitor.log_table( - "rollout_examples", pd.DataFrame(self.sample_exps_to_log), self.global_steps - ) - self.reset_experiences_example_table() + return train_status, metrics def save_checkpoint(self) -> None: self._save_checkpoint()