diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md new file mode 100644 index 0000000000..9dadc76b40 --- /dev/null +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -0,0 +1,303 @@ +# Integrate An New Algorithm + + +This guide introduces how to integrate a new algorithm to Trinity-RFT. +As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective: + +$$ +\mathcal{J}_{\text{Mix}}(\theta) = +\mathcal{J}_{\text{GRPO}}(\theta) ++ +\mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'} +\left[ + \frac{1}{T'_b} \sum_{t=1}^{T'_b} + \log \pi_\theta(o'_{b,t} \mid q'_b, o'_{b, Dict: + return { + "repeat_times": 8, + "policy_loss_fn": "mix", + "advantage_fn": "grpo", + "sample_strategy": "mix", + } +``` + + +## Step 2: Define the Sampling Strategy + +We need to read two kinds of experiences: usual experiences and expert experiences in each step. For this purpose, we define a new experience sampling strategy named `MixSampleStrategy`. + + +```python +class MixSampleStrategy(SampleStrategy): + """The default sample strategy.""" + + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) + tot_batch_size = buffer_config.read_batch_size + expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) + + # experience buffer + usual_buffer_config = copy.deepcopy(buffer_config) + usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size + self.usual_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore + ) + + if buffer_config.trainer_input.sft_warmup_dataset is None: + raise ValueError( + "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" + ) + + # expert experience buffer + expert_buffer_config = copy.deepcopy(buffer_config) + expert_buffer_config.read_batch_size = expert_batch_size + self.expert_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config + ) + + def sample(self, step: int) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + usual_exp_list = self.usual_exp_buffer.read() + for exp in usual_exp_list: + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = False + + expert_exp_list = self.expert_exp_buffer.read() + for exp in expert_exp_list: + exp.reward = 0.0 + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = True + + exp_list = usual_exp_list + expert_exp_list + repr_samples = representative_sample(exp_list) + + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) + + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto_mix(exps, is_expert_mask) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") +``` + +We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type. + +```diff ++ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), ++ "is_expert_mask": is_expert_mask, + } + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + return DataProto.from_single_dict(batch_dict) +``` + + +## Step 3: Define the Policy Loss Function + +We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively. + +```python +@POLICY_LOSS_FN.register_module("mix") +class MIXPolicyLossFn(PolicyLossFn): + def __init__( + self, + mu: float = 0.1, + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + use_dynamic_bsz: Optional[bool] = None, + repeat_times: Optional[int] = None, + ppo_mini_batch_size: Optional[int] = None, + ppo_micro_batch_size_per_gpu: Optional[int] = None, + ngpus_trainer: Optional[int] = None, + read_batch_size_usual: Optional[int] = None, + read_batch_size_expert: Optional[int] = None, + use_token_level_loss_in_sft: bool = True, + ) -> None: + self.mu = mu + self.use_dynamic_bsz = use_dynamic_bsz + self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore + self.gradient_accumulation = ( + ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore + ) + self.read_batch_size_usual = read_batch_size_usual + self.read_batch_size_expert = read_batch_size_expert + self.grpo_loss_fn = PPOPolicyLossFn( + clip_range=clip_range, + clip_range_low=clip_range_low, + clip_range_high=clip_range_high, + ) + self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft) + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + is_expert_mask = kwargs.get("is_expert_mask", None) + if is_expert_mask is None: + raise ValueError("is_expert_mask is required in MIX") + assert ( + len(is_expert_mask) == logprob.shape[0] + ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" + + n_usual_exp = torch.sum(~is_expert_mask).item() + n_expert_exp = torch.sum(is_expert_mask).item() + + if self.use_dynamic_bsz: + per_micro_batch_weight_usual = self.experience_per_gpu / ( + logprob.shape[0] * self.read_batch_size_usual + ) + per_micro_batch_weight_expert = self.experience_per_gpu / ( + logprob.shape[0] * self.read_batch_size_expert + ) + else: + per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore + per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore + + if n_usual_exp > 0: + grpo_loss, grpo_metrics = self.grpo_loss_fn( + logprob[~is_expert_mask], + old_logprob[~is_expert_mask], + action_mask[~is_expert_mask], + advantages[~is_expert_mask], + **kwargs, + ) + grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual + grpo_metrics = { + k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() + } + else: + grpo_loss = torch.tensor(0.0, device=logprob.device) + grpo_metrics = {} + + # SFT Loss (expert) + if n_expert_exp > 0: + sft_loss, sft_metrics = self.sft_loss_fn( + logprob[is_expert_mask], + action_mask[is_expert_mask], + ) + sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert + sft_metrics = { + k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() + } + else: + sft_loss = torch.tensor(0.0, device=logprob.device) + sft_metrics = {} + + loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss + + metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} + metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + metrics.update({"loss": loss.item()}) + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "mu": 0.1, + "clip_range": 0.2, + } + + @property + def select_keys(self) -> List[str]: + return ["old_logprob", "action_mask", "advantages", "is_expert_mask"] +``` + +## Step 4: Run the Experiment + +With the above newly-defined classes and functions, we can run the experiments without modifying other process. +An example showing some important configurations is shown below, including the weighting factor $\mu$ as `algorithm.policy_loss_fn_args['mu']` and the batch size of expert experiences $B'$, calculated as the product of `buffer.batch_size`, `algorithm.sample_strategy_args['expert_data_ratio']` and `algorithm.repeat_times`. +For the full configuration, please refer to [`mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/mix_math.yaml) and [`train_mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/train_mix_math.yaml). + +```yaml +algorithm: + algorithm_type: mix + repeat_times: 8 + sample_strategy_args: + expert_data_ratio: 0.25 + policy_loss_fn_args: + mu: 0.1 + clip_range: 0.2 + use_token_level_loss_in_sft: False + use_dynamic_bsz: False + repeat_times: 8 + ppo_mini_batch_size: 32 + ppo_micro_batch_size_per_gpu: 4 + ngpus_trainer: 4 + read_batch_size_expert: 64 + read_batch_size_usual: 192 +``` diff --git a/examples/mix_math/README.md b/examples/mix_math/README.md new file mode 100644 index 0000000000..8e84f233bc --- /dev/null +++ b/examples/mix_math/README.md @@ -0,0 +1,7 @@ +# Example: MIX on MATH dataset + +This example shows the usage of a new algorithm MIX on the MATH dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md). + +The config files are located in [`mix_math.yaml`](mix.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml). diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml new file mode 100644 index 0000000000..339d8df394 --- /dev/null +++ b/examples/mix_math/mix_math.yaml @@ -0,0 +1,88 @@ +project: "mix_math" +name: "expert0.25_mu0.1" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: mix + repeat_times: 8 + sample_strategy_args: + expert_data_ratio: 0.25 + policy_loss_fn_args: + mu: 0.1 + clip_range: 0.2 + use_token_level_loss_in_sft: False + use_dynamic_bsz: False + repeat_times: 8 + ppo_mini_batch_size: 32 + ppo_micro_batch_size_per_gpu: 4 + ngpus_trainer: 4 + read_batch_size_expert: 64 + read_batch_size_usual: 192 +model: + model_path: /PATH/TO/MODEL/ + max_prompt_tokens: 1024 + max_response_tokens: 10240 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 40 + explore_batch_size: 36 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: math_train + storage_type: file + path: /PATH/TO/DATASET/ + split: 'train' + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: + - name: math_eval + storage_type: file + path: /PATH/TO/DATASET/ + split: 'test' + format: + prompt_key: 'problem' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + path: /PATH/TO/BUFFER/ + sft_warmup_dataset: + name: math_sft + storage_type: file + algorithm_type: sft + path: /PATH/TO/EXPERT_DATA/ + split: 'train' + format: + prompt_type: messages + messages_key: 'messages' +explorer: + eval_interval: 10 + runner_num: 16 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/mix_math/train_math.yaml' + save_interval: 50 +monitor: + monitor_type: wandb diff --git a/examples/mix_math/train_mix_math.yaml b/examples/mix_math/train_mix_math.yaml new file mode 100644 index 0000000000..7b14a87fad --- /dev/null +++ b/examples/mix_math/train_mix_math.yaml @@ -0,0 +1,70 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 128 + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 25600 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: True # True for GRPO + kl_loss_coef: 0.0001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + # --- below: opmd --- + tau: 0.000 # strength of regularization w.r.t. old / ref policy + opmd_baseline: mean # mean / logavgexp, applicable to opmd + use_uid: False # True / False, applicable to pairwise_opmd + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +custom_reward_function: + path: null + name: compute_score + +algorithm: + gamma: 1.0 + lam: 1.0 + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.0001 + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 88b9b946b7..6f0a2d19a7 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -180,3 +180,24 @@ def check_config(cls, config: Config) -> None: logger.warning( "DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2." ) # no need to warn + + +@ALGORITHM_TYPE.register_module("mix") +class MIXAlgorithm(AlgorithmType): + """MIX algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = True + use_rollout: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def get_default_config(cls) -> Dict: + return { + "repeat_times": 8, + "policy_loss_fn": "mix", + "advantage_fn": "grpo", + "sample_strategy": "mix", + } diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 66dce16cab..705fb2525a 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -1,4 +1,5 @@ from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn +from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn @@ -11,4 +12,5 @@ "OPMDPolicyLossFn", "DPOLossFn", "SFTLossFn", + "MIXPolicyLossFn", ] diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py new file mode 100644 index 0000000000..84679b0ea8 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -0,0 +1,133 @@ +"""Mix policy loss function.""" + +from typing import Dict, List, Optional, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn + + +@POLICY_LOSS_FN.register_module("mix") +class MIXPolicyLossFn(PolicyLossFn): + """Implements a mixed policy loss combining GRPO and SFT losses. + + This loss function applies different loss components to data based on whether + it comes from an expert or not, as indicated by `is_expert_mask`. It combines: + - GRPO loss (self.grpo_loss_fn) for non-expert data + - SFT loss (self.sft_loss_fn) for expert data + - Weighting parameter `mu` + + The per-sample weights are normalized using either `experience_per_gpu` or + `gradient_accumulation`, depending on whether dynamic batch sizing is enabled, + to ensure consistent weighting across different batches of the same type experiences. + """ + + def __init__( + self, + mu: float = 0.1, + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + use_dynamic_bsz: Optional[bool] = None, + repeat_times: Optional[int] = None, + ppo_mini_batch_size: Optional[int] = None, + ppo_micro_batch_size_per_gpu: Optional[int] = None, + ngpus_trainer: Optional[int] = None, + read_batch_size_usual: Optional[int] = None, + read_batch_size_expert: Optional[int] = None, + use_token_level_loss_in_sft: bool = True, + ) -> None: + self.mu = mu + self.use_dynamic_bsz = use_dynamic_bsz + self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore + self.gradient_accumulation = ( + ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore + ) + self.read_batch_size_usual = read_batch_size_usual + self.read_batch_size_expert = read_batch_size_expert + self.grpo_loss_fn = PPOPolicyLossFn( + clip_range=clip_range, + clip_range_low=clip_range_low, + clip_range_high=clip_range_high, + ) + self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft) + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + is_expert_mask = kwargs.get("is_expert_mask", None) + if is_expert_mask is None: + raise ValueError("is_expert_mask is required in MIX") + assert ( + len(is_expert_mask) == logprob.shape[0] + ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" + + n_usual_exp = torch.sum(~is_expert_mask).item() + n_expert_exp = torch.sum(is_expert_mask).item() + + if self.use_dynamic_bsz: + per_micro_batch_weight_usual = self.experience_per_gpu / ( + logprob.shape[0] * self.read_batch_size_usual + ) + per_micro_batch_weight_expert = self.experience_per_gpu / ( + logprob.shape[0] * self.read_batch_size_expert + ) + else: + per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore + per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore + + if n_usual_exp > 0: + grpo_loss, grpo_metrics = self.grpo_loss_fn( + logprob[~is_expert_mask], + old_logprob[~is_expert_mask], + action_mask[~is_expert_mask], + advantages[~is_expert_mask], + **kwargs, + ) + grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual + grpo_metrics = { + k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() + } + else: + grpo_loss = torch.tensor(0.0, device=logprob.device) + grpo_metrics = {} + + # SFT Loss (expert) + if n_expert_exp > 0: + sft_loss, sft_metrics = self.sft_loss_fn( + logprob[is_expert_mask], + action_mask[is_expert_mask], + ) + sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert + sft_metrics = { + k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() + } + else: + sft_loss = torch.tensor(0.0, device=logprob.device) + sft_metrics = {} + + loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss + + metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} + metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + metrics.update({"loss": loss.item()}) + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "mu": 0.1, + "clip_range": 0.2, + } + + @property + def select_keys(self) -> List[str]: + return ["old_logprob", "action_mask", "advantages", "is_expert_mask"] diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py index 60f2e268ae..cd4b9e0d66 100644 --- a/trinity/algorithm/sample_strategy/__init__.py +++ b/trinity/algorithm/sample_strategy/__init__.py @@ -1,3 +1,4 @@ +from trinity.algorithm.sample_strategy.mix_sample_strategy import MixSampleStrategy from trinity.algorithm.sample_strategy.sample_strategy import ( SAMPLE_STRATEGY, DefaultSampleStrategy, @@ -10,4 +11,5 @@ "SampleStrategy", "DefaultSampleStrategy", "WarmupSampleStrategy", + "MixSampleStrategy", ] diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py new file mode 100644 index 0000000000..acdd340b24 --- /dev/null +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -0,0 +1,118 @@ +import copy +from math import ceil +from typing import Any, Dict, List, Tuple + +import numpy as np +import torch +from verl.trainer.ppo.ray_trainer import DataProto + +from trinity.algorithm.sample_strategy.sample_strategy import ( + SAMPLE_STRATEGY, + SampleStrategy, +) +from trinity.algorithm.sample_strategy.utils import representative_sample +from trinity.buffer import get_buffer_reader +from trinity.common.config import BufferConfig +from trinity.common.experience import Experiences +from trinity.utils.timer import Timer + + +@SAMPLE_STRATEGY.register_module("mix") +class MixSampleStrategy(SampleStrategy): + """The default sample strategy.""" + + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) + tot_batch_size = buffer_config.read_batch_size + expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) + + # experience buffer + usual_buffer_config = copy.deepcopy(buffer_config) + usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size + self.usual_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore + ) + + if buffer_config.trainer_input.sft_warmup_dataset is None: + raise ValueError( + "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" + ) + + # expert experience buffer + expert_buffer_config = copy.deepcopy(buffer_config) + expert_buffer_config.read_batch_size = expert_batch_size + self.expert_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config + ) + + def sample(self, step: int) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + usual_exp_list = self.usual_exp_buffer.read() + for exp in usual_exp_list: + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = False + + expert_exp_list = self.expert_exp_buffer.read() + for exp in expert_exp_list: + exp.reward = 0.0 + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = True + + exp_list = usual_exp_list + expert_exp_list + repr_samples = representative_sample(exp_list) + + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) + + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto_mix(exps, is_expert_mask) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + @classmethod + def get_default_config(cls) -> Dict: + return { + "expert_data_ratio": 0.5, + } + + +def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), + "is_expert_mask": is_expert_mask, + } + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + return DataProto.from_single_dict(batch_dict) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 644fe9a8f5..74fa419db9 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -105,6 +105,8 @@ class Rollout: val_kwargs: _ValKwargs = field(default_factory=_ValKwargs) temperature: float = 1.0 n: int = 1 # > 1 for grpo + log_prob_micro_batch_size: Optional[int] = None + log_prob_micro_batch_size_per_gpu: int = 1 @dataclass diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 616234d0d6..9c13c528d7 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -297,7 +297,12 @@ def update_policy(self, data: DataProto): # noqa: C901 } select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()} for trinity_key in self.policy_loss_fn.select_keys: - verl_key = select_keys_trinity2verl[trinity_key] + if trinity_key in select_keys_trinity2verl: + verl_key = select_keys_trinity2verl[trinity_key] + else: + verl_key = trinity_key + select_keys_verl2trinity.update({verl_key: trinity_key}) + select_keys_trinity2verl.update({trinity_key: verl_key}) select_keys.append(verl_key) if not isinstance(self.kl_loss_fn, DummyKLFn): select_keys.append("ref_log_prob")