From a1bd5193e4d2a3b2728a383123fe54163841fcbc Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Fri, 13 Jun 2025 09:59:21 +0800 Subject: [PATCH 1/9] prepare for adding mix algorithm --- .../source/tutorial/example_mix_algo.md | 311 ++++++++++++++++++ examples/grpo_math/math_mix.yaml | 76 +++++ examples/grpo_math/train_math.yaml | 4 +- trinity/algorithm/advantage_fn/__init__.py | 2 + .../algorithm/advantage_fn/mix_advantage.py | 52 +++ trinity/algorithm/algorithm.py | 26 ++ trinity/algorithm/policy_loss_fn/__init__.py | 2 + .../policy_loss_fn/mix_policy_loss.py | 116 +++++++ trinity/algorithm/sample_strategy/__init__.py | 2 + .../sample_strategy/mix_sample_strategy.py | 111 +++++++ trinity/common/config.py | 2 + trinity/common/verl_config.py | 18 + trinity/trainer/verl/dp_actor.py | 3 + 13 files changed, 723 insertions(+), 2 deletions(-) create mode 100644 docs/sphinx_doc/source/tutorial/example_mix_algo.md create mode 100644 examples/grpo_math/math_mix.yaml create mode 100644 trinity/algorithm/advantage_fn/mix_advantage.py create mode 100644 trinity/algorithm/policy_loss_fn/mix_policy_loss.py create mode 100644 trinity/algorithm/sample_strategy/mix_sample_strategy.py diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md new file mode 100644 index 0000000000..2c17b79f77 --- /dev/null +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -0,0 +1,311 @@ +# Integrate An New Algorithm + + +This guide introduces how to integrate a new algorithm to Trinity-RFT. +As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective: + +$$ +\mathcal{J}_{\text{Mix}}(\theta) = +\mathcal{J}_{\text{GRPO}}(\theta) ++ +\mu \cdot \underbrace{\frac{1}{B'} \sum_{b=1}^{B'} +\left[ + \frac{1}{T'_b} \sum_{t=1}^{T'_b} + \log \pi_\theta(o'_{b,t} \mid q'_b, o'_{b, Experiences: + return Experiences.gather_experiences(exps, pad_token_id) + + @classmethod + def get_default_config(cls) -> Dict: + return { + "repeat_times": 8, + "policy_loss_fn": "mix", + "mu": 0.1, + } +``` + +We also define some necessary configuration parameters for later use, including the weighting factor $\mu$ and the batch size of expert experiences $B'$, calculated by the product of `expert_data_ratio` and `batch_size`. + + +```python +class AlgorithmConfig: + """Config for algorithm.""" + ... + mu: float = 0.1 + expert_data_ratio: float = 0.5 +``` + +## Step 2: Define the Sampling Strategy + +We need to read two kinds of experiences: usual experiences and expert experiences in each step. For this purpose, we define a new experience sampling strategy named `MixSampleStrategy`. + + +```python +@SAMPLE_STRATEGY.register_module("mix") +class MixSampleStrategy(SampleStrategy): + """The default sample strategy.""" + + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.expert_data_ratio = buffer_config.expert_data_ratio + self.usual_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore + ) + if buffer_config.trainer_input.expert_dataset is None: + raise ValueError("`buffer_config.trainer_input.expert_dataset` is required in MIX algorithm") + + self.expert_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.expert_dataset, buffer_config + ) + tot_batch_size = buffer_config.batch_size + self.expert_exp_buffer.read_batch_size = ceil(self.expert_data_ratio * tot_batch_size) + self.usual_exp_buffer.read_batch_size = tot_batch_size - self.expert_exp_buffer.read_batch_size + + + def sample(self, step: int, **kwargs) -> DataProto: + usual_exp_list = self.exp_buffer.read() + for exp in usual_exp_list: + exp.info["is_expert"] = False + + expert_exp_list = self.expert_exp_buffer.read() + for exp in expert_exp_list: + exp.info["is_expert"] = True + + exp_list = usual_exp_list + expert_exp_list + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) + + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto_mix(exps, is_expert_mask) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + @classmethod + def get_default_config(cls) -> Dict: + return { + "expert_data_ratio": 0.5, + } +``` + +We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type. + +```diff +def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), ++ "is_expert_mask": is_expert_mask, + } + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + return DataProto.from_single_dict(batch_dict) +``` + + + +## Step 3: Define the Avantage Function + +We define a `MIXAdvantageFn` class in `trinity/algorithm/advantage_fn/mix_advantage.py`, which computes the advantage function for only usual experiences. + +```python +@ADVANTAGE_FN.register_module("mix") +class MIXAdvantageFn(GRPOAdvantageFn): + """MIX advantage computation""" + + def __init__( + self, + epsilon: float = 1e-6, + ) -> None: + super().__init__(epsilon) + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + is_expert_mask = exps.batch["is_expert_mask"] + + # Process tensors + tensors = { + k: tensor[~is_expert_mask] for k, tensor in exps.batch.items() + } + + # Process non-tensors + non_tensors = { + k: v[~is_expert_mask.detach().cpu().numpy()] for k, v in exps.non_tensor_batch.items() + } + + # Build new DataProto + exps = DataProto.from_dict( + tensors=tensors, + non_tensors=non_tensors, + meta_info=exps.meta_info + ) + return super().__call__(exps, **kwargs) + + @classmethod + def default_args(cls) -> Dict: + return { + "epsilon": 1e-6, + } +``` + + +## Step 4: Define the Policy Loss Function + +We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively. + +```python + +@POLICY_LOSS_FN.register_module("mix") +class MIXPolicyLossFn(PolicyLossFn): + def __init__( + self, + mu: float = 0.1, + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + use_token_level_loss_in_sft: Optional[bool] = True + ) -> None: + self.mu = mu + self.grpo_loss_fn = PPOPolicyLossFn( + clip_range=clip_range, + clip_range_low=clip_range_low, + clip_range_high=clip_range_high, + ) + self.sft_loss_fn = SFTLossFn( + use_token_level_loss=use_token_level_loss_in_sft + ) + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + is_expert_mask = kwargs.get("is_expert_mask", None) + per_micro_batch_weight = kwargs.get("per_micro_batch_weight", None) + if is_expert_mask is None: + raise ValueError("is_expert_mask is required in MIX") + assert len(is_expert_mask) == logprob.shape[0], f"{len(is_expert_mask)=} != {logprob.shape[0]=}" + + n_usual_exp = torch.sum(~is_expert_mask) + n_expert_exp = torch.sum(is_expert_mask) + + if n_usual_exp > 0: + grpo_loss, grpo_metrics = self.grpo_loss_fn( + logprob[~is_expert_mask], + old_logprob[~is_expert_mask], + action_mask[~is_expert_mask], + advantages[~is_expert_mask], + **kwargs, + ) + grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight + grpo_metrics = {k: v * n_usual_exp * per_micro_batch_weight for k, v in grpo_metrics.items()} + else: + grpo_loss = torch.tensor(0.0, device=logprob.device) + grpo_metrics = {} + + # SFT Loss (expert) + if n_expert_exp > 0: + sft_loss, sft_metrics = self.sft_loss_fn( + logprob[is_expert_mask], + action_mask[is_expert_mask], + ) + sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight + sft_metrics = {k: v * n_expert_exp * per_micro_batch_weight for k, v in sft_metrics.items()} + else: + sft_loss = torch.tensor(0.0, device=logprob.device) + sft_metrics = {} + + loss = grpo_loss + self.mu * sft_loss + + metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} + sft_metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "mu": 0.1, + "clip_range": 0.2, + } + + @property + def select_keys(self) -> List[str]: + return [ + "old_logprob", + "action_mask", + "advantages", + ] +``` + +## Step 5: Run the Experiment + +With the above newly-defined classes, we can run the experiments without modifying other process. +An example showing some important configurations is shown below. + +```yaml +TODO +``` \ No newline at end of file diff --git a/examples/grpo_math/math_mix.yaml b/examples/grpo_math/math_mix.yaml new file mode 100644 index 0000000000..8a34a1a235 --- /dev/null +++ b/examples/grpo_math/math_mix.yaml @@ -0,0 +1,76 @@ +project: "rft_sft_mixed" +name: "test" +checkpoint_root_dir: /mnt/yuchang/checkpoints/ +algorithm: + algorithm_type: mix + mu: 0.5 # NEW + repeat_times: 8 +model: + model_path: /mnt/checkpoint/qwen25/Qwen2.5-1.5B-Instruct + max_prompt_tokens: 1024 + max_response_tokens: 16392 +cluster: + node_num: 1 + gpu_per_node: 4 +buffer: + total_epochs: 1 + batch_size: 32 + expert_data_ratio: 0.5 # NEW + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: openr1 + storage_type: file + path: /mnt/yuchang/datasets/openr1_data/openr1_data_filtered_int + split: 'train' + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: + - name: openr1 + storage_type: file + path: /mnt/yuchang/datasets/openr1_data/openr1_data_filtered_int + split: 'test' + format: + prompt_key: 'problem' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: openr1_buffer + storage_type: queue + path: 'sqlite:////mnt/yuchang/checkpoints/${project}/${name}/openr1.db' + sft_warmup_dataset: + name: openr1_sft + storage_type: file + algorithm_type: sft + path: /mnt/yuchang/datasets/openr1_data_sft + split: 'train' + format: + prompt_type: messages + messages_key: 'messages' +explorer: + eval_interval: 10 + runner_num: 32 + rollout_model: + engine_type: vllm_async + engine_num: 2 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/grpo_math/train_math.yaml' + save_interval: 50 +monitor: + monitor_type: wandb \ No newline at end of file diff --git a/examples/grpo_math/train_math.yaml b/examples/grpo_math/train_math.yaml index 78bcb862c6..7b14a87fad 100644 --- a/examples/grpo_math/train_math.yaml +++ b/examples/grpo_math/train_math.yaml @@ -10,7 +10,7 @@ actor_rollout_ref: ppo_mini_batch_size: 128 ppo_micro_batch_size_per_gpu: 4 use_dynamic_bsz: True # False - ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + ppo_max_token_len_per_gpu: 25600 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 clip_ratio: 0.2 entropy_coeff: 0.001 @@ -21,7 +21,7 @@ actor_rollout_ref: shuffle: False ulysses_sequence_parallel_size: 1 # sp size optim: - lr: 5e-7 + lr: 1e-6 lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime # min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 7bcf682e4b..7e6b076843 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -7,6 +7,7 @@ ) from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn +from trinity.algorithm.advantage_fn.mix_advantage import MIXAdvantageFn __all__ = [ "ADVANTAGE_FN", @@ -17,4 +18,5 @@ "REMAXAdvantageFn", "RLOOAdvantageFn", "OPMDAdvantageFn", + "MIXAdvantageFn", ] diff --git a/trinity/algorithm/advantage_fn/mix_advantage.py b/trinity/algorithm/advantage_fn/mix_advantage.py new file mode 100644 index 0000000000..91fcff587f --- /dev/null +++ b/trinity/algorithm/advantage_fn/mix_advantage.py @@ -0,0 +1,52 @@ +"""Mix advantage computation""" + +from typing import Dict, Tuple + +import torch + +from verl import DataProto + +from trinity.algorithm.advantage_fn import ADVANTAGE_FN +from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn + + +@ADVANTAGE_FN.register_module("mix") +class MIXAdvantageFn(GRPOAdvantageFn): + """MIX advantage computation""" + + def __init__( + self, + epsilon: float = 1e-6, + ) -> None: + super().__init__(epsilon) + + def __call__( + self, + exps: DataProto, + **kwargs, + ) -> Tuple[DataProto, Dict]: + is_expert_mask = exps.batch["is_expert_mask"] + + # Process tensors + tensors = { + k: tensor[~is_expert_mask] for k, tensor in exps.batch.items() + } + + # Process non-tensors + non_tensors = { + k: v[~is_expert_mask.detach().cpu().numpy()] for k, v in exps.non_tensor_batch.items() + } + + # Build new DataProto + exps = DataProto.from_dict( + tensors=tensors, + non_tensors=non_tensors, + meta_info=exps.meta_info + ) + return super().__call__(exps, **kwargs) + + @classmethod + def default_args(cls) -> Dict: + return { + "epsilon": 1e-6, + } diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 88b9b946b7..8f62890211 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -180,3 +180,29 @@ def check_config(cls, config: Config) -> None: logger.warning( "DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2." ) # no need to warn + + +@ALGORITHM_TYPE.register_module("mix") +class MIXAlgorithm(AlgorithmType): + """MIX algorithm.""" + + use_critic: bool = False + use_reference: bool = True + use_advantage: bool = True + use_rollout: bool = True + can_balance_batch: bool = True + schema: type = ExperienceModel + + @classmethod + def check_config(cls, config: Config) -> None: + pass + + @classmethod + def get_default_config(cls) -> Dict: + return { + "repeat_times": 8, + "policy_loss_fn": "mix", + "advantage_fn": "mix", + "sample_strategy": "mix", + "mu": 0.1, + } diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 66dce16cab..200a2dbe4f 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -3,6 +3,7 @@ from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn +from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn __all__ = [ "POLICY_LOSS_FN", @@ -11,4 +12,5 @@ "OPMDPolicyLossFn", "DPOLossFn", "SFTLossFn", + "MIXPolicyLossFn", ] diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py new file mode 100644 index 0000000000..2d7d55c470 --- /dev/null +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -0,0 +1,116 @@ +"""Mix policy loss function.""" + +from typing import Dict, List, Optional, Tuple + +import torch + +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn + + +@POLICY_LOSS_FN.register_module("mix") +class MIXPolicyLossFn(PolicyLossFn): + def __init__( + self, + mu: float = 0.1, + clip_range: Optional[float] = None, + clip_range_low: Optional[float] = None, + clip_range_high: Optional[float] = None, + use_dynamic_bsz: Optional[int] = None, + ppo_mini_batch_size: Optional[int] = None, + gradient_accumulation: Optional[int] = None, + read_batch_size_usual: Optional[int] = None, + read_batch_size_expert: Optional[int] = None, + use_token_level_loss_in_sft: Optional[bool] = True + ) -> None: + self.mu = mu + self.use_dynamic_bsz = use_dynamic_bsz + self.ppo_mini_batch_size = ppo_mini_batch_size + self.gradient_accumulation = gradient_accumulation + self.read_batch_size_usual = read_batch_size_usual + self.read_batch_size_expert = read_batch_size_expert + self.grpo_loss_fn = PPOPolicyLossFn( + clip_range=clip_range, + clip_range_low=clip_range_low, + clip_range_high=clip_range_high, + ) + self.sft_loss_fn = SFTLossFn( + use_token_level_loss=use_token_level_loss_in_sft + ) + + def __call__( # type: ignore + self, + logprob: torch.Tensor, + old_logprob: torch.Tensor, + action_mask: torch.Tensor, + advantages: torch.Tensor, + **kwargs, + ) -> Tuple[torch.Tensor, Dict]: + is_expert_mask = kwargs.get("is_expert_mask", None) + if is_expert_mask is None: + raise ValueError("is_expert_mask is required in MIX") + assert len(is_expert_mask) == logprob.shape[0], f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" + + n_usual_exp = torch.sum(~is_expert_mask).item() + n_expert_exp = torch.sum(is_expert_mask).item() + + if self.use_dynamic_bsz: + per_micro_batch_weight_usual = self.ppo_mini_batch_size / (logprob.shape[0] * self.read_batch_size_usual) + per_micro_batch_weight_expert = self.ppo_mini_batch_size / (logprob.shape[0] * self.read_batch_size_expert) + else: + per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual + per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert + + print(f"debug: {per_micro_batch_weight_usual=}, = {self.ppo_mini_batch_size} / ({logprob.shape[0]} * {self.read_batch_size_usual})") + print(f"debug: {per_micro_batch_weight_expert=}, = {self.ppo_mini_batch_size} / ({logprob.shape[0]} * {self.read_batch_size_expert})") + print(f"debug: {n_usual_exp=}, {n_expert_exp=}") + + if n_usual_exp > 0: + grpo_loss, grpo_metrics = self.grpo_loss_fn( + logprob[~is_expert_mask], + old_logprob[~is_expert_mask], + action_mask[~is_expert_mask], + advantages[~is_expert_mask], + **kwargs, + ) + grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual + grpo_metrics = {k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items()} + else: + grpo_loss = torch.tensor(0.0, device=logprob.device) + grpo_metrics = {} + + # SFT Loss (expert) + if n_expert_exp > 0: + sft_loss, sft_metrics = self.sft_loss_fn( + logprob[is_expert_mask], + action_mask[is_expert_mask], + ) + sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert + sft_metrics = {k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items()} + else: + sft_loss = torch.tensor(0.0, device=logprob.device) + sft_metrics = {} + + loss = grpo_loss + self.mu * sft_loss + + metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} + sft_metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + + return loss, metrics + + @classmethod + def default_args(cls) -> Dict: + return { + "mu": 0.1, + "clip_range": 0.2, + } + + @property + def select_keys(self) -> List[str]: + return [ + "old_logprob", + "action_mask", + "advantages", + "is_expert_mask" + ] diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py index 60f2e268ae..1155a2467c 100644 --- a/trinity/algorithm/sample_strategy/__init__.py +++ b/trinity/algorithm/sample_strategy/__init__.py @@ -4,10 +4,12 @@ SampleStrategy, WarmupSampleStrategy, ) +from trinity.algorithm.sample_strategy.mix_sample_strategy import MixSampleStrategy __all__ = [ "SAMPLE_STRATEGY", "SampleStrategy", "DefaultSampleStrategy", "WarmupSampleStrategy", + "MixSampleStrategy", ] diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py new file mode 100644 index 0000000000..d0776c143d --- /dev/null +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -0,0 +1,111 @@ +import torch +import numpy as np + +from math import ceil +from typing import Any, Dict, List, Tuple + +from verl.trainer.ppo.ray_trainer import DataProto + +from trinity.algorithm.sample_strategy.utils import representative_sample +from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy +from trinity.buffer import get_buffer_reader +from trinity.common.config import BufferConfig +from trinity.common.experience import Experiences +from trinity.utils.timer import Timer + + +@SAMPLE_STRATEGY.register_module("mix") +class MixSampleStrategy(SampleStrategy): + """The default sample strategy.""" + + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.expert_data_ratio = buffer_config.expert_data_ratio + self.usual_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore + ) + if buffer_config.trainer_input.sft_warmup_dataset is None: + raise ValueError("`buffer_config.trainer_input.expert_dataset` is required in MIX algorithm") + + self.expert_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.sft_warmup_dataset, buffer_config + ) + tot_batch_size = buffer_config.read_batch_size + expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) + self.expert_exp_buffer.read_batch_size = expert_batch_size + self.usual_exp_buffer.read_batch_size = tot_batch_size - expert_batch_size + print(f"debug: {self.usual_exp_buffer.read_batch_size=}, {self.expert_exp_buffer.read_batch_size=}") + + def sample(self, step: int) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + usual_exp_list = self.usual_exp_buffer.read() + for exp in usual_exp_list: + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = False + + expert_exp_list = self.expert_exp_buffer.read() + for exp in expert_exp_list: + exp.reward = 0.0 + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = True + + exp_list = usual_exp_list + expert_exp_list + repr_samples = representative_sample(exp_list) + + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) + print(f"debug: {len(usual_exp_list)=}, {len(expert_exp_list)=}") + + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto_mix(exps, is_expert_mask) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + @classmethod + def get_default_config(cls) -> Dict: + return { + "expert_data_ratio": 0.5, + } + + +def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), + "is_expert_mask": is_expert_mask, + } + # print(f"debug: the last one {uid=}") + print(f"debug: (to_data_proto_mix) {is_expert_mask=}") + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + return DataProto.from_single_dict(batch_dict) diff --git a/trinity/common/config.py b/trinity/common/config.py index 7c371f4bcb..d31df1dae2 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -175,6 +175,7 @@ class AlgorithmConfig: algorithm_type: str = "ppo" # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 + mu: float = 0.1 # for mix training sample_strategy: Optional[str] = None sample_strategy_args: Optional[dict] = None @@ -250,6 +251,7 @@ class BufferConfig: batch_size: int = 1 total_epochs: int = 1 + expert_data_ratio: float = 0.0 # for explorer explorer_input: ExplorerInput = field(default_factory=ExplorerInput) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 644fe9a8f5..eeed9e9c0b 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -346,6 +346,24 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 # TODO: check other fields self.enable_preview = config.trainer.enable_preview + if config.algorithm.algorithm_type == "mix": + tot_batch_size = config.buffer.read_batch_size + read_batch_size_expert = math.ceil(config.buffer.expert_data_ratio * tot_batch_size) + read_batch_size_usual = tot_batch_size - read_batch_size_expert + loss_kwargs = { + "use_dynamic_bsz": config.trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz, + "ppo_mini_batch_size": config.trainer.trainer_config.actor_rollout_ref.actor.ppo_mini_batch_size, + "gradient_accumulation": ( + config.trainer.trainer_config.actor_rollout_ref.actor.ppo_mini_batch_size + // config.trainer.trainer_config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, + ), + "read_batch_size_usual": read_batch_size_usual, + "read_batch_size_expert": read_batch_size_expert, + } + config.algorithm.policy_loss_fn_args.update(loss_kwargs) + config.algorithm.policy_loss_fn_args.update( + {"use_token_level_loss_in_sft": config.algorithm.use_token_level_loss} + ) def load_config(config_path: str) -> veRLConfig: schema = OmegaConf.structured(veRLConfig) diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index 616234d0d6..a1b1c2c636 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -294,6 +294,7 @@ def update_policy(self, data: DataProto): # noqa: C901 "ref_log_prob": "ref_logprob", "response_mask": "action_mask", "advantages": "advantages", + "is_expert_mask": "is_expert_mask", } select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()} for trinity_key in self.policy_loss_fn.select_keys: @@ -410,9 +411,11 @@ def update_policy(self, data: DataProto): # noqa: C901 if self.config.use_dynamic_bsz: # relative to the dynamic bsz + print(f"debug: use_dynamic_bsz: {len(data)} / {self.config.ppo_mini_batch_size} = ", len(data) / self.config.ppo_mini_batch_size) loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) else: loss = policy_loss / self.gradient_accumulation + print(f"debug: gradient_accumulation: 1/{self.gradient_accumulation} = ", 1.0/self.gradient_accumulation) loss.backward() append_to_dict(metrics, micro_batch_metrics) From dd19758b5e62921ec32dc2dfee1c0243aa04fb09 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Fri, 13 Jun 2025 14:26:59 +0800 Subject: [PATCH 2/9] prepare for mix algorithm (cont) --- .../algorithm/advantage_fn/mix_advantage.py | 19 +++++++++++-- .../sample_strategy/mix_sample_strategy.py | 27 ++++++++++--------- trinity/common/verl_config.py | 16 ++++++----- 3 files changed, 41 insertions(+), 21 deletions(-) diff --git a/trinity/algorithm/advantage_fn/mix_advantage.py b/trinity/algorithm/advantage_fn/mix_advantage.py index 91fcff587f..af4620f5fe 100644 --- a/trinity/algorithm/advantage_fn/mix_advantage.py +++ b/trinity/algorithm/advantage_fn/mix_advantage.py @@ -26,6 +26,8 @@ def __call__( **kwargs, ) -> Tuple[DataProto, Dict]: is_expert_mask = exps.batch["is_expert_mask"] + device = is_expert_mask.device + batch_size = is_expert_mask.shape[0] # Process tensors tensors = { @@ -38,13 +40,26 @@ def __call__( } # Build new DataProto - exps = DataProto.from_dict( + new_exps = DataProto.from_dict( tensors=tensors, non_tensors=non_tensors, meta_info=exps.meta_info ) - return super().__call__(exps, **kwargs) + new_exps, new_metrics = super().__call__(new_exps, **kwargs) + # Get full advantages + full_advantages = torch.zeros((batch_size, new_exps.batch["advantages"].shape[1]), device=device) + full_returns = torch.zeros((batch_size, new_exps.batch["returns"].shape[1]), device=device) + + # Fill in the non-expert parts with computed values + full_advantages[~is_expert_mask] = new_exps.batch["advantages"] + full_returns[~is_expert_mask] = new_exps.batch["returns"] + + # Write back to original exps + exps.batch["advantages"] = full_advantages + exps.batch["returns"] = full_returns + # TODO: change new_metrics + return exps, new_metrics @classmethod def default_args(cls) -> Dict: return { diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index d0776c143d..bbcc9fda3d 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -1,3 +1,4 @@ +import copy import torch import numpy as np @@ -21,21 +22,26 @@ class MixSampleStrategy(SampleStrategy): def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): super().__init__(buffer_config, trainer_type) self.expert_data_ratio = buffer_config.expert_data_ratio + tot_batch_size = buffer_config.read_batch_size + expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) + + # experience buffer + usual_buffer_config = copy.deepcopy(buffer_config) + usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size self.usual_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore + buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore ) + if buffer_config.trainer_input.sft_warmup_dataset is None: - raise ValueError("`buffer_config.trainer_input.expert_dataset` is required in MIX algorithm") + raise ValueError("`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm") + # expert experience buffer + expert_buffer_config = copy.deepcopy(buffer_config) + expert_buffer_config.read_batch_size = expert_batch_size self.expert_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.sft_warmup_dataset, buffer_config + buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config ) - tot_batch_size = buffer_config.read_batch_size - expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) - self.expert_exp_buffer.read_batch_size = expert_batch_size - self.usual_exp_buffer.read_batch_size = tot_batch_size - expert_batch_size - print(f"debug: {self.usual_exp_buffer.read_batch_size=}, {self.expert_exp_buffer.read_batch_size=}") - + def sample(self, step: int) -> Tuple[Any, Dict, List]: metrics = {} with Timer(metrics, "read_time"): @@ -57,7 +63,6 @@ def sample(self, step: int) -> Tuple[Any, Dict, List]: repr_samples = representative_sample(exp_list) is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) - print(f"debug: {len(usual_exp_list)=}, {len(expert_exp_list)=}") with Timer(metrics, "gather_time"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore @@ -93,8 +98,6 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> ), "is_expert_mask": is_expert_mask, } - # print(f"debug: the last one {uid=}") - print(f"debug: (to_data_proto_mix) {is_expert_mask=}") if experiences.rewards is not None: token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) eos_mask_idx = cumsum.argmax(dim=-1) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index eeed9e9c0b..b5d3ad32d8 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -105,6 +105,8 @@ class Rollout: val_kwargs: _ValKwargs = field(default_factory=_ValKwargs) temperature: float = 1.0 n: int = 1 # > 1 for grpo + log_prob_micro_batch_size: Optional[int] = None + log_prob_micro_batch_size_per_gpu: int = 1 @dataclass @@ -351,12 +353,11 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 read_batch_size_expert = math.ceil(config.buffer.expert_data_ratio * tot_batch_size) read_batch_size_usual = tot_batch_size - read_batch_size_expert loss_kwargs = { - "use_dynamic_bsz": config.trainer.trainer_config.actor_rollout_ref.actor.use_dynamic_bsz, - "ppo_mini_batch_size": config.trainer.trainer_config.actor_rollout_ref.actor.ppo_mini_batch_size, - "gradient_accumulation": ( - config.trainer.trainer_config.actor_rollout_ref.actor.ppo_mini_batch_size - // config.trainer.trainer_config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, - ), + "use_dynamic_bsz": self.actor_rollout_ref.actor.use_dynamic_bsz, + "ppo_mini_batch_size": self.actor_rollout_ref.actor.ppo_mini_batch_size * self.actor_rollout_ref.rollout.n // world_size, # TODO: check + "gradient_accumulation": + self.actor_rollout_ref.actor.ppo_mini_batch_size * self.actor_rollout_ref.rollout.n + // self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, "read_batch_size_usual": read_batch_size_usual, "read_batch_size_expert": read_batch_size_expert, } @@ -364,7 +365,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 config.algorithm.policy_loss_fn_args.update( {"use_token_level_loss_in_sft": config.algorithm.use_token_level_loss} ) - + print(f"{config.buffer.read_batch_size=}, {loss_kwargs['ppo_mini_batch_size']=}") + print(f"{self.actor_rollout_ref.actor.ppo_mini_batch_size=}") def load_config(config_path: str) -> veRLConfig: schema = OmegaConf.structured(veRLConfig) yaml_config = OmegaConf.load(config_path) From 970e321a36548fc18b4ef1dc7f5d73db6cb32bef Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Fri, 13 Jun 2025 15:11:00 +0800 Subject: [PATCH 3/9] clear debug logs --- trinity/algorithm/policy_loss_fn/mix_policy_loss.py | 6 +----- trinity/trainer/verl/dp_actor.py | 2 -- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 2d7d55c470..353f6f9d1a 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -61,10 +61,6 @@ def __call__( # type: ignore else: per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert - - print(f"debug: {per_micro_batch_weight_usual=}, = {self.ppo_mini_batch_size} / ({logprob.shape[0]} * {self.read_batch_size_usual})") - print(f"debug: {per_micro_batch_weight_expert=}, = {self.ppo_mini_batch_size} / ({logprob.shape[0]} * {self.read_batch_size_expert})") - print(f"debug: {n_usual_exp=}, {n_expert_exp=}") if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( @@ -92,7 +88,7 @@ def __call__( # type: ignore sft_loss = torch.tensor(0.0, device=logprob.device) sft_metrics = {} - loss = grpo_loss + self.mu * sft_loss + loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} sft_metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index a1b1c2c636..ae691e111a 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -411,11 +411,9 @@ def update_policy(self, data: DataProto): # noqa: C901 if self.config.use_dynamic_bsz: # relative to the dynamic bsz - print(f"debug: use_dynamic_bsz: {len(data)} / {self.config.ppo_mini_batch_size} = ", len(data) / self.config.ppo_mini_batch_size) loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) else: loss = policy_loss / self.gradient_accumulation - print(f"debug: gradient_accumulation: 1/{self.gradient_accumulation} = ", 1.0/self.gradient_accumulation) loss.backward() append_to_dict(metrics, micro_batch_metrics) From 05b14ee0eb37644c1d8b510c821437b49de5853d Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Fri, 13 Jun 2025 15:18:54 +0800 Subject: [PATCH 4/9] add example markdown --- .../source/tutorial/example_mix_algo.md | 181 +++++++++++------- examples/grpo_math/math_mix.yaml | 76 -------- trinity/algorithm/advantage_fn/__init__.py | 2 +- .../algorithm/advantage_fn/mix_advantage.py | 14 +- trinity/algorithm/algorithm.py | 4 - trinity/algorithm/policy_loss_fn/__init__.py | 2 +- .../policy_loss_fn/mix_policy_loss.py | 41 ++-- trinity/algorithm/sample_strategy/__init__.py | 2 +- .../sample_strategy/mix_sample_strategy.py | 21 +- trinity/common/verl_config.py | 12 +- 10 files changed, 165 insertions(+), 190 deletions(-) delete mode 100644 examples/grpo_math/math_mix.yaml diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 2c17b79f77..1c50f962cf 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -1,7 +1,7 @@ # Integrate An New Algorithm -This guide introduces how to integrate a new algorithm to Trinity-RFT. +This guide introduces how to integrate a new algorithm to Trinity-RFT. As an example, we incorporate some "expert" data generated by a more advanced LLM and propose an algorithm named MIX , which optimizes the following policy objective: $$ @@ -42,19 +42,17 @@ class MIXAlgorithm(AlgorithmType): use_critic: bool = False use_reference: bool = True use_advantage: bool = True - use_rollout: bool = False + use_rollout: bool = True can_balance_batch: bool = True schema: type = ExperienceModel - @classmethod - def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences: - return Experiences.gather_experiences(exps, pad_token_id) - @classmethod def get_default_config(cls) -> Dict: return { "repeat_times": 8, "policy_loss_fn": "mix", + "advantage_fn": "mix", + "sample_strategy": "mix", "mu": 0.1, } ``` @@ -76,51 +74,65 @@ We need to read two kinds of experiences: usual experiences and expert experienc ```python -@SAMPLE_STRATEGY.register_module("mix") class MixSampleStrategy(SampleStrategy): """The default sample strategy.""" def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): super().__init__(buffer_config, trainer_type) self.expert_data_ratio = buffer_config.expert_data_ratio + tot_batch_size = buffer_config.read_batch_size + expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) + + # experience buffer + usual_buffer_config = copy.deepcopy(buffer_config) + usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size self.usual_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore + buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore ) - if buffer_config.trainer_input.expert_dataset is None: - raise ValueError("`buffer_config.trainer_input.expert_dataset` is required in MIX algorithm") + if buffer_config.trainer_input.sft_warmup_dataset is None: + raise ValueError( + "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" + ) + + # expert experience buffer + expert_buffer_config = copy.deepcopy(buffer_config) + expert_buffer_config.read_batch_size = expert_batch_size self.expert_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.expert_dataset, buffer_config + buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config ) - tot_batch_size = buffer_config.batch_size - self.expert_exp_buffer.read_batch_size = ceil(self.expert_data_ratio * tot_batch_size) - self.usual_exp_buffer.read_batch_size = tot_batch_size - self.expert_exp_buffer.read_batch_size - - - def sample(self, step: int, **kwargs) -> DataProto: - usual_exp_list = self.exp_buffer.read() - for exp in usual_exp_list: - exp.info["is_expert"] = False - - expert_exp_list = self.expert_exp_buffer.read() - for exp in expert_exp_list: - exp.info["is_expert"] = True - - exp_list = usual_exp_list + expert_exp_list + + def sample(self, step: int) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + usual_exp_list = self.usual_exp_buffer.read() + for exp in usual_exp_list: + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = False + + expert_exp_list = self.expert_exp_buffer.read() + for exp in expert_exp_list: + exp.reward = 0.0 + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = True + + exp_list = usual_exp_list + expert_exp_list + repr_samples = representative_sample(exp_list) + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) - + + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + if self.trainer_type == "verl": with Timer(metrics, "convert_time"): data = to_data_proto_mix(exps, is_expert_mask) return data, metrics, repr_samples else: raise NotImplementedError(f"backend {self.trainer_type} is not supported") - - @classmethod - def get_default_config(cls) -> Dict: - return { - "expert_data_ratio": 0.5, - } ``` We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type. @@ -182,11 +194,11 @@ class MIXAdvantageFn(GRPOAdvantageFn): **kwargs, ) -> Tuple[DataProto, Dict]: is_expert_mask = exps.batch["is_expert_mask"] + device = is_expert_mask.device + batch_size = is_expert_mask.shape[0] # Process tensors - tensors = { - k: tensor[~is_expert_mask] for k, tensor in exps.batch.items() - } + tensors = {k: tensor[~is_expert_mask] for k, tensor in exps.batch.items()} # Process non-tensors non_tensors = { @@ -194,12 +206,26 @@ class MIXAdvantageFn(GRPOAdvantageFn): } # Build new DataProto - exps = DataProto.from_dict( - tensors=tensors, - non_tensors=non_tensors, - meta_info=exps.meta_info + new_exps = DataProto.from_dict( + tensors=tensors, non_tensors=non_tensors, meta_info=exps.meta_info + ) + new_exps, new_metrics = super().__call__(new_exps, **kwargs) + + # Get full advantages + full_advantages = torch.zeros( + (batch_size, new_exps.batch["advantages"].shape[1]), device=device ) - return super().__call__(exps, **kwargs) + full_returns = torch.zeros((batch_size, new_exps.batch["returns"].shape[1]), device=device) + + # Fill in the non-expert parts with computed values + full_advantages[~is_expert_mask] = new_exps.batch["advantages"] + full_returns[~is_expert_mask] = new_exps.batch["returns"] + + # Write back to original exps + exps.batch["advantages"] = full_advantages + exps.batch["returns"] = full_returns + # TODO: change new_metrics + return exps, new_metrics @classmethod def default_args(cls) -> Dict: @@ -214,7 +240,6 @@ class MIXAdvantageFn(GRPOAdvantageFn): We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively. ```python - @POLICY_LOSS_FN.register_module("mix") class MIXPolicyLossFn(PolicyLossFn): def __init__( @@ -223,18 +248,26 @@ class MIXPolicyLossFn(PolicyLossFn): clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, - use_token_level_loss_in_sft: Optional[bool] = True + use_dynamic_bsz: Optional[int] = None, + ppo_mini_batch_size: Optional[int] = None, + gradient_accumulation: Optional[int] = None, + read_batch_size_usual: Optional[int] = None, + read_batch_size_expert: Optional[int] = None, + use_token_level_loss_in_sft: Optional[bool] = True, ) -> None: self.mu = mu + self.use_dynamic_bsz = use_dynamic_bsz + self.ppo_mini_batch_size = ppo_mini_batch_size + self.gradient_accumulation = gradient_accumulation + self.read_batch_size_usual = read_batch_size_usual + self.read_batch_size_expert = read_batch_size_expert self.grpo_loss_fn = PPOPolicyLossFn( clip_range=clip_range, clip_range_low=clip_range_low, clip_range_high=clip_range_high, ) - self.sft_loss_fn = SFTLossFn( - use_token_level_loss=use_token_level_loss_in_sft - ) - + self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft) + def __call__( # type: ignore self, logprob: torch.Tensor, @@ -244,14 +277,26 @@ class MIXPolicyLossFn(PolicyLossFn): **kwargs, ) -> Tuple[torch.Tensor, Dict]: is_expert_mask = kwargs.get("is_expert_mask", None) - per_micro_batch_weight = kwargs.get("per_micro_batch_weight", None) if is_expert_mask is None: raise ValueError("is_expert_mask is required in MIX") - assert len(is_expert_mask) == logprob.shape[0], f"{len(is_expert_mask)=} != {logprob.shape[0]=}" - - n_usual_exp = torch.sum(~is_expert_mask) - n_expert_exp = torch.sum(is_expert_mask) - + assert ( + len(is_expert_mask) == logprob.shape[0] + ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" + + n_usual_exp = torch.sum(~is_expert_mask).item() + n_expert_exp = torch.sum(is_expert_mask).item() + + if self.use_dynamic_bsz: + per_micro_batch_weight_usual = self.ppo_mini_batch_size / ( + logprob.shape[0] * self.read_batch_size_usual + ) + per_micro_batch_weight_expert = self.ppo_mini_batch_size / ( + logprob.shape[0] * self.read_batch_size_expert + ) + else: + per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual + per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert + if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( logprob[~is_expert_mask], @@ -260,8 +305,10 @@ class MIXPolicyLossFn(PolicyLossFn): advantages[~is_expert_mask], **kwargs, ) - grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight - grpo_metrics = {k: v * n_usual_exp * per_micro_batch_weight for k, v in grpo_metrics.items()} + grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual + grpo_metrics = { + k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() + } else: grpo_loss = torch.tensor(0.0, device=logprob.device) grpo_metrics = {} @@ -272,17 +319,19 @@ class MIXPolicyLossFn(PolicyLossFn): logprob[is_expert_mask], action_mask[is_expert_mask], ) - sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight - sft_metrics = {k: v * n_expert_exp * per_micro_batch_weight for k, v in sft_metrics.items()} + sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert + sft_metrics = { + k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() + } else: sft_loss = torch.tensor(0.0, device=logprob.device) sft_metrics = {} - loss = grpo_loss + self.mu * sft_loss + loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} sft_metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) - + return loss, metrics @classmethod @@ -294,11 +343,7 @@ class MIXPolicyLossFn(PolicyLossFn): @property def select_keys(self) -> List[str]: - return [ - "old_logprob", - "action_mask", - "advantages", - ] + return ["old_logprob", "action_mask", "advantages", "is_expert_mask"] ``` ## Step 5: Run the Experiment @@ -307,5 +352,9 @@ With the above newly-defined classes, we can run the experiments without modifyi An example showing some important configurations is shown below. ```yaml -TODO -``` \ No newline at end of file +algorithm: + algorithm_type: mix + mu: 0.5 +buffer: + expert_data_ratio: 0.25 +``` diff --git a/examples/grpo_math/math_mix.yaml b/examples/grpo_math/math_mix.yaml deleted file mode 100644 index 8a34a1a235..0000000000 --- a/examples/grpo_math/math_mix.yaml +++ /dev/null @@ -1,76 +0,0 @@ -project: "rft_sft_mixed" -name: "test" -checkpoint_root_dir: /mnt/yuchang/checkpoints/ -algorithm: - algorithm_type: mix - mu: 0.5 # NEW - repeat_times: 8 -model: - model_path: /mnt/checkpoint/qwen25/Qwen2.5-1.5B-Instruct - max_prompt_tokens: 1024 - max_response_tokens: 16392 -cluster: - node_num: 1 - gpu_per_node: 4 -buffer: - total_epochs: 1 - batch_size: 32 - expert_data_ratio: 0.5 # NEW - max_retry_times: 3 - max_retry_interval: 1 - explorer_input: - taskset: - name: openr1 - storage_type: file - path: /mnt/yuchang/datasets/openr1_data/openr1_data_filtered_int - split: 'train' - format: - prompt_key: 'problem' - response_key: 'answer' - rollout_args: - temperature: 1.0 - logprobs: 0 - eval_tasksets: - - name: openr1 - storage_type: file - path: /mnt/yuchang/datasets/openr1_data/openr1_data_filtered_int - split: 'test' - format: - prompt_key: 'problem' - response_key: 'answer' - default_workflow_type: 'math_workflow' - trainer_input: - experience_buffer: - name: openr1_buffer - storage_type: queue - path: 'sqlite:////mnt/yuchang/checkpoints/${project}/${name}/openr1.db' - sft_warmup_dataset: - name: openr1_sft - storage_type: file - algorithm_type: sft - path: /mnt/yuchang/datasets/openr1_data_sft - split: 'train' - format: - prompt_type: messages - messages_key: 'messages' -explorer: - eval_interval: 10 - runner_num: 32 - rollout_model: - engine_type: vllm_async - engine_num: 2 - tensor_parallel_size: 1 - enable_prefix_caching: false - enforce_eager: true - dtype: bfloat16 - seed: 42 -synchronizer: - sync_method: 'nccl' - sync_interval: 1 - sync_timeout: 1200 -trainer: - trainer_type: 'verl' - trainer_config_path: 'examples/grpo_math/train_math.yaml' - save_interval: 50 -monitor: - monitor_type: wandb \ No newline at end of file diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 7e6b076843..5e5aad8c4f 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -1,5 +1,6 @@ from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn +from trinity.algorithm.advantage_fn.mix_advantage import MIXAdvantageFn from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import ( @@ -7,7 +8,6 @@ ) from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn -from trinity.algorithm.advantage_fn.mix_advantage import MIXAdvantageFn __all__ = [ "ADVANTAGE_FN", diff --git a/trinity/algorithm/advantage_fn/mix_advantage.py b/trinity/algorithm/advantage_fn/mix_advantage.py index af4620f5fe..41db245ea2 100644 --- a/trinity/algorithm/advantage_fn/mix_advantage.py +++ b/trinity/algorithm/advantage_fn/mix_advantage.py @@ -3,7 +3,6 @@ from typing import Dict, Tuple import torch - from verl import DataProto from trinity.algorithm.advantage_fn import ADVANTAGE_FN @@ -30,9 +29,7 @@ def __call__( batch_size = is_expert_mask.shape[0] # Process tensors - tensors = { - k: tensor[~is_expert_mask] for k, tensor in exps.batch.items() - } + tensors = {k: tensor[~is_expert_mask] for k, tensor in exps.batch.items()} # Process non-tensors non_tensors = { @@ -41,14 +38,14 @@ def __call__( # Build new DataProto new_exps = DataProto.from_dict( - tensors=tensors, - non_tensors=non_tensors, - meta_info=exps.meta_info + tensors=tensors, non_tensors=non_tensors, meta_info=exps.meta_info ) new_exps, new_metrics = super().__call__(new_exps, **kwargs) # Get full advantages - full_advantages = torch.zeros((batch_size, new_exps.batch["advantages"].shape[1]), device=device) + full_advantages = torch.zeros( + (batch_size, new_exps.batch["advantages"].shape[1]), device=device + ) full_returns = torch.zeros((batch_size, new_exps.batch["returns"].shape[1]), device=device) # Fill in the non-expert parts with computed values @@ -60,6 +57,7 @@ def __call__( exps.batch["returns"] = full_returns # TODO: change new_metrics return exps, new_metrics + @classmethod def default_args(cls) -> Dict: return { diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 8f62890211..93c19879c9 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -193,10 +193,6 @@ class MIXAlgorithm(AlgorithmType): can_balance_batch: bool = True schema: type = ExperienceModel - @classmethod - def check_config(cls, config: Config) -> None: - pass - @classmethod def get_default_config(cls) -> Dict: return { diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 200a2dbe4f..705fb2525a 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -1,9 +1,9 @@ from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn +from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn -from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn __all__ = [ "POLICY_LOSS_FN", diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 353f6f9d1a..82b1923147 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -22,7 +22,7 @@ def __init__( gradient_accumulation: Optional[int] = None, read_batch_size_usual: Optional[int] = None, read_batch_size_expert: Optional[int] = None, - use_token_level_loss_in_sft: Optional[bool] = True + use_token_level_loss_in_sft: Optional[bool] = True, ) -> None: self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz @@ -35,10 +35,8 @@ def __init__( clip_range_low=clip_range_low, clip_range_high=clip_range_high, ) - self.sft_loss_fn = SFTLossFn( - use_token_level_loss=use_token_level_loss_in_sft - ) - + self.sft_loss_fn = SFTLossFn(use_token_level_loss=use_token_level_loss_in_sft) + def __call__( # type: ignore self, logprob: torch.Tensor, @@ -50,17 +48,23 @@ def __call__( # type: ignore is_expert_mask = kwargs.get("is_expert_mask", None) if is_expert_mask is None: raise ValueError("is_expert_mask is required in MIX") - assert len(is_expert_mask) == logprob.shape[0], f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" - + assert ( + len(is_expert_mask) == logprob.shape[0] + ), f"Error: {len(is_expert_mask)=} != {logprob.shape[0]=}" + n_usual_exp = torch.sum(~is_expert_mask).item() n_expert_exp = torch.sum(is_expert_mask).item() if self.use_dynamic_bsz: - per_micro_batch_weight_usual = self.ppo_mini_batch_size / (logprob.shape[0] * self.read_batch_size_usual) - per_micro_batch_weight_expert = self.ppo_mini_batch_size / (logprob.shape[0] * self.read_batch_size_expert) + per_micro_batch_weight_usual = self.ppo_mini_batch_size / ( + logprob.shape[0] * self.read_batch_size_usual + ) + per_micro_batch_weight_expert = self.ppo_mini_batch_size / ( + logprob.shape[0] * self.read_batch_size_expert + ) else: per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual - per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert + per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( @@ -71,7 +75,9 @@ def __call__( # type: ignore **kwargs, ) grpo_loss = grpo_loss * n_usual_exp * per_micro_batch_weight_usual - grpo_metrics = {k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items()} + grpo_metrics = { + k: v * n_usual_exp * per_micro_batch_weight_usual for k, v in grpo_metrics.items() + } else: grpo_loss = torch.tensor(0.0, device=logprob.device) grpo_metrics = {} @@ -83,7 +89,9 @@ def __call__( # type: ignore action_mask[is_expert_mask], ) sft_loss = sft_loss * n_expert_exp * per_micro_batch_weight_expert - sft_metrics = {k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items()} + sft_metrics = { + k: v * n_expert_exp * per_micro_batch_weight_expert for k, v in sft_metrics.items() + } else: sft_loss = torch.tensor(0.0, device=logprob.device) sft_metrics = {} @@ -92,7 +100,7 @@ def __call__( # type: ignore metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} sft_metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) - + return loss, metrics @classmethod @@ -104,9 +112,4 @@ def default_args(cls) -> Dict: @property def select_keys(self) -> List[str]: - return [ - "old_logprob", - "action_mask", - "advantages", - "is_expert_mask" - ] + return ["old_logprob", "action_mask", "advantages", "is_expert_mask"] diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py index 1155a2467c..cd4b9e0d66 100644 --- a/trinity/algorithm/sample_strategy/__init__.py +++ b/trinity/algorithm/sample_strategy/__init__.py @@ -1,10 +1,10 @@ +from trinity.algorithm.sample_strategy.mix_sample_strategy import MixSampleStrategy from trinity.algorithm.sample_strategy.sample_strategy import ( SAMPLE_STRATEGY, DefaultSampleStrategy, SampleStrategy, WarmupSampleStrategy, ) -from trinity.algorithm.sample_strategy.mix_sample_strategy import MixSampleStrategy __all__ = [ "SAMPLE_STRATEGY", diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index bbcc9fda3d..e2a6cd4b59 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -1,14 +1,13 @@ import copy -import torch -import numpy as np - from math import ceil from typing import Any, Dict, List, Tuple +import numpy as np +import torch from verl.trainer.ppo.ray_trainer import DataProto -from trinity.algorithm.sample_strategy.utils import representative_sample from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy +from trinity.algorithm.sample_strategy.utils import representative_sample from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig from trinity.common.experience import Experiences @@ -24,16 +23,18 @@ def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): self.expert_data_ratio = buffer_config.expert_data_ratio tot_batch_size = buffer_config.read_batch_size expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) - + # experience buffer usual_buffer_config = copy.deepcopy(buffer_config) usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size self.usual_exp_buffer = get_buffer_reader( buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore ) - + if buffer_config.trainer_input.sft_warmup_dataset is None: - raise ValueError("`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm") + raise ValueError( + "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" + ) # expert experience buffer expert_buffer_config = copy.deepcopy(buffer_config) @@ -61,19 +62,19 @@ def sample(self, step: int) -> Tuple[Any, Dict, List]: exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) - + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) with Timer(metrics, "gather_time"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - + if self.trainer_type == "verl": with Timer(metrics, "convert_time"): data = to_data_proto_mix(exps, is_expert_mask) return data, metrics, repr_samples else: raise NotImplementedError(f"backend {self.trainer_type} is not supported") - + @classmethod def get_default_config(cls) -> Dict: return { diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index b5d3ad32d8..9f7a2c4b7c 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -354,10 +354,12 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 read_batch_size_usual = tot_batch_size - read_batch_size_expert loss_kwargs = { "use_dynamic_bsz": self.actor_rollout_ref.actor.use_dynamic_bsz, - "ppo_mini_batch_size": self.actor_rollout_ref.actor.ppo_mini_batch_size * self.actor_rollout_ref.rollout.n // world_size, # TODO: check - "gradient_accumulation": - self.actor_rollout_ref.actor.ppo_mini_batch_size * self.actor_rollout_ref.rollout.n - // self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, + "ppo_mini_batch_size": self.actor_rollout_ref.actor.ppo_mini_batch_size + * self.actor_rollout_ref.rollout.n + // world_size, + "gradient_accumulation": self.actor_rollout_ref.actor.ppo_mini_batch_size + * self.actor_rollout_ref.rollout.n + // self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, "read_batch_size_usual": read_batch_size_usual, "read_batch_size_expert": read_batch_size_expert, } @@ -367,6 +369,8 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 ) print(f"{config.buffer.read_batch_size=}, {loss_kwargs['ppo_mini_batch_size']=}") print(f"{self.actor_rollout_ref.actor.ppo_mini_batch_size=}") + + def load_config(config_path: str) -> veRLConfig: schema = OmegaConf.structured(veRLConfig) yaml_config = OmegaConf.load(config_path) From 05ad9757872735c6c552f697d75623a1efa1b39f Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 16 Jun 2025 15:36:10 +0800 Subject: [PATCH 5/9] fix typos --- .../source/tutorial/example_mix_algo.md | 11 +- .../policy_loss_fn/mix_policy_loss.py | 9 +- trinity/algorithm/sample_strategy/__init__.py | 1 - .../sample_strategy/mix_sample_strategy.py | 115 ------------------ .../sample_strategy/sample_strategy.py | 78 +++++++++++- trinity/algorithm/sample_strategy/utils.py | 33 +++++ trinity/common/verl_config.py | 3 +- 7 files changed, 123 insertions(+), 127 deletions(-) delete mode 100644 trinity/algorithm/sample_strategy/mix_sample_strategy.py diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 1c50f962cf..0ec34912ec 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -14,7 +14,7 @@ $$ \log \pi_\theta(o'_{b,t} \mid q'_b, o'_{b, None: self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz @@ -330,7 +330,8 @@ class MIXPolicyLossFn(PolicyLossFn): loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} - sft_metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + metrics.update({"loss": loss.item()}) return loss, metrics @@ -355,6 +356,8 @@ An example showing some important configurations is shown below. algorithm: algorithm_type: mix mu: 0.5 + repeat_times: 8 + use_token_level_loss: False buffer: expert_data_ratio: 0.25 ``` diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 82b1923147..6e44d8f931 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -22,7 +22,7 @@ def __init__( gradient_accumulation: Optional[int] = None, read_batch_size_usual: Optional[int] = None, read_batch_size_expert: Optional[int] = None, - use_token_level_loss_in_sft: Optional[bool] = True, + use_token_level_loss_in_sft: bool = False, ) -> None: self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz @@ -63,8 +63,8 @@ def __call__( # type: ignore logprob.shape[0] * self.read_batch_size_expert ) else: - per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual - per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert + per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore + per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( @@ -99,7 +99,8 @@ def __call__( # type: ignore loss = (1 - self.mu) * grpo_loss + self.mu * sft_loss metrics = {f"usual/{k}": v for k, v in grpo_metrics.items()} - sft_metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + metrics.update({f"expert/{k}": v for k, v in sft_metrics.items()}) + metrics.update({"loss": loss.item()}) return loss, metrics diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py index cd4b9e0d66..385ff58147 100644 --- a/trinity/algorithm/sample_strategy/__init__.py +++ b/trinity/algorithm/sample_strategy/__init__.py @@ -1,4 +1,3 @@ -from trinity.algorithm.sample_strategy.mix_sample_strategy import MixSampleStrategy from trinity.algorithm.sample_strategy.sample_strategy import ( SAMPLE_STRATEGY, DefaultSampleStrategy, diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py deleted file mode 100644 index e2a6cd4b59..0000000000 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ /dev/null @@ -1,115 +0,0 @@ -import copy -from math import ceil -from typing import Any, Dict, List, Tuple - -import numpy as np -import torch -from verl.trainer.ppo.ray_trainer import DataProto - -from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy -from trinity.algorithm.sample_strategy.utils import representative_sample -from trinity.buffer import get_buffer_reader -from trinity.common.config import BufferConfig -from trinity.common.experience import Experiences -from trinity.utils.timer import Timer - - -@SAMPLE_STRATEGY.register_module("mix") -class MixSampleStrategy(SampleStrategy): - """The default sample strategy.""" - - def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): - super().__init__(buffer_config, trainer_type) - self.expert_data_ratio = buffer_config.expert_data_ratio - tot_batch_size = buffer_config.read_batch_size - expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) - - # experience buffer - usual_buffer_config = copy.deepcopy(buffer_config) - usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size - self.usual_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore - ) - - if buffer_config.trainer_input.sft_warmup_dataset is None: - raise ValueError( - "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" - ) - - # expert experience buffer - expert_buffer_config = copy.deepcopy(buffer_config) - expert_buffer_config.read_batch_size = expert_batch_size - self.expert_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config - ) - - def sample(self, step: int) -> Tuple[Any, Dict, List]: - metrics = {} - with Timer(metrics, "read_time"): - usual_exp_list = self.usual_exp_buffer.read() - for exp in usual_exp_list: - if exp.info is None: - exp.info = {} - exp.info["is_expert"] = False - - expert_exp_list = self.expert_exp_buffer.read() - for exp in expert_exp_list: - exp.reward = 0.0 - exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) - if exp.info is None: - exp.info = {} - exp.info["is_expert"] = True - - exp_list = usual_exp_list + expert_exp_list - repr_samples = representative_sample(exp_list) - - is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) - - with Timer(metrics, "gather_time"): - exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - - if self.trainer_type == "verl": - with Timer(metrics, "convert_time"): - data = to_data_proto_mix(exps, is_expert_mask) - return data, metrics, repr_samples - else: - raise NotImplementedError(f"backend {self.trainer_type} is not supported") - - @classmethod - def get_default_config(cls) -> Dict: - return { - "expert_data_ratio": 0.5, - } - - -def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: - attention_mask = experiences.attention_masks - cumsum = torch.cumsum(attention_mask, dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - batch_dict = { - "uid": np.array(experiences.run_ids), - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - "is_expert_mask": is_expert_mask, - } - if experiences.rewards is not None: - token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) - eos_mask_idx = cumsum.argmax(dim=-1) - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] - batch_dict.update( - { - "token_level_scores": token_level_rewards, - "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore - } - ) - return DataProto.from_single_dict(batch_dict) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 8686a0d497..f2fc079003 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -1,7 +1,15 @@ +import copy from abc import ABC, abstractmethod +from math import ceil from typing import Any, Dict, List, Tuple -from trinity.algorithm.sample_strategy.utils import representative_sample, to_data_proto +import torch + +from trinity.algorithm.sample_strategy.utils import ( + representative_sample, + to_data_proto, + to_data_proto_mix, +) from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig from trinity.common.experience import Experiences @@ -112,3 +120,71 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: return data, metrics, repr_samples else: raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + +@SAMPLE_STRATEGY.register_module("mix") +class MixSampleStrategy(SampleStrategy): + """The default sample strategy.""" + + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.expert_data_ratio = buffer_config.expert_data_ratio + tot_batch_size = buffer_config.read_batch_size + expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) + + # experience buffer + usual_buffer_config = copy.deepcopy(buffer_config) + usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size + self.usual_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore + ) + + if buffer_config.trainer_input.sft_warmup_dataset is None: + raise ValueError( + "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" + ) + + # expert experience buffer + expert_buffer_config = copy.deepcopy(buffer_config) + expert_buffer_config.read_batch_size = expert_batch_size + self.expert_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config + ) + + def sample(self, step: int) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + usual_exp_list = self.usual_exp_buffer.read() + for exp in usual_exp_list: + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = False + + expert_exp_list = self.expert_exp_buffer.read() + for exp in expert_exp_list: + exp.reward = 0.0 + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = True + + exp_list = usual_exp_list + expert_exp_list + repr_samples = representative_sample(exp_list) + + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) + + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto_mix(exps, is_expert_mask) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + @classmethod + def get_default_config(cls) -> Dict: + return { + "expert_data_ratio": 0.5, + } diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py index 8c443a20b1..be2c06bd34 100644 --- a/trinity/algorithm/sample_strategy/utils.py +++ b/trinity/algorithm/sample_strategy/utils.py @@ -40,6 +40,39 @@ def to_data_proto(experiences: Experiences) -> DataProto: return DataProto.from_single_dict(batch_dict) +def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), + "is_expert_mask": is_expert_mask, + } + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + return DataProto.from_single_dict(batch_dict) + + def representative_sample(experiences: List[Experience]) -> List[dict]: if experiences[0].reward is None: sample = random.choice(experiences) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 9f7a2c4b7c..73a2369bd6 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -353,6 +353,7 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 read_batch_size_expert = math.ceil(config.buffer.expert_data_ratio * tot_batch_size) read_batch_size_usual = tot_batch_size - read_batch_size_expert loss_kwargs = { + "mu": config.algorithm.mu, "use_dynamic_bsz": self.actor_rollout_ref.actor.use_dynamic_bsz, "ppo_mini_batch_size": self.actor_rollout_ref.actor.ppo_mini_batch_size * self.actor_rollout_ref.rollout.n @@ -367,8 +368,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 config.algorithm.policy_loss_fn_args.update( {"use_token_level_loss_in_sft": config.algorithm.use_token_level_loss} ) - print(f"{config.buffer.read_batch_size=}, {loss_kwargs['ppo_mini_batch_size']=}") - print(f"{self.actor_rollout_ref.actor.ppo_mini_batch_size=}") def load_config(config_path: str) -> veRLConfig: From 23edc18141d5fa4ce23a35f5924786054f553493 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 16 Jun 2025 17:35:07 +0800 Subject: [PATCH 6/9] remove unnecessary adv compute --- .../source/tutorial/example_mix_algo.md | 84 +++---------------- trinity/algorithm/advantage_fn/__init__.py | 2 - .../algorithm/advantage_fn/mix_advantage.py | 65 -------------- trinity/algorithm/algorithm.py | 2 +- .../policy_loss_fn/mix_policy_loss.py | 2 +- trinity/common/config.py | 1 - trinity/common/verl_config.py | 4 - 7 files changed, 13 insertions(+), 147 deletions(-) delete mode 100644 trinity/algorithm/advantage_fn/mix_advantage.py diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 0ec34912ec..105745ae1c 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -51,20 +51,19 @@ class MIXAlgorithm(AlgorithmType): return { "repeat_times": 8, "policy_loss_fn": "mix", - "advantage_fn": "mix", + "advantage_fn": "grpo", "sample_strategy": "mix", "mu": 0.1, } ``` -We also define some necessary configuration parameters for later use, including the weighting factor $\mu$ and the batch size of expert experiences $B'$, calculated as the product of `buffer.expert_data_ratio`, `buffer.expert_data_ratio` and `algorithm.repeat_times`. +We also define some necessary configuration parameters for later use, including the weighting factor $\mu$ and the batch size of expert experiences $B'$, calculated as the product of `buffer.batch_size`, `buffer.expert_data_ratio` and `algorithm.repeat_times`. ```python -class AlgorithmConfig: +class BufferConfig: """Config for algorithm.""" ... - mu: float = 0.1 expert_data_ratio: float = 0.5 ``` @@ -172,70 +171,7 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> ``` - -## Step 3: Define the Avantage Function - -We define a `MIXAdvantageFn` class in `trinity/algorithm/advantage_fn/mix_advantage.py`, which computes the advantage function for only usual experiences. - -```python -@ADVANTAGE_FN.register_module("mix") -class MIXAdvantageFn(GRPOAdvantageFn): - """MIX advantage computation""" - - def __init__( - self, - epsilon: float = 1e-6, - ) -> None: - super().__init__(epsilon) - - def __call__( - self, - exps: DataProto, - **kwargs, - ) -> Tuple[DataProto, Dict]: - is_expert_mask = exps.batch["is_expert_mask"] - device = is_expert_mask.device - batch_size = is_expert_mask.shape[0] - - # Process tensors - tensors = {k: tensor[~is_expert_mask] for k, tensor in exps.batch.items()} - - # Process non-tensors - non_tensors = { - k: v[~is_expert_mask.detach().cpu().numpy()] for k, v in exps.non_tensor_batch.items() - } - - # Build new DataProto - new_exps = DataProto.from_dict( - tensors=tensors, non_tensors=non_tensors, meta_info=exps.meta_info - ) - new_exps, new_metrics = super().__call__(new_exps, **kwargs) - - # Get full advantages - full_advantages = torch.zeros( - (batch_size, new_exps.batch["advantages"].shape[1]), device=device - ) - full_returns = torch.zeros((batch_size, new_exps.batch["returns"].shape[1]), device=device) - - # Fill in the non-expert parts with computed values - full_advantages[~is_expert_mask] = new_exps.batch["advantages"] - full_returns[~is_expert_mask] = new_exps.batch["returns"] - - # Write back to original exps - exps.batch["advantages"] = full_advantages - exps.batch["returns"] = full_returns - # TODO: change new_metrics - return exps, new_metrics - - @classmethod - def default_args(cls) -> Dict: - return { - "epsilon": 1e-6, - } -``` - - -## Step 4: Define the Policy Loss Function +## Step 3: Define the Policy Loss Function We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively. @@ -248,7 +184,7 @@ class MIXPolicyLossFn(PolicyLossFn): clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, - use_dynamic_bsz: Optional[int] = None, + use_dynamic_bsz: Optional[bool] = None, ppo_mini_batch_size: Optional[int] = None, gradient_accumulation: Optional[int] = None, read_batch_size_usual: Optional[int] = None, @@ -347,17 +283,19 @@ class MIXPolicyLossFn(PolicyLossFn): return ["old_logprob", "action_mask", "advantages", "is_expert_mask"] ``` -## Step 5: Run the Experiment +## Step 4: Run the Experiment -With the above newly-defined classes, we can run the experiments without modifying other process. +With the above newly-defined classes and functions, we can run the experiments without modifying other process. An example showing some important configurations is shown below. ```yaml algorithm: algorithm_type: mix - mu: 0.5 repeat_times: 8 - use_token_level_loss: False + policy_loss_fn_args: + mu: 0.5 + clip_range: 0.2 + use_token_level_loss_in_sft: False buffer: expert_data_ratio: 0.25 ``` diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index 5e5aad8c4f..7bcf682e4b 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -1,6 +1,5 @@ from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn -from trinity.algorithm.advantage_fn.mix_advantage import MIXAdvantageFn from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import ( @@ -18,5 +17,4 @@ "REMAXAdvantageFn", "RLOOAdvantageFn", "OPMDAdvantageFn", - "MIXAdvantageFn", ] diff --git a/trinity/algorithm/advantage_fn/mix_advantage.py b/trinity/algorithm/advantage_fn/mix_advantage.py deleted file mode 100644 index 41db245ea2..0000000000 --- a/trinity/algorithm/advantage_fn/mix_advantage.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Mix advantage computation""" - -from typing import Dict, Tuple - -import torch -from verl import DataProto - -from trinity.algorithm.advantage_fn import ADVANTAGE_FN -from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn - - -@ADVANTAGE_FN.register_module("mix") -class MIXAdvantageFn(GRPOAdvantageFn): - """MIX advantage computation""" - - def __init__( - self, - epsilon: float = 1e-6, - ) -> None: - super().__init__(epsilon) - - def __call__( - self, - exps: DataProto, - **kwargs, - ) -> Tuple[DataProto, Dict]: - is_expert_mask = exps.batch["is_expert_mask"] - device = is_expert_mask.device - batch_size = is_expert_mask.shape[0] - - # Process tensors - tensors = {k: tensor[~is_expert_mask] for k, tensor in exps.batch.items()} - - # Process non-tensors - non_tensors = { - k: v[~is_expert_mask.detach().cpu().numpy()] for k, v in exps.non_tensor_batch.items() - } - - # Build new DataProto - new_exps = DataProto.from_dict( - tensors=tensors, non_tensors=non_tensors, meta_info=exps.meta_info - ) - new_exps, new_metrics = super().__call__(new_exps, **kwargs) - - # Get full advantages - full_advantages = torch.zeros( - (batch_size, new_exps.batch["advantages"].shape[1]), device=device - ) - full_returns = torch.zeros((batch_size, new_exps.batch["returns"].shape[1]), device=device) - - # Fill in the non-expert parts with computed values - full_advantages[~is_expert_mask] = new_exps.batch["advantages"] - full_returns[~is_expert_mask] = new_exps.batch["returns"] - - # Write back to original exps - exps.batch["advantages"] = full_advantages - exps.batch["returns"] = full_returns - # TODO: change new_metrics - return exps, new_metrics - - @classmethod - def default_args(cls) -> Dict: - return { - "epsilon": 1e-6, - } diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 93c19879c9..0237402e1b 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -198,7 +198,7 @@ def get_default_config(cls) -> Dict: return { "repeat_times": 8, "policy_loss_fn": "mix", - "advantage_fn": "mix", + "advantage_fn": "grpo", "sample_strategy": "mix", "mu": 0.1, } diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 6e44d8f931..932ea3738f 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -17,7 +17,7 @@ def __init__( clip_range: Optional[float] = None, clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, - use_dynamic_bsz: Optional[int] = None, + use_dynamic_bsz: Optional[bool] = None, ppo_mini_batch_size: Optional[int] = None, gradient_accumulation: Optional[int] = None, read_batch_size_usual: Optional[int] = None, diff --git a/trinity/common/config.py b/trinity/common/config.py index d31df1dae2..c48f659326 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -175,7 +175,6 @@ class AlgorithmConfig: algorithm_type: str = "ppo" # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 - mu: float = 0.1 # for mix training sample_strategy: Optional[str] = None sample_strategy_args: Optional[dict] = None diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index 73a2369bd6..f4f974102d 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -353,7 +353,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 read_batch_size_expert = math.ceil(config.buffer.expert_data_ratio * tot_batch_size) read_batch_size_usual = tot_batch_size - read_batch_size_expert loss_kwargs = { - "mu": config.algorithm.mu, "use_dynamic_bsz": self.actor_rollout_ref.actor.use_dynamic_bsz, "ppo_mini_batch_size": self.actor_rollout_ref.actor.ppo_mini_batch_size * self.actor_rollout_ref.rollout.n @@ -365,9 +364,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 "read_batch_size_expert": read_batch_size_expert, } config.algorithm.policy_loss_fn_args.update(loss_kwargs) - config.algorithm.policy_loss_fn_args.update( - {"use_token_level_loss_in_sft": config.algorithm.use_token_level_loss} - ) def load_config(config_path: str) -> veRLConfig: From c7d1e50af64ac2b7aeeb9bbd0bd476eb8fa6fd20 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Mon, 16 Jun 2025 19:23:24 +0800 Subject: [PATCH 7/9] fix config passing --- .../source/tutorial/example_mix_algo.md | 20 +++++-------------- trinity/algorithm/algorithm.py | 1 - .../sample_strategy/sample_strategy.py | 2 +- trinity/common/config.py | 1 - trinity/common/verl_config.py | 4 +++- 5 files changed, 9 insertions(+), 19 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 105745ae1c..39abd09e3c 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -53,19 +53,9 @@ class MIXAlgorithm(AlgorithmType): "policy_loss_fn": "mix", "advantage_fn": "grpo", "sample_strategy": "mix", - "mu": 0.1, } ``` -We also define some necessary configuration parameters for later use, including the weighting factor $\mu$ and the batch size of expert experiences $B'$, calculated as the product of `buffer.batch_size`, `buffer.expert_data_ratio` and `algorithm.repeat_times`. - - -```python -class BufferConfig: - """Config for algorithm.""" - ... - expert_data_ratio: float = 0.5 -``` ## Step 2: Define the Sampling Strategy @@ -78,7 +68,7 @@ class MixSampleStrategy(SampleStrategy): def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): super().__init__(buffer_config, trainer_type) - self.expert_data_ratio = buffer_config.expert_data_ratio + self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) tot_batch_size = buffer_config.read_batch_size expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) @@ -286,16 +276,16 @@ class MIXPolicyLossFn(PolicyLossFn): ## Step 4: Run the Experiment With the above newly-defined classes and functions, we can run the experiments without modifying other process. -An example showing some important configurations is shown below. +An example showing some important configurations is shown below, including the weighting factor $\mu$ as `algorithm.policy_loss_fn_args['mu']` and the batch size of expert experiences $B'$, calculated as the product of `buffer.batch_size`, `algorithm.sample_strategy_args['expert_data_ratio']` and `algorithm.repeat_times`. ```yaml algorithm: algorithm_type: mix repeat_times: 8 + sample_strategy_args: + expert_data_ratio: 0.25 policy_loss_fn_args: - mu: 0.5 + mu: 0.5 # NEW clip_range: 0.2 use_token_level_loss_in_sft: False -buffer: - expert_data_ratio: 0.25 ``` diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 0237402e1b..6f0a2d19a7 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -200,5 +200,4 @@ def get_default_config(cls) -> Dict: "policy_loss_fn": "mix", "advantage_fn": "grpo", "sample_strategy": "mix", - "mu": 0.1, } diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index f2fc079003..b6e8f550ae 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -128,7 +128,7 @@ class MixSampleStrategy(SampleStrategy): def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): super().__init__(buffer_config, trainer_type) - self.expert_data_ratio = buffer_config.expert_data_ratio + self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) tot_batch_size = buffer_config.read_batch_size expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) diff --git a/trinity/common/config.py b/trinity/common/config.py index c48f659326..7c371f4bcb 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -250,7 +250,6 @@ class BufferConfig: batch_size: int = 1 total_epochs: int = 1 - expert_data_ratio: float = 0.0 # for explorer explorer_input: ExplorerInput = field(default_factory=ExplorerInput) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index f4f974102d..b12975af9a 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -350,7 +350,9 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 if config.algorithm.algorithm_type == "mix": tot_batch_size = config.buffer.read_batch_size - read_batch_size_expert = math.ceil(config.buffer.expert_data_ratio * tot_batch_size) + read_batch_size_expert = math.ceil( + config.algorithm.sample_strategy_args["expert_data_ratio"] * tot_batch_size + ) read_batch_size_usual = tot_batch_size - read_batch_size_expert loss_kwargs = { "use_dynamic_bsz": self.actor_rollout_ref.actor.use_dynamic_bsz, From c23d0338539a10257e5781019016edb8be8d27d1 Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 17 Jun 2025 10:29:57 +0800 Subject: [PATCH 8/9] fix comments --- .../source/tutorial/example_mix_algo.md | 31 +++++++++++++------ .../policy_loss_fn/mix_policy_loss.py | 29 +++++++++++++---- trinity/common/verl_config.py | 19 ------------ trinity/trainer/verl/dp_actor.py | 8 +++-- 4 files changed, 50 insertions(+), 37 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 39abd09e3c..84a4852053 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -127,7 +127,7 @@ class MixSampleStrategy(SampleStrategy): We also need to add an `is_expert_mask` field when transforming to DataProto to indicate the data type. ```diff -def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: ++ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: attention_mask = experiences.attention_masks cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() @@ -175,16 +175,20 @@ class MIXPolicyLossFn(PolicyLossFn): clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, use_dynamic_bsz: Optional[bool] = None, + repeat_times: Optional[int] = None, ppo_mini_batch_size: Optional[int] = None, - gradient_accumulation: Optional[int] = None, + ppo_micro_batch_size_per_gpu: Optional[int] = None, + ngpus_trainer: Optional[int] = None, read_batch_size_usual: Optional[int] = None, read_batch_size_expert: Optional[int] = None, - use_token_level_loss_in_sft: Optional[bool] = False, + use_token_level_loss_in_sft: bool = True, ) -> None: self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz - self.ppo_mini_batch_size = ppo_mini_batch_size - self.gradient_accumulation = gradient_accumulation + self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore + self.gradient_accumulation = ( + ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore + ) self.read_batch_size_usual = read_batch_size_usual self.read_batch_size_expert = read_batch_size_expert self.grpo_loss_fn = PPOPolicyLossFn( @@ -213,15 +217,15 @@ class MIXPolicyLossFn(PolicyLossFn): n_expert_exp = torch.sum(is_expert_mask).item() if self.use_dynamic_bsz: - per_micro_batch_weight_usual = self.ppo_mini_batch_size / ( + per_micro_batch_weight_usual = self.experience_per_gpu / ( logprob.shape[0] * self.read_batch_size_usual ) - per_micro_batch_weight_expert = self.ppo_mini_batch_size / ( + per_micro_batch_weight_expert = self.experience_per_gpu / ( logprob.shape[0] * self.read_batch_size_expert ) else: - per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual - per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert + per_micro_batch_weight_usual = self.gradient_accumulation / self.read_batch_size_usual # type: ignore + per_micro_batch_weight_expert = self.gradient_accumulation / self.read_batch_size_expert # type: ignore if n_usual_exp > 0: grpo_loss, grpo_metrics = self.grpo_loss_fn( @@ -285,7 +289,14 @@ algorithm: sample_strategy_args: expert_data_ratio: 0.25 policy_loss_fn_args: - mu: 0.5 # NEW + mu: 0.5 clip_range: 0.2 use_token_level_loss_in_sft: False + use_dynamic_bsz: False + repeat_times: 8 + ppo_mini_batch_size: 32 + ppo_micro_batch_size_per_gpu: 4 + ngpus_trainer: 4 + read_batch_size_expert: 64 + read_batch_size_usual: 192 ``` diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index 932ea3738f..84679b0ea8 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -11,6 +11,19 @@ @POLICY_LOSS_FN.register_module("mix") class MIXPolicyLossFn(PolicyLossFn): + """Implements a mixed policy loss combining GRPO and SFT losses. + + This loss function applies different loss components to data based on whether + it comes from an expert or not, as indicated by `is_expert_mask`. It combines: + - GRPO loss (self.grpo_loss_fn) for non-expert data + - SFT loss (self.sft_loss_fn) for expert data + - Weighting parameter `mu` + + The per-sample weights are normalized using either `experience_per_gpu` or + `gradient_accumulation`, depending on whether dynamic batch sizing is enabled, + to ensure consistent weighting across different batches of the same type experiences. + """ + def __init__( self, mu: float = 0.1, @@ -18,16 +31,20 @@ def __init__( clip_range_low: Optional[float] = None, clip_range_high: Optional[float] = None, use_dynamic_bsz: Optional[bool] = None, + repeat_times: Optional[int] = None, ppo_mini_batch_size: Optional[int] = None, - gradient_accumulation: Optional[int] = None, + ppo_micro_batch_size_per_gpu: Optional[int] = None, + ngpus_trainer: Optional[int] = None, read_batch_size_usual: Optional[int] = None, read_batch_size_expert: Optional[int] = None, - use_token_level_loss_in_sft: bool = False, + use_token_level_loss_in_sft: bool = True, ) -> None: self.mu = mu self.use_dynamic_bsz = use_dynamic_bsz - self.ppo_mini_batch_size = ppo_mini_batch_size - self.gradient_accumulation = gradient_accumulation + self.experience_per_gpu = ppo_mini_batch_size * repeat_times // ngpus_trainer # type: ignore + self.gradient_accumulation = ( + ppo_mini_batch_size * repeat_times // ppo_micro_batch_size_per_gpu # type: ignore + ) self.read_batch_size_usual = read_batch_size_usual self.read_batch_size_expert = read_batch_size_expert self.grpo_loss_fn = PPOPolicyLossFn( @@ -56,10 +73,10 @@ def __call__( # type: ignore n_expert_exp = torch.sum(is_expert_mask).item() if self.use_dynamic_bsz: - per_micro_batch_weight_usual = self.ppo_mini_batch_size / ( + per_micro_batch_weight_usual = self.experience_per_gpu / ( logprob.shape[0] * self.read_batch_size_usual ) - per_micro_batch_weight_expert = self.ppo_mini_batch_size / ( + per_micro_batch_weight_expert = self.experience_per_gpu / ( logprob.shape[0] * self.read_batch_size_expert ) else: diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index b12975af9a..74fa419db9 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -348,25 +348,6 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 # TODO: check other fields self.enable_preview = config.trainer.enable_preview - if config.algorithm.algorithm_type == "mix": - tot_batch_size = config.buffer.read_batch_size - read_batch_size_expert = math.ceil( - config.algorithm.sample_strategy_args["expert_data_ratio"] * tot_batch_size - ) - read_batch_size_usual = tot_batch_size - read_batch_size_expert - loss_kwargs = { - "use_dynamic_bsz": self.actor_rollout_ref.actor.use_dynamic_bsz, - "ppo_mini_batch_size": self.actor_rollout_ref.actor.ppo_mini_batch_size - * self.actor_rollout_ref.rollout.n - // world_size, - "gradient_accumulation": self.actor_rollout_ref.actor.ppo_mini_batch_size - * self.actor_rollout_ref.rollout.n - // self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, - "read_batch_size_usual": read_batch_size_usual, - "read_batch_size_expert": read_batch_size_expert, - } - config.algorithm.policy_loss_fn_args.update(loss_kwargs) - def load_config(config_path: str) -> veRLConfig: schema = OmegaConf.structured(veRLConfig) diff --git a/trinity/trainer/verl/dp_actor.py b/trinity/trainer/verl/dp_actor.py index ae691e111a..9c13c528d7 100644 --- a/trinity/trainer/verl/dp_actor.py +++ b/trinity/trainer/verl/dp_actor.py @@ -294,11 +294,15 @@ def update_policy(self, data: DataProto): # noqa: C901 "ref_log_prob": "ref_logprob", "response_mask": "action_mask", "advantages": "advantages", - "is_expert_mask": "is_expert_mask", } select_keys_trinity2verl = {value: key for key, value in select_keys_verl2trinity.items()} for trinity_key in self.policy_loss_fn.select_keys: - verl_key = select_keys_trinity2verl[trinity_key] + if trinity_key in select_keys_trinity2verl: + verl_key = select_keys_trinity2verl[trinity_key] + else: + verl_key = trinity_key + select_keys_verl2trinity.update({verl_key: trinity_key}) + select_keys_trinity2verl.update({trinity_key: verl_key}) select_keys.append(verl_key) if not isinstance(self.kl_loss_fn, DummyKLFn): select_keys.append("ref_log_prob") From 5d9b004fad043c16f34768f5c6a6a5cd90d7551e Mon Sep 17 00:00:00 2001 From: hiyuchang Date: Tue, 17 Jun 2025 14:32:32 +0800 Subject: [PATCH 9/9] add MIX example and refactor mix_sample_strategy --- .../source/tutorial/example_mix_algo.md | 3 +- examples/grpo_math/train_math.yaml | 4 +- examples/mix_math/README.md | 7 ++ examples/mix_math/mix_math.yaml | 88 +++++++++++++ examples/mix_math/train_mix_math.yaml | 70 +++++++++++ trinity/algorithm/sample_strategy/__init__.py | 1 + .../sample_strategy/mix_sample_strategy.py | 118 ++++++++++++++++++ .../sample_strategy/sample_strategy.py | 78 +----------- trinity/algorithm/sample_strategy/utils.py | 33 ----- 9 files changed, 289 insertions(+), 113 deletions(-) create mode 100644 examples/mix_math/README.md create mode 100644 examples/mix_math/mix_math.yaml create mode 100644 examples/mix_math/train_mix_math.yaml create mode 100644 trinity/algorithm/sample_strategy/mix_sample_strategy.py diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 84a4852053..9dadc76b40 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -281,6 +281,7 @@ class MIXPolicyLossFn(PolicyLossFn): With the above newly-defined classes and functions, we can run the experiments without modifying other process. An example showing some important configurations is shown below, including the weighting factor $\mu$ as `algorithm.policy_loss_fn_args['mu']` and the batch size of expert experiences $B'$, calculated as the product of `buffer.batch_size`, `algorithm.sample_strategy_args['expert_data_ratio']` and `algorithm.repeat_times`. +For the full configuration, please refer to [`mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/mix_math.yaml) and [`train_mix_math.yaml`](https://github.com/modelscope/Trinity-RFT/tree/main/examples/mix_math/train_mix_math.yaml). ```yaml algorithm: @@ -289,7 +290,7 @@ algorithm: sample_strategy_args: expert_data_ratio: 0.25 policy_loss_fn_args: - mu: 0.5 + mu: 0.1 clip_range: 0.2 use_token_level_loss_in_sft: False use_dynamic_bsz: False diff --git a/examples/grpo_math/train_math.yaml b/examples/grpo_math/train_math.yaml index 7b14a87fad..78bcb862c6 100644 --- a/examples/grpo_math/train_math.yaml +++ b/examples/grpo_math/train_math.yaml @@ -10,7 +10,7 @@ actor_rollout_ref: ppo_mini_batch_size: 128 ppo_micro_batch_size_per_gpu: 4 use_dynamic_bsz: True # False - ppo_max_token_len_per_gpu: 25600 # n * ${data.max_prompt_length} + ${data.max_response_length} + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 clip_ratio: 0.2 entropy_coeff: 0.001 @@ -21,7 +21,7 @@ actor_rollout_ref: shuffle: False ulysses_sequence_parallel_size: 1 # sp size optim: - lr: 1e-6 + lr: 5e-7 lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime # min_lr_ratio: null # only useful for warmup with cosine warmup_style: constant # select from constant/cosine diff --git a/examples/mix_math/README.md b/examples/mix_math/README.md new file mode 100644 index 0000000000..8e84f233bc --- /dev/null +++ b/examples/mix_math/README.md @@ -0,0 +1,7 @@ +# Example: MIX on MATH dataset + +This example shows the usage of a new algorithm MIX on the MATH dataset. + +For more detailed information, please refer to the [documentation](../../docs/sphinx_doc/source/tutorial/example_mix_algo.md). + +The config files are located in [`mix_math.yaml`](mix.yaml) and [`train_mix_math.yaml`](train_mix_math.yaml). diff --git a/examples/mix_math/mix_math.yaml b/examples/mix_math/mix_math.yaml new file mode 100644 index 0000000000..339d8df394 --- /dev/null +++ b/examples/mix_math/mix_math.yaml @@ -0,0 +1,88 @@ +project: "mix_math" +name: "expert0.25_mu0.1" +checkpoint_root_dir: /PATH/TO/CHECKPOINT/ +algorithm: + algorithm_type: mix + repeat_times: 8 + sample_strategy_args: + expert_data_ratio: 0.25 + policy_loss_fn_args: + mu: 0.1 + clip_range: 0.2 + use_token_level_loss_in_sft: False + use_dynamic_bsz: False + repeat_times: 8 + ppo_mini_batch_size: 32 + ppo_micro_batch_size_per_gpu: 4 + ngpus_trainer: 4 + read_batch_size_expert: 64 + read_batch_size_usual: 192 +model: + model_path: /PATH/TO/MODEL/ + max_prompt_tokens: 1024 + max_response_tokens: 10240 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + total_epochs: 1 + batch_size: 40 + explore_batch_size: 36 + max_retry_times: 3 + max_retry_interval: 1 + explorer_input: + taskset: + name: math_train + storage_type: file + path: /PATH/TO/DATASET/ + split: 'train' + format: + prompt_key: 'problem' + response_key: 'answer' + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: + - name: math_eval + storage_type: file + path: /PATH/TO/DATASET/ + split: 'test' + format: + prompt_key: 'problem' + response_key: 'answer' + default_workflow_type: 'math_workflow' + trainer_input: + experience_buffer: + name: math_buffer + storage_type: queue + path: /PATH/TO/BUFFER/ + sft_warmup_dataset: + name: math_sft + storage_type: file + algorithm_type: sft + path: /PATH/TO/EXPERT_DATA/ + split: 'train' + format: + prompt_type: messages + messages_key: 'messages' +explorer: + eval_interval: 10 + runner_num: 16 + rollout_model: + engine_type: vllm_async + engine_num: 4 + tensor_parallel_size: 1 + enable_prefix_caching: false + enforce_eager: true + dtype: bfloat16 + seed: 42 +synchronizer: + sync_method: 'nccl' + sync_interval: 1 + sync_timeout: 1200 +trainer: + trainer_type: 'verl' + trainer_config_path: 'examples/mix_math/train_math.yaml' + save_interval: 50 +monitor: + monitor_type: wandb diff --git a/examples/mix_math/train_mix_math.yaml b/examples/mix_math/train_mix_math.yaml new file mode 100644 index 0000000000..7b14a87fad --- /dev/null +++ b/examples/mix_math/train_mix_math.yaml @@ -0,0 +1,70 @@ +actor_rollout_ref: + hybrid_engine: True + model: + external_lib: null + override_config: { } + enable_gradient_checkpointing: True + use_remove_padding: True # False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 128 + ppo_micro_batch_size_per_gpu: 4 + use_dynamic_bsz: True # False + ppo_max_token_len_per_gpu: 25600 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: True # True for GRPO + kl_loss_coef: 0.0001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + # min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 + # --- below: opmd --- + tau: 0.000 # strength of regularization w.r.t. old / ref policy + opmd_baseline: mean # mean / logavgexp, applicable to opmd + use_uid: False # True / False, applicable to pairwise_opmd + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + +custom_reward_function: + path: null + name: compute_score + +algorithm: + gamma: 1.0 + lam: 1.0 + kl_penalty: kl # how to estimate kl divergence + kl_ctrl: + type: fixed + kl_coef: 0.0001 + +trainer: + balance_batch: True + # auto: find the last ckpt to resume. If can't find, start from scratch + resume_mode: auto # or auto or resume_path if + default_hdfs_dir: null + remove_previous_ckpt_in_save: False + del_local_ckpt_after_load: False + val_before_train: False diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py index 385ff58147..cd4b9e0d66 100644 --- a/trinity/algorithm/sample_strategy/__init__.py +++ b/trinity/algorithm/sample_strategy/__init__.py @@ -1,3 +1,4 @@ +from trinity.algorithm.sample_strategy.mix_sample_strategy import MixSampleStrategy from trinity.algorithm.sample_strategy.sample_strategy import ( SAMPLE_STRATEGY, DefaultSampleStrategy, diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py new file mode 100644 index 0000000000..acdd340b24 --- /dev/null +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -0,0 +1,118 @@ +import copy +from math import ceil +from typing import Any, Dict, List, Tuple + +import numpy as np +import torch +from verl.trainer.ppo.ray_trainer import DataProto + +from trinity.algorithm.sample_strategy.sample_strategy import ( + SAMPLE_STRATEGY, + SampleStrategy, +) +from trinity.algorithm.sample_strategy.utils import representative_sample +from trinity.buffer import get_buffer_reader +from trinity.common.config import BufferConfig +from trinity.common.experience import Experiences +from trinity.utils.timer import Timer + + +@SAMPLE_STRATEGY.register_module("mix") +class MixSampleStrategy(SampleStrategy): + """The default sample strategy.""" + + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) + tot_batch_size = buffer_config.read_batch_size + expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) + + # experience buffer + usual_buffer_config = copy.deepcopy(buffer_config) + usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size + self.usual_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore + ) + + if buffer_config.trainer_input.sft_warmup_dataset is None: + raise ValueError( + "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" + ) + + # expert experience buffer + expert_buffer_config = copy.deepcopy(buffer_config) + expert_buffer_config.read_batch_size = expert_batch_size + self.expert_exp_buffer = get_buffer_reader( + buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config + ) + + def sample(self, step: int) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + usual_exp_list = self.usual_exp_buffer.read() + for exp in usual_exp_list: + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = False + + expert_exp_list = self.expert_exp_buffer.read() + for exp in expert_exp_list: + exp.reward = 0.0 + exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) + if exp.info is None: + exp.info = {} + exp.info["is_expert"] = True + + exp_list = usual_exp_list + expert_exp_list + repr_samples = representative_sample(exp_list) + + is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) + + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto_mix(exps, is_expert_mask) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + @classmethod + def get_default_config(cls) -> Dict: + return { + "expert_data_ratio": 0.5, + } + + +def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), + "is_expert_mask": is_expert_mask, + } + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + return DataProto.from_single_dict(batch_dict) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index b6e8f550ae..8686a0d497 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -1,15 +1,7 @@ -import copy from abc import ABC, abstractmethod -from math import ceil from typing import Any, Dict, List, Tuple -import torch - -from trinity.algorithm.sample_strategy.utils import ( - representative_sample, - to_data_proto, - to_data_proto_mix, -) +from trinity.algorithm.sample_strategy.utils import representative_sample, to_data_proto from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig from trinity.common.experience import Experiences @@ -120,71 +112,3 @@ def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: return data, metrics, repr_samples else: raise NotImplementedError(f"backend {self.trainer_type} is not supported") - - -@SAMPLE_STRATEGY.register_module("mix") -class MixSampleStrategy(SampleStrategy): - """The default sample strategy.""" - - def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): - super().__init__(buffer_config, trainer_type) - self.expert_data_ratio = kwargs.get("expert_data_ratio", 0.5) - tot_batch_size = buffer_config.read_batch_size - expert_batch_size = ceil(self.expert_data_ratio * tot_batch_size) - - # experience buffer - usual_buffer_config = copy.deepcopy(buffer_config) - usual_buffer_config.read_batch_size = tot_batch_size - expert_batch_size - self.usual_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.experience_buffer, usual_buffer_config # type: ignore - ) - - if buffer_config.trainer_input.sft_warmup_dataset is None: - raise ValueError( - "`buffer_config.trainer_input.sft_warmup_dataset` is required in MIX algorithm" - ) - - # expert experience buffer - expert_buffer_config = copy.deepcopy(buffer_config) - expert_buffer_config.read_batch_size = expert_batch_size - self.expert_exp_buffer = get_buffer_reader( - buffer_config.trainer_input.sft_warmup_dataset, expert_buffer_config - ) - - def sample(self, step: int) -> Tuple[Any, Dict, List]: - metrics = {} - with Timer(metrics, "read_time"): - usual_exp_list = self.usual_exp_buffer.read() - for exp in usual_exp_list: - if exp.info is None: - exp.info = {} - exp.info["is_expert"] = False - - expert_exp_list = self.expert_exp_buffer.read() - for exp in expert_exp_list: - exp.reward = 0.0 - exp.logprobs = torch.zeros_like(exp.tokens, dtype=torch.float32) - if exp.info is None: - exp.info = {} - exp.info["is_expert"] = True - - exp_list = usual_exp_list + expert_exp_list - repr_samples = representative_sample(exp_list) - - is_expert_mask = torch.tensor([exp.info["is_expert"] for exp in exp_list], dtype=torch.bool) - - with Timer(metrics, "gather_time"): - exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - - if self.trainer_type == "verl": - with Timer(metrics, "convert_time"): - data = to_data_proto_mix(exps, is_expert_mask) - return data, metrics, repr_samples - else: - raise NotImplementedError(f"backend {self.trainer_type} is not supported") - - @classmethod - def get_default_config(cls) -> Dict: - return { - "expert_data_ratio": 0.5, - } diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py index be2c06bd34..8c443a20b1 100644 --- a/trinity/algorithm/sample_strategy/utils.py +++ b/trinity/algorithm/sample_strategy/utils.py @@ -40,39 +40,6 @@ def to_data_proto(experiences: Experiences) -> DataProto: return DataProto.from_single_dict(batch_dict) -def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) -> DataProto: - attention_mask = experiences.attention_masks - cumsum = torch.cumsum(attention_mask, dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - batch_dict = { - "uid": np.array(experiences.run_ids), - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - "is_expert_mask": is_expert_mask, - } - if experiences.rewards is not None: - token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) - eos_mask_idx = cumsum.argmax(dim=-1) - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] - batch_dict.update( - { - "token_level_scores": token_level_rewards, - "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore - } - ) - return DataProto.from_single_dict(batch_dict) - - def representative_sample(experiences: List[Experience]) -> List[dict]: if experiences[0].reward is None: sample = random.choice(experiences)