diff --git a/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md b/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md index 1f21f2abfe..6dc03118e2 100644 --- a/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md +++ b/docs/source/Instruction/GRPO/AdvancedResearch/GSPO.md @@ -54,7 +54,7 @@ importance_weights = torch.exp(log_importance_weights) - `importance_sampling_level sequence` (GSPO) - `importance_sampling_level sequence_token` (GSPO-token) -其中 sequence_token 要求 ms-swift > 3.7 (源码安装) +其中 sequence_token 要求 ms-swift >= 3.8 论文其他超参 ```bash diff --git "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index ba6256e2d0..bcc7325f87 100644 --- "a/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Instruction/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -505,6 +505,15 @@ reward模型参数将在PPO、GRPO中使用。 #### GRPO参数 - beta: KL正则系数,默认为0.04,设置为0时不加载ref model。 +- epsilon: clip 系数,默认为0.2。 +- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。 +- delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。 +- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 +- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。 +- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。 +- top_entropy_quantile: 仅对熵值处于前指定分位的 token 参与损失计算,默认为1.0,即不过滤低熵 token,具体参考[文档](./GRPO/AdvancedResearch/entropy_mask.md) +- log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics) +- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 、 `sequence` 和 `sequence_token`,默认为`token`。具体参考[GSPO文档](./GRPO/AdvancedResearch/GSPO.md) - per_device_train_batch_size: 每个设备训练批量大小,在GRPO中,指 completion 的批次大小。 - per_device_eval_batch_size: 每个设备评估批量大小,在GRPO中,指 completion 的批次大小。 - generation_batch_size: 采样completion批量大小,需要是 num_processes * per_device_train_batch_size 的倍数,默认等于 per_device_batch_size * gradient_accumulation_steps * num_processes @@ -542,22 +551,15 @@ reward模型参数将在PPO、GRPO中使用。 - completion_length_limit_scope: 在多轮对话中,`max_completion_length` 的限制范围。 `total`限制所有对话轮次的总输出长度不超过`max_completion_length`, `per_round`限制每一轮的输出长度。 - num_iterations: 每个批次代更新次数,默认为1。 -- epsilon: clip 系数,默认为0.2。 -- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。 -- delta: [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291)中双侧 GRPO 上界裁剪值。若设置,建议大于 1 + epsilon。默认为None。 - sync_ref_model: 是否定期同步ref_model,默认为False。 - ref_model_mixup_alpha: 控制在更新过程中model和先前ref_model之间的混合。更新公式为 $π_{ref} = α * π_θ + (1 - α) * π_{ref_{prev}}$。默认为0.6。 - ref_model_sync_steps:同步频率,默认为512。 - move_model_batches: 在模型向vLLM等快速推理框架移动参数时,将layers分为多少个batch. 默认为None, 代表整个模型不进行拆分,否则拆分为move_model_batches+1(非layer参数)+1(多模态部分参数)个。注意:该参数仅对LoRA(PEFT)训练有意义。 - multi_turn_scheduler: 多轮GRPO参数, 传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现。 - max_turns: 多轮GRPO的轮数上限。默认为None,不做限制。 -- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。 -- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。 -- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 -- top_entropy_quantile: 仅对熵值处于前指定分位的 token 参与损失计算,默认为1.0,即不过滤低熵 token,具体参考[文档](./GRPO/AdvancedResearch/entropy_mask.md) -- log_entropy: 记录训练中的熵值变化动态,默认为False,具体参考[文档](./GRPO/GetStarted/GRPO.md#logged-metrics) -- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 和 `sequence`,`token` 模式下保留原始的每个 token 的对数概率比,`sequence` 模式下则会对序列中所有有效 token 的对数概率比进行平均。[GSPO论文](https://www.arxiv.org/abs/2507.18071)中使用sequence级别计算来稳定训练,默认为`token`。 +##### 奖励函数参数 +内置的奖励函数参考[文档](./GRPO/DeveloperGuide/奖励函数.md) cosine 奖励参数 - cosine_min_len_value_wrong:cosine 奖励函数参数,生成错误答案时,最小长度对应的奖励值。默认值为-0.5。 - cosine_max_len_value_wrong:生成错误答案时,最大长度对应的奖励值。默认值为0.0。 diff --git "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" index 9892d74035..bae33fa5fe 100644 --- "a/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" +++ "b/docs/source/Megatron-SWIFT/\345\221\275\344\273\244\350\241\214\345\217\202\346\225\260.md" @@ -236,7 +236,7 @@ lora训练: **DPO参数**: -- ref_load: ref_model的加载路径。采用DPO/KTO算法且使用全参数训练时需要传入。默认为None,即设置为`load`。 +- ref_load: ref_model的加载路径。采用DPO/GRPO/KTO算法且使用全参数训练时需要传入。默认为None,即设置为`load`。 - ref_adapter_load: 加载ref_adapter的权重路径,默认为None。若你要使用SFT产生的LoRA权重进行DPO,请使用"ms-swift>=3.8",并在训练时设置`--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true`。若是此场景的断点续训,则设置`--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`。 - beta: 含义与[TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig)相同。控制与参考模型偏差程度的参数。beta值越高,表示与参考模型的偏差越小。对于 IPO 损失函数 (loss_type="ipo"),beta是[论文](https://huggingface.co/papers/2310.12036)中所指的正则化参数。默认为0.1。 - 🔥rpo_alpha: 来自[RPO 论文](https://huggingface.co/papers/2404.19733)中的参数,用于控制损失函数中NLL项的权重(即SFT损失),`loss = dpo_loss + rpo_alpha * sft_loss`,论文中推荐设置为`1.`。默认为`None`,即默认不引入sft_loss。 @@ -254,6 +254,35 @@ lora训练: - desirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 desirable 损失按该系数进行加权,默认为`1.`。 - undesirable_weight: 抵消 desirable 和 undesirable 数量不均衡的影响,对 undesirable 损失按该系数进行加权,默认为`1.`。 +**GRPO参数** +- ref_load: 含义同DPO。 +- ref_adapter_load: 含义同DPO。 +- beta: KL正则系数,默认为0.04,设置为0时不加载ref model。 +- epsilon: clip 系数,默认为0.2。 +- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围。 +- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。 +- importance_sampling_level: 控制重要性采样比计算,可选项为 `token` 、 `sequence` 和 `sequence_token`,默认为`token`。具体参考[GSPO文档](../Instruction/GRPO/AdvancedResearch/GSPO.md) +- batch size 相关参数(注意以下均为 completion-level) + - micro_batch_size: 每个device的批次大小,默认为1。 + - global_batch_size: 总批次大小,等价于`micro_batch_size*数据并行大小*梯度累加步数`。默认为16。对应每次更新权重的训练数据大小(mini_batch_size) + - generation_batch_size: 采样批量大小,需要是global_batch_size的倍数,默认等于global_batch_size + - steps_per_generation:每轮生成的优化步数,即采样批量大小相对global_batch_size的倍数,默认为1。 + - num_generations:每个prompt采样的数量,论文中的G值。采样批量大小需被num_generations 整除。默认为 8。 +- reward_funcs: GRPO算法奖励函数,可选项为`accuracy`、`format`、`cosine`、`repetition`和`soft_overlong`,见swift/plugin/orm.py。你也可以在plugin中自定义自己的奖励函数。默认为`[]`。 +- reward_weights: 每个奖励函数的权重。必须与奖励函数和奖励模型的总数量匹配。如果为 None,则所有奖励的权重都相等,为`1.0`。 +- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo'], 默认为'grpo', 具体查看该[pr](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348)。 +- vllm_mode 参数 + - vllm_gpu_memory_utilization: vllm透传参数,默认为0.9。 + - vllm_max_model_len: vllm透传参数,默认为None。 + - vllm_enforce_eager: vllm透传参数,默认为False。 + - vllm_limit_mm_per_prompt: vllm透传参数,默认为None。 + - vllm_enable_prefix_caching: vllm透传参数,默认为True。 + - sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1], 默认为0,不释放 + - offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。 + - offload_model: 是否在vLLM推理时 offload 模型,默认为False。 + +内置奖励函数参数参考[文档](../Instruction/命令行参数.md#奖励函数参数) + ## 训练参数 Megatron训练参数继承自Megatron参数和基本参数(**与ms-swift共用dataset、template等参数,也支持ms-swift中的特定模型参数**)。基本参数的内容可以参考[这里](../Instruction/命令行参数.md#基本参数)。此外还包括以下参数: diff --git a/docs/source_en/Instruction/Command-line-parameters.md b/docs/source_en/Instruction/Command-line-parameters.md index 5776c9ccf1..fda80fbbd1 100644 --- a/docs/source_en/Instruction/Command-line-parameters.md +++ b/docs/source_en/Instruction/Command-line-parameters.md @@ -515,6 +515,15 @@ The meanings of the following parameters can be referenced [here](https://huggin #### GRPO Arguments - beta: KL regularization coefficient; default 0.04. Setting it to 0 disables the reference model. +- epsilon: epsilon value for clipping. Default is 0.2. +- epsilon_high: Upper clip coefficient, default is None. When set, it forms a clipping range of [epsilon, epsilon_high] together with epsilon. +- delta: Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). +- overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False. +- dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False. +- max_resample_times: Under the dynamic_sample setting, limit the number of resampling attempts to a maximum of 3. Default is 3 times. +- top_entropy_quantile: Only tokens whose entropy ranks within the specified top quantile are included in the loss calculation. The default is 1.0, which means low-entropy tokens are not filtered. For details, refer to the [documentation](./GRPO/AdvancedResearch/entropy_mask.md). +- log_entropy: Logs the entropy values during training. The default is False. For more information, refer to the [documentation](./GRPO/GetStarted/GRPO.md#logged-metrics). +- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`. - per_device_train_batch_size: The training batch size per device. In GRPO, this refers to the batch size of completions during training. - per_device_eval_batch_size: The evaluation batch size per device. In GRPO, this refers to the batch size of completions during evaluation. - generation_batch_size: Batch size to use for generation. It defaults to the effective training batch size: per_device_train_batch_size * num_processes * gradient_accumulation_steps` @@ -556,23 +565,16 @@ The meanings of the following parameters can be referenced [here](https://huggin - top_p: Default is 0.9. - repetition_penalty: Repetition penalty term. Default is 1. - num_iterations: number of iterations per batch. Default is 1. -- epsilon: epsilon value for clipping. Default is 0.2. -- epsilon_high: Upper clip coefficient, default is None. When set, it forms a clipping range of [epsilon, epsilon_high] together with epsilon. -- delta: Delta value for the upper clipping bound in two-sided GRPO. Recommended to be > 1 + epsilon. This method was introduced in the [INTELLECT-2 tech report](https://huggingface.co/papers/2505.07291). + - sync_ref_model: Whether to synchronize the reference model. Default is False。 - ref_model_mixup_alpha: The Parameter controls the mix between the current policy and the previous reference policy during updates. The reference policy is updated according to the equation: $π_{ref} = α * π_θ + (1 - α) * π_{ref_{prev}}$. Default is 0.6. - ref_model_sync_steps:The parameter determines how frequently the current policy is synchronized with the reference policy. Default is 512. - move_model_batches: When moving model parameters to fast inference frameworks such as vLLM/LMDeploy, determines how many batches to divide the layers into. The default is `None`, which means the entire model is not split. Otherwise, the model is split into `move_model_batches + 1` (non-layer parameters) + `1` (multi-modal component parameters) batches. This parameter is only meaningful for LoRA (PEFT). - multi_turn_scheduler: Multi-turn GRPO parameter; pass the corresponding plugin name, and make sure to implement it in plugin/multi_turn.py. - max_turns: Maximum number of rounds for multi-turn GRPO. The default is None, which means there is no limit. -- dynamic_sample: Exclude data within the group where the reward standard deviation is 0, and additionally sample new data. Default is False. -- max_resample_times: Under the dynamic_sample setting, limit the number of resampling attempts to a maximum of 3. Default is 3 times. -- overlong_filter: Skip overlong truncated samples, which will not be included in loss calculation. Default is False. -The hyperparameters for the reward function can be found in the [Built-in Reward Functions section](#built-in-reward-functions). -- top_entropy_quantile: Only tokens whose entropy ranks within the specified top quantile are included in the loss calculation. The default is 1.0, which means low-entropy tokens are not filtered. For details, refer to the [documentation](./GRPO/AdvancedResearch/entropy_mask.md). -- log_entropy: Logs the entropy values during training. The default is False. For more information, refer to the [documentation](./GRPO/GetStarted/GRPO.md#logged-metrics). -- importance_sampling_level: Controls how the importance sampling ratio is computed. Options are `token` and `sequence`. In `token` mode, the raw per-token log-probability ratios are used. In `sequence` mode, the log-probability ratios of all valid tokens in the sequence are averaged to produce a single ratio per sequence. The [GSPO paper](https://www.arxiv.org/abs/2507.18071) uses sequence-level importance sampling to stabilize training. The default is `token`. +##### Reward function parameters +Refer to the [documentation](./GRPO/DeveloperGuide/reward_function.md) for built-in reward functions. cosine reward function arguments - cosine_min_len_value_wrong (default: -0.5): Reward value corresponding to the minimum length when the answer is incorrect. diff --git a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md index 4deaf2fc98..defacd0735 100644 --- a/docs/source_en/Megatron-SWIFT/Command-line-parameters.md +++ b/docs/source_en/Megatron-SWIFT/Command-line-parameters.md @@ -250,7 +250,7 @@ LoRA Training: - use_rslora: Default is `False`. Whether to use `RS-LoRA`. **DPO Parameters** -- ref_load: The loading path for the reference model. This must be provided when using DPO/KTO algorithms with full-parameter training. Defaults to `None`, which means it will be set to the same value as `load`. +- ref_load: The loading path for the reference model. This must be provided when using DPO/GRPO/KTO algorithms with full-parameter training. Defaults to `None`, which means it will be set to the same value as `load`. - ref_adapter_load: The path to load the ref_adapter weights, default is `None`. If you want to use LoRA weights generated from SFT for DPO, please use "ms-swift>=3.8" and set `--adapter_load sft_ckpt --ref_adapter_load sft_ckpt --finetune true` during training. For resuming training from a checkpoint in this scenario, set `--adapter_load rlhf_ckpt --ref_adapter_load sft_ckpt --finetune false`. - beta: Has the same meaning as in [TRL](https://huggingface.co/docs/trl/main/en/dpo_trainer#trl.DPOConfig). It controls the degree of deviation from the reference model. A higher beta value indicates less deviation from the reference model. For the IPO loss function (`loss_type="ipo"`), beta is the regularization parameter as mentioned in the [paper](https://huggingface.co/papers/2310.12036). Default is 0.1. - 🔥rpo_alpha: A parameter from the [RPO paper](https://huggingface.co/papers/2404.19733) that controls the weight of the NLL term (i.e., the SFT loss) in the loss function, where `loss = dpo_loss + rpo_alpha * sft_loss`. The paper recommends setting it to `1.`. The default value is `None`, meaning the SFT loss is not included by default. @@ -268,6 +268,36 @@ LoRA Training: - desirable_weight: factor to weight desirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. - undesirable_weight: factor to weight undesirable losses to counter imbalance between desirable and undesirable pairs. Default is `1.`. +**GRPO Parameters** +- ref_load: Same meaning as in DPO. +- ref_adapter_load: Same meaning as in DPO. +- beta: KL regularization coefficient, default is 0.04. When set to 0, the reference model is not loaded. +- epsilon: Clip coefficient, default is 0.2. +- epsilon_high: Upper clip coefficient, default is None. When set, forms a clipping range [epsilon, epsilon_high] together with epsilon. +- overlong_filter: Skips samples that are truncated due to excessive length and excludes them from loss computation. Default is False. +- importance_sampling_level: Controls the level at which importance sampling ratios are computed. Options are `token`, `sequence`, and `sequence_token`. Default is `token`. See [GSPO Documentation](../Instruction/GRPO/AdvancedResearch/GSPO.md) for details. +- Batch Size Related Parameters (Note: all are completion-level) + - micro_batch_size: Batch size per device, default is 1. + - global_batch_size: Total batch size, equivalent to `micro_batch_size * data parallelism size * gradient accumulation steps`. Default is 16. Corresponds to the mini_batch_size (number of training samples per weight update). + - generation_batch_size: Sampling batch size, must be a multiple of global_batch_size. Default equals global_batch_size. + - steps_per_generation: Number of optimization steps per generation round, i.e., the ratio of generation_batch_size to global_batch_size. Default is 1. + - num_generations: Number of samples generated per prompt (the "G" value in the paper). generation_batch_size must be divisible by num_generations. Default is 8. +- reward_funcs: Reward functions used in GRPO algorithm. Options include `accuracy`, `format`, `cosine`, `repetition`, and `soft_overlong`, defined in swift/plugin/orm.py. You can also customize your own reward functions in the plugin. Default is `[]`. +- reward_weights: Weights assigned to each reward function. Must match the total number of reward functions and reward models. If None, all rewards are equally weighted with `1.0`. +- loss_type: Type of loss normalization. Options are ['grpo', 'bnpo', 'dr_grpo']. Default is 'grpo'. See this [PR](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348) for details. + +- vLLM Parameters + - vllm_gpu_memory_utilization: Pass-through parameter to vLLM, default is 0.9. + - vllm_max_model_len: Pass-through parameter to vLLM, default is None. + - vllm_enforce_eager: Pass-through parameter to vLLM, default is False. + - vllm_limit_mm_per_prompt: Pass-through parameter to vLLM, default is None. + - vllm_enable_prefix_caching: Pass-through parameter to vLLM, default is True. + - sleep_level: Release vLLM GPU memory during training. Options are [0, 1], default is 0 (no release). + - offload_optimizer: Whether to offload optimizer states during vLLM inference. Default is False. + - offload_model: Whether to offload model weights during vLLM inference. Default is False. + +For built-in reward function parameters, refer to the [documentation](../Instruction/GRPO/DeveloperGuide/reward_function.md). + ## Training Parameters Megatron training parameters are inherited from Megatron parameters and basic parameters (**sharing dataset, template, etc. with ms-swift, and also supporting model-specific parameters from ms-swift**). For details on basic parameters, please refer to [here](../Instruction/Command-line-parameters.md#base-arguments). Additionally, the following parameters are included: diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 78291888d2..f09c208afe 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -1266,6 +1266,8 @@ def _handle_megatron_cp(self, encoded: Dict[str, Any]) -> None: cp_size = self.sequence_parallel_size if not self.use_megatron or cp_size == 1: return + if self.mode == 'vllm': # skip for megatron grpo rollout + return input_ids = encoded['input_ids'] padding_len = math.ceil(len(input_ids) / (cp_size * 2)) * (cp_size * 2) - len(input_ids) input_ids += [self.tokenizer.pad_token_id] * padding_len diff --git a/swift/megatron/argument/megatron_args.py b/swift/megatron/argument/megatron_args.py index a9847db55e..edc2807275 100644 --- a/swift/megatron/argument/megatron_args.py +++ b/swift/megatron/argument/megatron_args.py @@ -10,14 +10,15 @@ from transformers.utils.versions import require_version from swift.llm.argument.base_args import to_abspath -from swift.utils import get_dist_setting, get_logger, json_parse_to_dict +from swift.utils import get_current_device, get_dist_setting, get_logger, is_master, json_parse_to_dict logger = get_logger() @dataclass class RLHFMegatronArgumentsMixin: - rlhf_type: Literal['dpo', 'kto'] = None + rlhf_type: Literal['dpo', 'kto', 'grpo'] = None + perform_initialization: bool = True ref_load: Optional[str] = None ref_adapter_load: Optional[str] = None @@ -33,6 +34,100 @@ class RLHFMegatronArgumentsMixin: undesirable_weight: float = 1. calculate_KL: Optional[bool] = None + # =========================== GRPO =========================== + generation_batch_size: Optional[int] = None + steps_per_generation: Optional[int] = None + num_generations: int = 8 + max_completion_length: int = 512 + # GSPO https://www.arxiv.org/abs/2507.18071 + importance_sampling_level: Literal['token', 'sequence', 'sequence_token'] = 'token' + + # ─────────────────────────── Sampling ─────────────────────────── + epsilon: float = 0.2 + epsilon_high: Optional[float] = None + delta: Optional[float] = None + top_k: int = 50 + top_p: float = 0.9 + repetition_penalty: float = 1. + # ─────────────────────────── VLLM ─────────────────────────── + use_vllm: bool = False + vllm_mode: Literal['server', 'colocate'] = 'colocate' + # ────────────── Internal VLLM (colocate) ────────────── + vllm_enable_prefix_caching: bool = True + vllm_gpu_memory_utilization: float = 0.9 + vllm_tensor_parallel_size: int = 1 + vllm_max_model_len: Optional[int] = None + vllm_enforce_eager: bool = False + vllm_limit_mm_per_prompt: Optional[Union[dict, str]] = None # '{"image": 5, "video": 2}' + vllm_disable_cascade_attn: bool = False + sleep_level: Literal[0, 1, 2] = 0 + + # ────────────── External VLLM (server, not supported yet) ────────────── + vllm_server_base_url: Optional[List[str]] = None + vllm_server_host: Optional[List[str]] = None + vllm_server_port: List[int] = field(default_factory=lambda: [8000]) + vllm_server_timeout: float = 240.0 + vllm_client: Optional[object] = field(init=False, default=None) + + # ─────────────────────────── Reward ─────────────────────────── + reward_funcs: List[str] = field(default_factory=list) + reward_weights: List[float] = None + # see details in swift/plugin/orm.py + # cosine reward, https://arxiv.org/abs/2502.03373 + cosine_min_len_value_wrong: float = -0.5 # r^w_0 in paper, Reward for wrong answers with zero completion length. + cosine_max_len_value_wrong: float = 0.0 # r^w_L in paper, Reward for wrong answers with max completion length. + cosine_min_len_value_correct: float = 1.0 # r^c_0 in paper, Reward for correct answers with zero completion length. + cosine_max_len_value_correct: float = 0.5 # r^c_L in paper, Reward for correct answers with max completion length. + cosine_max_len: Optional[int] = None # Lmax in paper, default equal to max_completion_length + # repetition penalty, https://arxiv.org/abs/2502.03373 + repetition_n_grams: int = 3 + repetition_max_penalty: float = -1.0 + # soft_overlong, https://arxiv.org/abs/2503.14476 + soft_max_length: Optional[int] = None + soft_cache_length: Optional[int] = None + + # ─────────────────────────── Not Supported Yet ─────────────────────────── + # reward model + reward_model: Optional[List[str]] = None + reward_model_plugin: Optional[List[str]] = None + # sync ref model + sync_ref_model: bool = False + ref_model_sync_steps: int = 512 + ref_model_mixup_alpha: float = 0.6 + + async_generate: bool = False + + move_model_batches: Optional[int] = None + offload_optimizer: bool = False + offload_model: bool = False + gc_collect_after_offload: bool = False # deprecated + + # multi turn + multi_turn_func: Optional[str] = None # deprecated + multi_turn_scheduler: Optional[str] = None + max_turns: Optional[int] = None + completion_length_limit_scope: Literal['total', 'per_round'] = 'per_round' + vllm_server_pass_dataset: bool = False + + # DAPO, https://arxiv.org/abs/2503.14476 + dynamic_sample: bool = False + max_resample_times: int = 3 + overlong_filter: bool = False + + # Dr. GRPO, https://arxiv.org/abs/2503.20783 + scale_rewards: bool = True + + # entropy + log_entropy: bool = False + # Beyond the 80/20 Rule, https://arxiv.org/abs/2506.01939 + top_entropy_quantile: float = 1.0 + + wandb_log_unique_prompts: Optional[bool] = None + num_iterations: int = 1 + + # dataset + dataset_shuffle: Optional[bool] = True + def _init_kto(self): if self.calculate_KL is None: # Not all losses require a KL calculation @@ -43,11 +138,93 @@ def _init_kto(self): def __post_init__(self): if self.rlhf_type is None: return - default_loss_type = {'kto': 'kto', 'dpo': 'sigmoid'} + default_loss_type = {'kto': 'kto', 'dpo': 'sigmoid', 'grpo': 'grpo'} if self.loss_type is None: self.loss_type = default_loss_type[self.rlhf_type] if self.rlhf_type == 'kto': self._init_kto() + if self.rlhf_type == 'grpo': + self._init_grpo() + + def _init_grpo(self): + + def _init_external_vllm(): + if self.rlhf_type != 'grpo' or (self.vllm_server_host is None and self.vllm_server_base_url is None): + return + from swift.trainers.rlhf_trainer.vllm_client import VLLMClient + if is_master(): + logger.info('Start connecting to vLLM server') + self.vllm_client = VLLMClient( + base_urls=self.vllm_server_base_url, + hosts=self.vllm_server_host, + server_ports=self.vllm_server_port, + connection_timeout=self.vllm_server_timeout) + self.vllm_client.init_communicator(device=get_current_device()) + logger.info('Connected to vLLM server') + + def _check_not_supported(): + pass + + def _check_batch_params(): + if self.generation_batch_size is None and self.steps_per_generation is None: + self.steps_per_generation = 1 + self.generation_batch_size = self.global_batch_size * self.steps_per_generation + elif self.generation_batch_size is not None and self.steps_per_generation is None: + # Just ensure the value is divisible by the global batch size + if self.generation_batch_size % self.global_batch_size != 0: + raise ValueError(f'generation_batch_size ({self.generation_batch_size}) ' + f'must be divisible by the global batch size ({self.global_batch_size}).') + self.steps_per_generation = self.generation_batch_size // self.global_batch_size + elif self.generation_batch_size is None and self.steps_per_generation is not None: + self.generation_batch_size = self.global_batch_size * self.steps_per_generation + else: + raise ValueError( + "'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time") + world_size = torch.distributed.get_world_size() + assert self.generation_batch_size % world_size == 0, \ + f'generation_batch_size ({self.generation_batch_size}) ' \ + f'must be divisible by the world size ({world_size})' + self.per_device_generation_batch_size = self.generation_batch_size // world_size + + _init_external_vllm() + _check_not_supported() + _check_batch_params() + # default loss_type if no loss_type is provided + assert self.loss_type in ['grpo', 'bnpo', 'dr_grpo'], \ + f'loss_type must be one of [grpo, bnpo, dr_grpo], but got {self.loss_type}' + if self.async_generate or not self.use_vllm: + self.sleep_level = 0 + self.remove_unused_columns = False + logger.info(f'Setting args.remove_unused_columns: {self.remove_unused_columns}') + if self.truncation_strategy is None: + self.truncation_strategy = 'left' + assert self.truncation_strategy in ['left', 'delete' + ], ("GRPO requires `truncation_strategy 'left' or 'delete'`, " + f"Current value: `truncation_strategy='{self.truncation_strategy}'`." + ) # noqa + if self.beta is None: + self.beta = 0.04 # https://arxiv.org/abs/2402.03300 + if self.async_generate: + logger.info('Using async mode. This is a approximate version which ' + 'will use the old weights to generate responses to accelerate. ' + 'This will ignore the `CLIP` of advantages, if you found the training ' + 'is unstable, you may consider using --async_generate false.') + if 'soft_overlong' in self.reward_funcs: + assert self.soft_cache_length is not None, \ + 'The soft_cache_length must be set when using soft overlong rewards.' + if self.soft_max_length is None: + self.soft_max_length = self.max_completion_length + logger.info(f'Auto-configured soft_max_length = max_completion_length {self.max_completion_length}') + if self.use_vllm: + # set vllm mode + if self.vllm_server_host is not None or self.vllm_server_base_url is not None: + if self.vllm_mode != 'server': + self.vllm_mode = 'server' + logger.warning('set vllm_mode to `server` since vllm server host/base_url is provided') + else: + if self.vllm_mode != 'colocate': + self.vllm_mode = 'colocate' + logger.warning('set vllm_mode to `colocate` since vllm_server_host is not provided') @dataclass @@ -177,6 +354,7 @@ class MegatronArguments(ExtraMegatronArguments): dataloader_type: Literal['single', 'cyclic', 'external'] = 'cyclic' manual_gc: bool = False manual_gc_interval: int = 0 + use_mbridge: bool = False # learning rate lr: Optional[float] = None @@ -205,7 +383,7 @@ class MegatronArguments(ExtraMegatronArguments): no_load_rng: bool = False finetune: bool = False ckpt_format: Literal['torch', 'torch_dist', 'zarr'] = 'torch_dist' - no_initialization: bool = True + no_initialization: bool = False auto_detect_ckpt_format: bool = True exit_on_missing_checkpoint: bool = True diff --git a/swift/megatron/argument/rlhf_args.py b/swift/megatron/argument/rlhf_args.py index 74c8c29c1b..a0cc0b2f4a 100644 --- a/swift/megatron/argument/rlhf_args.py +++ b/swift/megatron/argument/rlhf_args.py @@ -7,7 +7,7 @@ @dataclass class MegatronRLHFArguments(MegatronTrainArguments): - rlhf_type: Literal['dpo', 'kto'] = 'dpo' + rlhf_type: Literal['dpo', 'kto', 'grpo'] = 'dpo' loss_scale: str = 'last_round' calculate_per_token_loss: bool = False diff --git a/swift/megatron/argument/train_args.py b/swift/megatron/argument/train_args.py index 124740a2f6..7552e65a00 100644 --- a/swift/megatron/argument/train_args.py +++ b/swift/megatron/argument/train_args.py @@ -9,7 +9,7 @@ from swift.llm.argument.base_args import to_abspath from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master from ..model import get_megatron_model_meta -from .megatron_args import MegatronArguments +from .megatron_args import MegatronArguments, RLHFMegatronArgumentsMixin logger = get_logger() diff --git a/swift/megatron/train/rlhf.py b/swift/megatron/train/rlhf.py index da964950dc..5b133bdcde 100644 --- a/swift/megatron/train/rlhf.py +++ b/swift/megatron/train/rlhf.py @@ -2,9 +2,10 @@ from typing import List, Optional, Union from swift.llm.train.kto import prepare_kto_dataset +from swift.trainers.rlhf_trainer.utils import identity_data_collator from swift.utils import get_logger from ..argument import MegatronRLHFArguments -from ..trainers import MegatronDPOTrainer, MegatronKTOTrainer +from ..trainers import MegatronDPOTrainer, MegatronGRPOTrainer, MegatronKTOTrainer from .sft import MegatronSft logger = get_logger() @@ -18,6 +19,8 @@ def prepare_trainer(self): args = self.args if args.rlhf_type == 'dpo': trainer_cls = MegatronDPOTrainer + elif args.rlhf_type == 'grpo': + trainer_cls = MegatronGRPOTrainer elif args.rlhf_type == 'kto': trainer_cls = MegatronKTOTrainer else: @@ -26,10 +29,13 @@ def prepare_trainer(self): def _prepare_template(self) -> None: super()._prepare_template() - if self.args.rlhf_type == 'kto': - self.template.set_mode('kto') - else: - self.template.set_mode('rlhf') + model_mapping = {'grpo': 'train', 'kto': 'kto'} + self.template.set_mode(model_mapping.get(self.args.rlhf_type, 'rlhf')) + + def _get_data_collator(self): + if self.args.rlhf_type == 'grpo': + return identity_data_collator + return super()._get_data_collator() def _get_dataset(self): args = self.args diff --git a/swift/megatron/trainers/__init__.py b/swift/megatron/trainers/__init__.py index 4f95226ebc..d875bb2b60 100644 --- a/swift/megatron/trainers/__init__.py +++ b/swift/megatron/trainers/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from .dpo_trainer import MegatronDPOTrainer +from .grpo_trainer import MegatronGRPOTrainer from .kto_trainer import MegatronKTOTrainer from .trainer import MegatronTrainer diff --git a/swift/megatron/trainers/base.py b/swift/megatron/trainers/base.py index 1ecd6cd3c0..05f8152b08 100644 --- a/swift/megatron/trainers/base.py +++ b/swift/megatron/trainers/base.py @@ -3,7 +3,7 @@ import os import time from abc import ABC, abstractmethod -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from datetime import datetime from typing import Dict, Literal @@ -27,8 +27,10 @@ from megatron.training.training import num_floating_point_operations from megatron.training.utils import reduce_max_stat_across_model_parallel_group, report_memory, unwrap_model from packaging import version +from torch.distributed.nn import all_reduce +from transformers.utils import ContextManagers -from swift.llm import dynamic_gradient_checkpointing +from swift.llm import Template, dynamic_gradient_checkpointing from swift.plugin import MeanMetric from swift.trainers import SwiftMixin from swift.utils import JsonlWriter, deep_getattr, format_time, get_logger @@ -41,11 +43,12 @@ class BaseMegatronTrainer(ABC): - def __init__(self, args, template): + def __init__(self, args, template: Template): self.args = args self.template = template self.stimer = StragglerDetector() self.unwrapped_models = [] + self.wrapped_models = [] self.peft_models = [] logging_path = os.path.join(args.save, 'logging.jsonl') logger.info(f'logging_path: {logging_path}') @@ -70,9 +73,11 @@ def initialize_megatron(*_args, **kwargs): args = get_args() data_parallel_size = mpu.get_data_parallel_world_size() step_batch_size = args.micro_batch_size * data_parallel_size + num_generations = args.num_generations if hasattr(args, 'num_generations') else 1 if args.train_iters is None and args.max_epochs is not None: if hasattr(train_dataset, '__len__'): dataset_sample = len(train_dataset) // step_batch_size * step_batch_size + dataset_sample = dataset_sample * num_generations args.train_iters = dataset_sample * args.max_epochs // args.global_batch_size else: raise ValueError( @@ -82,6 +87,7 @@ def initialize_megatron(*_args, **kwargs): args.eval_iters = 0 elif hasattr(val_dataset, '__len__'): dataset_sample = len(val_dataset) // step_batch_size * step_batch_size + dataset_sample = dataset_sample * num_generations args.eval_iters = max(dataset_sample // args.global_batch_size, 1) else: raise ValueError( @@ -261,6 +267,7 @@ def new_model_provider_func(*args, **kwargs): with self._patch_load_state_dict(self._load_base_checkpoint): model, optimizer, opt_param_scheduler = self._origin_setup_model_and_optimizer( new_model_provider_func, model_type, *_args, **kwargs) + self.wrapped_models = model if args.initialize_embedding: for m in self.unwrapped_models: self._initialize_embedding(m) diff --git a/swift/megatron/trainers/grpo_trainer.py b/swift/megatron/trainers/grpo_trainer.py new file mode 100644 index 0000000000..f9685556db --- /dev/null +++ b/swift/megatron/trainers/grpo_trainer.py @@ -0,0 +1,941 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import gc +import inspect +from collections import defaultdict +from contextlib import contextmanager, nullcontext +from copy import copy, deepcopy +from functools import partial +from typing import Any, Dict, List, Union + +import torch +import torch.nn as nn +from megatron.core import mpu +from megatron.training import get_args, training +from trl.trainer.grpo_trainer import nanstd +from vllm.distributed import parallel_state as vllm_ps + +from swift.llm import RequestConfig, RowPreprocessor, Template, to_device +from swift.llm.infer.protocol import RolloutOutput +from swift.plugin import orms +from swift.trainers.rlhf_trainer.grpo_trainer import DataType +from swift.trainers.rlhf_trainer.utils import replace_assistant_response_with_ids +from swift.utils import get_current_device, get_logger, is_vllm_available, remove_response +from ..argument import MegatronArguments, MegatronRLHFArguments +from .rlhf_mixin import MegatronRLHFTrainer +from .utils import (gather, gather_object, load_megatron_model_to_gpu, load_megatron_optimizer, log_gpu_memory, + offload_megatron_model_to_cpu, offload_megatron_optimizer, patch_model_for_lora_export, + profiling_context) + +try: + from mbridge import AutoBridge +except ImportError: + pass + +logger = get_logger() + + +class MegatronGRPOTrainer(MegatronRLHFTrainer): + + def __init__(self, args: MegatronRLHFArguments, template: Template): + super().__init__(args, template) + self.args = args + self.hf_model_dir = args.model_info.model_dir + self.processing_class = self.template.processor + # TODO: multi turn scheduler(colocate multi turn) + self._prepare_template_data_collator() + self._init_grpo_params() + self._prepare_rewards() + self._prepare_rollout_engine() + # debug: use mbridge to convert mcore to hf + self.bridge = None + self._metrics = {'train': defaultdict(list), 'eval': defaultdict(list)} + + def _prepare_template_data_collator(self): + template = self.template + args = self.args + data_collator = template.data_collator + padding_to = None + if args.tensor_model_parallel_size > 1 and args.sequence_parallel: + padding_to = args.tensor_model_parallel_size + if args.context_parallel_size > 1: + padding_to = (padding_to or 1) * args.context_parallel_size + if args.fp8_format: + padding_to = max((padding_to or 1) * 8, 16) + logger.info(f'padding_to: {padding_to}') + data_collator = partial(data_collator, padding_to=padding_to) + template.data_collator = data_collator + + def _init_grpo_params(self): + args: MegatronArguments = self.args + # distributed params + self.world_size = torch.distributed.get_world_size() + self.process_index = torch.distributed.get_rank() + self.is_main_process = self.process_index == 0 + self.device = get_current_device() + # algorithm params + self.num_generations = args.num_generations # G in the GRPO paper + self.beta = args.beta + self.temperature = args.temperature + self.loss_type = args.loss_type + self.max_completion_length = args.max_completion_length + self.epsilon_low = args.epsilon + self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon + self.top_entropy_quantile = args.top_entropy_quantile + self.importance_sampling_level = args.importance_sampling_level + self.enable_offload = False + # batch size (completion-level) + self.generation_batch_size = args.generation_batch_size + self.steps_per_generation = args.steps_per_generation + self.global_batch_size = args.global_batch_size + self.micro_batch_size = args.micro_batch_size + self.per_device_generation_batch_size = args.per_device_generation_batch_size + + # sampling params + self.request_config = RequestConfig( + n=1, + max_tokens=args.max_completion_length, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + repetition_penalty=args.repetition_penalty, + stop=args.stop_words, + return_details=True) + + self._step = 0 + self._rollout_group = None # Will be lazily initialized + + def _prepare_rollout_engine(self): + args = self.args + self.vllm_mode = args.vllm_mode + self.vllm_gpu_memory_utilization = args.vllm_gpu_memory_utilization # only applies to colocation mode + self.vllm_tensor_parallel_size = args.vllm_tensor_parallel_size # only applies to colocation mode + self.use_vllm = args.use_vllm + self.async_generate = args.async_generate + self.use_fast_infer = self.use_vllm # whether to use the PT backend + self.vllm_use_async_engine = False + self.enable_offload = False + self.use_gym_env = False + self.enable_server_multi_turn = False # TODO + # for multi-turn server, maybe the num of rollout outputs is not equal to the num of rollout inputs + self.dynamic_num_samples = False + if self.use_vllm: + if not is_vllm_available(): + raise ImportError('vLLM is not available and `use_vllm` is set to True. ' + 'Please install vLLM with `pip install vllm -U` to use it.') + assert self.vllm_mode == 'colocate' # TODO: server mode + + if not self.world_size % self.vllm_tensor_parallel_size == 0: + raise ValueError(f'vllm_tensor_parallel_size ({self.vllm_tensor_parallel_size}) must divide world size ' + f'({self.world_size}) evenly.') + + self.enable_offload = self.args.offload_model or self.args.offload_optimizer + context = self.offload_context if self.enable_offload else nullcontext + + with context(): + self.engine = self.prepare_vllm() + if self.args.sleep_level > 0: + self.engine.engine.sleep(self.args.sleep_level) + log_gpu_memory('after sleep vLLM engine') + + def prepare_vllm(self): + from swift.llm.infer.infer_engine import GRPOVllmEngine + args = self.args + max_num_seqs = self.per_device_generation_batch_size * self.vllm_tensor_parallel_size + vllm_template = copy(self.template) + vllm_template.padding_free = False + engine = GRPOVllmEngine( + self.hf_model_dir, + args.torch_dtype, + model_type=args.model_type, + use_async_engine=False, + tensor_parallel_size=self.vllm_tensor_parallel_size, + gpu_memory_utilization=self.vllm_gpu_memory_utilization, + enable_prefix_caching=self.args.vllm_enable_prefix_caching, + max_num_seqs=max_num_seqs, + enforce_eager=self.args.vllm_enforce_eager, + limit_mm_per_prompt=self.args.vllm_limit_mm_per_prompt, + enable_sleep_mode=self.args.sleep_level > 0, + max_model_len=self.args.vllm_max_model_len, + seed=self.process_index // self.vllm_tensor_parallel_size, + disable_cascade_attn=self.args.vllm_disable_cascade_attn, + load_format='dummy', + template=vllm_template, + distributed_executor_backend='external_launcher', + ) + if self.vllm_tensor_parallel_size > 1: + self.vllm_tp_group = vllm_ps.get_tp_group().device_group + self._buffered_inputs = None + return engine + + def _move_model_to_vllm(self): + # TODO: server + if self.bridge is None: + self.bridge = AutoBridge.from_pretrained(self.hf_model_dir) + self._patch_mbridge(self.bridge) + + # Handle LoRA: merge adapters before exporting weights + is_lora_training = self.args.train_type == 'lora' + restore_funcs = [] + + try: + if is_lora_training: + self._merge_lora_adapters() + for model in self.unwrapped_models: + restore_func = patch_model_for_lora_export(model) + restore_funcs.append(restore_func) + + per_tensor_params = self.bridge.export_weights(self.unwrapped_models) + self.engine.inner_model.load_weights(per_tensor_params) + finally: + for restore_func in restore_funcs: + restore_func() + + # Unmerge adapters to restore training state + if is_lora_training: + logger.info('Unmerging LoRA adapters to restore training state...') + self._unmerge_lora_adapters() + + def _prepare_rewards(self): + # TODO: reward model + args = self.args + reward_funcs = args.reward_funcs + if not isinstance(reward_funcs, list): + reward_funcs = [reward_funcs] + + # initilize reward functions + if reward_funcs: + for i, reward_func in enumerate(reward_funcs): + if reward_func in orms: + reward_func_class = orms[reward_func] + reward_func_args = list(inspect.signature(reward_func_class.__init__).parameters) + reward_func_kwargs = { + key: getattr(args, key) + for key in reward_func_args if key not in ['self', 'args', 'kwargs'] and hasattr(args, key) + } + if 'tokenizer' in reward_func_args: + reward_func_kwargs['tokenizer'] = self.processing_class + reward_funcs[i] = reward_func_class(**reward_func_kwargs) + elif not callable(reward_func): + raise ValueError(f'reward_function {reward_func} is not implemented in swift.plugin') + + # get reward name for logging + self.reward_funcs = reward_funcs + self.reward_func_names = [] + for reward_func in reward_funcs: + if inspect.isfunction(reward_func): + reward_func_name = reward_func.__name__ + else: + reward_func_name = reward_func.__class__.__name__ + self.reward_func_names.append(reward_func_name) + + # set reward weights + if args.reward_weights is not None: + if len(args.reward_weights) != len(reward_funcs): + raise ValueError(f'Number of reward weights ({len(args.reward_weights)}) must match number of reward ' + f'functions ({len(reward_funcs)})') + self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32).to(self.device) + else: + self.reward_weights = torch.ones(len(self.reward_func_names), dtype=torch.float32).to(self.device) + + # TODO: reward models + self.reward_model_plugins = [None] * len(self.reward_funcs) + + assert self.reward_funcs, 'reward_funcs is not set' + + def _merge_lora_adapters(self): + """Merge LoRA adapters into base model weights for vLLM inference.""" + from ..tuners import LoraParallelLinear + for model in self.unwrapped_models: + for module in model.modules(): + if isinstance(module, LoraParallelLinear): + # Merge all active adapters + module.merge() + + def _unmerge_lora_adapters(self): + """Unmerge LoRA adapters to restore training state.""" + from ..tuners import LoraParallelLinear + for model in self.unwrapped_models: + for module in model.modules(): + if isinstance(module, LoraParallelLinear): + # Unmerge to restore separate LoRA weights for training + module.unmerge() + + def _patch_mbridge(self, bridge): + original_method = bridge._weight_to_hf_format + + def _weight_to_hf_format_patched(mcore_weights_name, mcore_weights): + # skip ViT weights + if 'visual' in mcore_weights_name: + if 'visual.visual' in mcore_weights_name: + mcore_weights_name = mcore_weights_name.replace('visual.visual', 'visual') + return [mcore_weights_name], [mcore_weights] + + if '.base_layer.' in mcore_weights_name: + mcore_weights_name = mcore_weights_name.replace('.base_layer.', '.') + + if '.modules_to_save.default.' in mcore_weights_name: + mcore_weights_name = mcore_weights_name.replace('.modules_to_save.default.', '.') + return original_method(mcore_weights_name, mcore_weights) + + bridge._weight_to_hf_format = _weight_to_hf_format_patched + + def _get_rollout_group(self): + """ + Get or create the rollout process group (TP×PP×CP). + + The rollout group is used for: + 1. Data slicing: distributing rollout data across all model parallel ranks (including CP) + 2. Gather operations: collecting results from all model parallel ranks (including CP) + + Note: MODEL_PARALLEL_GROUP only includes TP×PP, but we need TP×PP×CP for correct + data distribution during rollout phase. + + Key insight: ranks with the same DP index but different TP/PP/CP indices should be + in the same rollout group. These ranks will: + - During rollout: each process different data slices + - During training: TP/PP ranks process same data (model split), CP ranks process same data (sequence split) + - During gather: collect all data from TP×PP×CP ranks for training + """ + if self._rollout_group is not None: + return self._rollout_group + + cp_size = mpu.get_context_parallel_world_size() + if cp_size == 1: + # No CP, use the standard MODEL_PARALLEL_GROUP + self._rollout_group = mpu.get_model_parallel_group() + return self._rollout_group + + # Get parallel dimensions + tp_size = mpu.get_tensor_model_parallel_world_size() + pp_size = mpu.get_pipeline_model_parallel_world_size() + dp_size = mpu.get_data_parallel_world_size() + global_rank = torch.distributed.get_rank() + + # Calculate rollout group size + rollout_group_size = tp_size * pp_size * cp_size + + # Simple and reliable method: assume ranks are organized in contiguous blocks per DP group + # This is typically true for the default order (tp-cp-ep-dp-pp) + # Each DP group has rollout_group_size consecutive ranks + ranks_per_dp_group = rollout_group_size + my_dp_block_index = global_rank // ranks_per_dp_group + + # Calculate the rank range for my rollout group + group_start = my_dp_block_index * ranks_per_dp_group + + # Create all rollout groups (must be done on all ranks) + if not hasattr(self, '_rollout_groups_created'): + for dp_idx in range(dp_size): + group_start = dp_idx * ranks_per_dp_group + group_ranks = list(range(group_start, min(group_start + ranks_per_dp_group, self.world_size))) + group = torch.distributed.new_group(ranks=group_ranks, group_desc='ROLLOUT_GROUP') + if global_rank in group_ranks: + self._rollout_group = group + self._rollout_groups_created = True + + return self._rollout_group + + def _replace_data_iterator(self, data_iterator, model): + + if self._step % self.steps_per_generation == 0: + # each rollout DP group will generate generation_batch_size / world_size completions + completions_to_rollout = self.generation_batch_size // mpu.get_data_parallel_world_size() + # completions will be repeated num_generations times after + # so we need to divide num_iters_per_step by num_generations to get prompt batch size + prompts_to_rollout = completions_to_rollout // self.num_generations + # every iter will generate micro_batch_size prompts + num_iters_per_step = prompts_to_rollout // self.micro_batch_size + assert num_iters_per_step > 0, ( + f'num_iters_per_step={num_iters_per_step} <= 0. ' + f'This means no prompts will be generated' + f'generation_batch_size={self.generation_batch_size}, ' + f'data_parallel_world_size={mpu.get_data_parallel_world_size()}, ' + f'num_generations={self.num_generations}, ' + f'micro_batch_size={self.micro_batch_size}. ' + 'Please adjust these parameters so that ' + 'generation_batch_size // data_parallel_world_size // num_generations // micro_batch_size >= 1.') + rollout_batch = [] + for _ in range(num_iters_per_step): + rollout_batch.extend(next(data_iterator)) + micro_batch_data = self._generate_and_score_completions(rollout_batch) + num_mini_batch = self.global_batch_size // (self.micro_batch_size * mpu.get_data_parallel_world_size()) + mini_batch_data = [ + micro_batch_data[i:i + num_mini_batch] for i in range(0, len(micro_batch_data), num_mini_batch) + ] + assert len(mini_batch_data) == self.steps_per_generation + self._buffered_inputs = mini_batch_data + self._step += 1 + inputs = self._buffered_inputs[self._step % self.steps_per_generation] + return iter(inputs) + + def _generate_and_score_completions(self, batch): + # Get or create the rollout group (TP×PP×CP) + rollout_group = self._get_rollout_group() + + # batch : same across DP groups + def get_local_rollout_batch(batch): + # repeat num_generations times + global_rollout_batch = [deepcopy(item) for item in batch for _ in range(self.num_generations)] + # get local rollout data + rollout_rank = torch.distributed.get_rank(group=rollout_group) + rollout_group_size = torch.distributed.get_world_size(group=rollout_group) + + per_device_batch_size = self.per_device_generation_batch_size + assert rollout_group_size * per_device_batch_size == len(global_rollout_batch) + data_slice = slice(rollout_rank * per_device_batch_size, (rollout_rank + 1) * per_device_batch_size) + rollout_batch = global_rollout_batch[data_slice] + return rollout_batch + + # Step1: Rollout / Reward / Advantage + + rollout_batch = get_local_rollout_batch(batch) + + rollout_batch = self._generate_completions(rollout_batch) + + rewards_per_func = self._score_completions(rollout_batch) + + advantages = self._compute_advantages(rollout_batch, rewards_per_func) + + def _get_encoded_batch(rollout_batch, advantages): + template = self.template + encoded_batch = [template.encode(data, return_length=True) for data in rollout_batch] + encoded_batch = to_device(template.data_collator(encoded_batch), self.device) + labels = encoded_batch['labels'] + assert self.template.padding_free + position_ids = encoded_batch.get('text_position_ids') + if position_ids is None: + position_ids = encoded_batch.get('position_ids') + squeezed_position_ids = position_ids.squeeze() + assert squeezed_position_ids is not None + # Remove trailing padding zeros from position_ids to avoid interference + # Find the last non-zero position + last_nonzero_idx = (squeezed_position_ids != 0).nonzero(as_tuple=True)[0] + if len(last_nonzero_idx) > 0: + # Keep only up to the last non-zero position + 1 to include the last valid position + squeezed_position_ids = squeezed_position_ids[:last_nonzero_idx[-1] + 1] + + # Calculate lengths based on sequence boundaries (position_ids == 0) + lengths = torch.diff( + torch.cat([(squeezed_position_ids == 0).nonzero(as_tuple=True)[0], + torch.tensor([len(squeezed_position_ids)]).to(squeezed_position_ids.device)])) + advantages = torch.repeat_interleave(advantages, lengths) + truncated_mask = torch.tensor([b['is_truncated'] for b in rollout_batch], + dtype=torch.bool, + device=self.device) + truncated_mask = torch.repeat_interleave(truncated_mask, lengths).unsqueeze(0) + padding_length = labels.shape[1] - truncated_mask.shape[1] + if padding_length > 0: + padding = torch.zeros((1, padding_length), device=truncated_mask.device, dtype=truncated_mask.dtype) + truncated_mask = torch.cat([truncated_mask, padding], dim=1) + # Pad advantages to match the original position_ids length + original_length = position_ids.shape[1] + if advantages.shape[0] < original_length: + padding_length = original_length - advantages.shape[0] + padding = torch.zeros(padding_length, device=advantages.device, dtype=advantages.dtype) + advantages = torch.cat([advantages, padding]) + + encoded_batch.update({ + 'completion_mask': labels != -100, + 'truncated_mask': truncated_mask, + 'advantages': advantages, + }) + + return encoded_batch + + # Step2: ref/old logps + total_batch = gather_object(rollout_batch, group=rollout_group) + total_advantages = gather(advantages, group=rollout_group) + mini_batch_data = [] + for idx in range(0, len(total_batch), self.micro_batch_size): + micro_batch_data = total_batch[idx:idx + self.micro_batch_size] + micro_batch_data = self._maybe_replace_response_token(micro_batch_data) + micro_batch_advantages = total_advantages[idx:idx + self.micro_batch_size] + micro_batch_data = _get_encoded_batch(micro_batch_data, micro_batch_advantages) + micro_batch_data = self._maybe_compute_logps(micro_batch_data) + mini_batch_data.append(micro_batch_data) + + return mini_batch_data + + def _generate_completions(self, batch): + """ + Generate completions for a batch of rollout data using vLLM engine. + + This method processes rollout data for the current process, generates completions + using the vLLM engine, and merges the results back into the original batch. + + Args: + batch: Rollout data assigned to the current process. + + Returns: + batch: The input batch with rollout completion results merged in. + + Note: + Currently only supports colocate mode. Server mode support is planned for future implementation. + """ + # TODO: server mode + assert self.vllm_mode == 'colocate' + # Step 1: Wake up the engine if it's sleeping (vLLM colocate mode) + if self.engine.inner_model_executor.is_sleeping: + wake_up_params = inspect.signature(self.engine.engine.wake_up).parameters + # Load weights only (faster and reduces memory peak) + kwargs = {'tags': ['weights']} if 'tags' in wake_up_params else {} + self.engine.engine.wake_up(**kwargs) + log_gpu_memory(f'after wake up vLLM engine with {kwargs}') + + # Step 2: Load model weights + self._move_model_to_vllm() + + context = self.offload_context if self.enable_offload else nullcontext + with context(): + if (self.engine.inner_model_executor.is_sleeping + and 'tags' in inspect.signature(self.engine.engine.wake_up).parameters): + self.engine.engine.wake_up(tags=['kv_cache']) + log_gpu_memory('after wake up vLLM engine with kv_cache') + + # Step3: Rollout + batch = self.preprocess_rollout_data(batch) + outputs: List[RolloutOutput] = self._rollout(batch) + + # Step4: Sleep to release memory + if self.args.sleep_level > 0: + self.engine.engine.sleep(self.args.sleep_level) + log_gpu_memory('after sleep vLLM engine') + batch = self.postprocess_rollout_data(batch, outputs) + + return batch + + def preprocess_rollout_data(self, batch): + """ + Gather rollout trajectories across the vLLM tensor-parallel (TP) group. + + This method collect the full batch on every rank, then flattens + the nested lists into a single list of samples. + + Args: + batch (list): List of rollout samples local to this TP rank. + + Returns: + list: Flattened list containing all rollout samples from every + rank in the TP group. + """ + if self.vllm_tensor_parallel_size == 1: + return batch + + gathered_batch = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_batch, batch, group=self.vllm_tp_group) + flattened_batch = [p for sublist in gathered_batch for p in sublist] + return flattened_batch + + def _rollout(self, batch) -> List[RolloutOutput]: + request_config = self._get_request_config() + # TODO: server mode + rollout_outputs = self._colocate_rollout(batch, request_config) + return rollout_outputs + + def postprocess_rollout_data(self, batch, outputs): + """ + Post-process the raw vLLM generation outputs and merge them back into the + original input batch. + + Args: + batch (List[Dict[str, Any]]): + Original rollout samples. + outputs (List[RolloutOutput]): + outputs from vLLM from vLLM TP group + + Returns: + List[Dict[str, Any]]: + Updated samples with rollout results merged in. + """ + + if self.vllm_tensor_parallel_size > 1: + local_rank_in_group = torch.distributed.get_rank(group=self.vllm_tp_group) + orig_size = len(outputs) // self.vllm_tensor_parallel_size + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + outputs = outputs[tp_slice] + + def merge_output_input_data(input_data: Dict[str, Union[torch.Tensor, Any]], output: RolloutOutput): + response = output.response + choice = response.choices[0] + + # Step 1: Update or append assistant message + if output.messages: + input_data['messages'] = output.messages # Override full message history + else: + # not provided, append + messages = input_data['messages'] + remove_response(messages) + messages.append({'role': 'assistant', 'content': choice.message.content}) + # Step 2: Add token IDs and loss mask + if output.response_token_ids: + input_data['response_token_ids'] = output.response_token_ids + if output.response_loss_mask: + input_data['response_loss_mask'] = output.response_loss_mask + else: + # for single turn, skip tokenizer response + input_data['response_token_ids'] = output.response.choices[0].token_ids + + # Step 3: Attach rollout extra info + if output.rollout_infos: + input_data['rollout_infos'] = output.rollout_infos + + # Step 4: Store finish reason (used for truncation filters etc.) + input_data['finish_reason'] = choice.finish_reason + input_data['is_truncated'] = choice.finish_reason == 'length' + + return input_data + + assert len(batch) == len(outputs) + return [merge_output_input_data(input_data, output) for input_data, output in zip(batch, outputs)] + + def _get_request_config(self) -> RequestConfig: + request_config = copy(self.request_config) + if self.args.vllm_mode == 'colocate' and self.vllm_tensor_parallel_size > 1: + # Set request_config.seed + # 1. Ensure that the seed for vLLM Engines within each TP (Tensor Parallelism) group is the same; + # otherwise, the program may hang. + # 2. Ensure that the seed for vLLM Engines across different TP groups is different; + # otherwise, identical completions will be generated. + batch_size = self.per_device_generation_batch_size + batch_size *= self.vllm_tensor_parallel_size + # Since the TP (Tensor Parallelism) group gathers the inputs, + # multiply the batch size by the TP parallel size. + request_config.seed = batch_size * (self.process_index // self.vllm_tensor_parallel_size) + + return request_config + + def _colocate_rollout(self, batch, request_config: RequestConfig): + outputs: List[RolloutOutput] = self.engine.infer(infer_requests=batch, request_config=request_config) + return outputs + + def _score_completions(self, inputs: DataType) -> torch.Tensor: + """Score completions using all reward functions. + + Args: + inputs: List of input examples, each containing a 'messages' list with conversation history + + Returns: + rewards_per_func: Tensor of shape (num_examples, num_reward_funcs) with local reward values + """ + # Compute rewards using reward functions + local_rewards_per_func = self._compute_rewards_per_func(inputs) + + return local_rewards_per_func + + def _compute_rewards_per_func(self, batch: DataType) -> torch.Tensor: + """Compute rewards using all reward functions""" + device = self.device + rewards_per_func = torch.zeros((len(batch), len(self.reward_funcs)), device=device) + completions = [inp['messages'][-1]['content'] for inp in batch] + reward_kwargs = {} # TODO: training step info + for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate( + zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)): + with profiling_context(self, reward_func_name): + # reward model + if isinstance(reward_func, nn.Module): + output_reward_func = reward_model_plugin(inputs=batch, **reward_kwargs) + # reward function + else: + # Repeat all input columns (but "messages" and "completion") to match the number of generations + reward_kwargs.update(RowPreprocessor.rows_to_batched(batch)) + output_reward_func = reward_func(completions, **reward_kwargs) + output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func] + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + + # If all reward functions return None for a given row, issue a detailed warning + if torch.isnan(rewards_per_func).all(dim=1).any(): + nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0] + row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()} + row_reward_kwargs['completion'] = completions[nan_row_idx] + logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. ' + 'Please ensure that at least one reward function returns a valid reward.') + + return rewards_per_func + + def _compute_advantages(self, batch: DataType, rewards_per_func: torch.Tensor) -> torch.Tensor: + """Compute advantages for RL training.""" + + def maybe_normalize_advantages(advantages: torch.Tensor, rewards_std: torch.Tensor) -> torch.Tensor: + """Normalize advantages if configured; otherwise, return as-is.""" + if self.args.scale_rewards: + return advantages / (rewards_std + 1e-4) + return advantages + + assert len(batch) == rewards_per_func.shape[0] + total_rewards_per_func = gather(rewards_per_func) + rewards = (total_rewards_per_func * self.reward_weights.unsqueeze(0)).nansum(dim=1) + grouped_rewards = rewards.view(-1, self.num_generations) + group_rewards_mean = grouped_rewards.mean(dim=1) + group_rewards_std = grouped_rewards.std(dim=1) + + # Broadcast stats back to the original shape + group_rewards_mean = group_rewards_mean.repeat_interleave(self.num_generations) + group_rewards_std = group_rewards_std.repeat_interleave(self.num_generations) + + # Compute advantages relative to group mean + advantages = rewards - group_rewards_mean + advantages = maybe_normalize_advantages(advantages, group_rewards_std) + + def log_rewards_metrics(rewards: torch.Tensor, rewards_per_func_for_metrics: torch.Tensor): + """Log reward statistics for monitoring. Only log once per unique request_id.""" + # rewards: [prompt_batch_size, self.num_generations] + # rewards_per_func_for_metrics: [prompt_batch_size*self.num_generations, self.num_reward_funcs] + mode = 'train' if self.unwrapped_models[0].training else 'eval' + group_rewards = rewards.view(-1, self.num_generations) + rewards_mean = group_rewards.mean(-1).mean().item() + rewards_std = group_rewards.std(-1).mean().item() + is_std_zero = torch.isclose(group_rewards.std(dim=1), torch.zeros_like(group_rewards.std(dim=1))) + + self._metrics[mode]['reward'].append(rewards_mean) + self._metrics[mode]['reward_std'].append(rewards_std) + self._metrics[mode]['frac_reward_zero_std'].append(is_std_zero.float().mean().item()) + + # Log per-reward-function statistics using deduplicated rewards_per_func + for i, name in enumerate(self.reward_func_names): + col = rewards_per_func_for_metrics[:, i] + self._metrics[mode][f'rewards/{name}/mean'].append(torch.nanmean(col).item()) + self._metrics[mode][f'rewards/{name}/std'].append(nanstd(col).item()) + + log_rewards_metrics(rewards=grouped_rewards, rewards_per_func_for_metrics=rewards_per_func) + + slice_start = self.process_index * len(batch) + slice_end = slice_start + len(batch) + advantages = advantages[slice_start:slice_end] + + return advantages + + def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]: + # TODO: entropy + inputs = {k: v for k, v in batch.items() if k not in ['completion_mask', 'advantages', 'truncated_mask']} + if self.beta != 0.0: + with torch.no_grad(), self.null_ref_context() as ref_models: + assert len(ref_models) == 1, 'GRPO currently does not support VPP.' + ref_model = ref_models[0] + batch['ref_per_token_logps'] = self.model_forward( + ref_model, iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps'] + + if not self.on_policy: + batch['old_per_token_logps'] = self.model_forward( + self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps'] + return batch + + @contextmanager + def _disable_maxlength_template_context(self, template: Template): + # The max_length for prompt and completion has already been restricted, so there is no need for max_length here. + max_length = template.max_length + template.max_length = None + try: + yield + finally: + template.max_length = max_length + + def _maybe_replace_response_token(self, batch): + # maybe replace the response token with the response token ids to avoid repetitive tokenize + + for data in batch: + if 'response_token_ids' in data and data['response_token_ids']: + loss_mask = None + if 'response_loss_mask' in data and data['response_loss_mask']: + loss_mask = data['response_loss_mask'] + # token in token out + data['messages'] = replace_assistant_response_with_ids(data['messages'], data['response_token_ids'], + loss_mask) + return batch + + @property + def on_policy(self): + return self.steps_per_generation == 1 + + @contextmanager + def patch_megatron_data_collator(self, data_collator): + """ + Context manager that temporarily patches Megatron's data-loader factory so each + prompt-level micro-batch size equals (original micro-batch size // num_generations), + required by GRPO. Restores the original size and loader on exit. + """ + origin_build_pretraining_data_loader = training.build_pretraining_data_loader + + def build_pretraining_data_loader(*_args, **kwargs): + args = get_args() + org_micro_batch_size = args.micro_batch_size + # args.micro_batch_size = org_micro_batch_size // self.num_generations + res = origin_build_pretraining_data_loader(*_args, **kwargs) + args.micro_batch_size = org_micro_batch_size + if res is not None and args.dataloader_type != 'external': + res.collate_fn = data_collator + return res + + training.build_pretraining_data_loader = build_pretraining_data_loader + try: + yield + finally: + training.build_pretraining_data_loader = origin_build_pretraining_data_loader + + def forward_step(self, data_iterator, model): + # train_batch_size + # return: output_tensor, loss_func + data = self.get_batch(data_iterator) + data.pop('loss_scale', None) + inputs = { + k: v + for k, v in data.items() if k not in + ['completion_mask', 'ref_per_token_logps', 'advantages', 'old_per_token_logps', 'truncated_mask'] + } + + with self.stimer: + output_tensor = model(**inputs) + return output_tensor, partial(self.loss_func, data=data) + + def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]): + advantages = data['advantages'] + labels = data['labels'] + completion_mask = data['completion_mask'] + packed_seq_params = data['packed_seq_params'] + truncated_mask = data['truncated_mask'] + micro_batch_size = self.micro_batch_size + lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size + + 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size] + lengths_with_padding = packed_seq_params.cu_seqlens_q[1:] - packed_seq_params.cu_seqlens_q[:-1] + if mpu.get_context_parallel_world_size() > 1: + # When using Context Parallel, each rank only processes a portion of the sequence + # So we need to divide the lengths by CP size + cp_size = mpu.get_context_parallel_world_size() + cu_seqlens_cp = packed_seq_params.cu_seqlens_q // cp_size + lengths_with_padding = cu_seqlens_cp[1:] - cu_seqlens_cp[:-1] + lengths = cu_seqlens_cp[1:micro_batch_size + 1] - cu_seqlens_cp[:micro_batch_size] + + per_token_logps = self.get_logps( + output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True) + + if self.args.overlong_filter and truncated_mask.any(): + completion_mask = completion_mask & (~truncated_mask) + + if self.beta != 0.0: + ref_per_token_logps = data.get('ref_per_token_logps') + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1) + + old_per_token_logps = ( + per_token_logps.detach() if data.get('old_per_token_logps') is None else data['old_per_token_logps']) + log_ratio = per_token_logps - old_per_token_logps + + if self.importance_sampling_level == 'token': + log_importance_weights = log_ratio + elif self.importance_sampling_level in ['sequence', 'sequence_token']: + log_ratio_list = torch.split(log_ratio.squeeze(0), lengths_with_padding.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) + seq_weights = [(lr * m).sum() / m.sum().clamp(min=1.0) for lr, m in zip(log_ratio_list, mask_list)] + seq_level_log_weights = torch.stack(seq_weights).to(log_ratio.dtype).unsqueeze(-1) + if self.importance_sampling_level == 'sequence': + log_importance_weights = seq_level_log_weights + else: + seq_level_log_weight = seq_level_log_weights.detach() + seq_level_log_weight = torch.repeat_interleave(seq_level_log_weight, lengths).unsqueeze(0) + log_importance_weights = per_token_logps - per_token_logps.detach() + seq_level_log_weight + else: + raise ValueError( + f"Unknown importance sampling level: {self.importance_sampling_level}. Possible values are 'token' " + "and 'sequence'.") + + coef_1 = torch.exp(log_importance_weights) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + if self.args.delta is not None: + coef_1 = torch.clamp(coef_1, max=self.args.delta) + + if self.template.padding_free: + advantages = advantages[-coef_1.shape[1]:] + per_token_loss1 = coef_1 * advantages.unsqueeze(0) + per_token_loss2 = coef_2 * advantages.unsqueeze(0) + else: + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if self.loss_type == 'grpo': + loss_list = torch.split(per_token_loss.squeeze(0), lengths_with_padding.tolist()) + mask_list = torch.split(completion_mask.squeeze(0), lengths_with_padding.tolist()) + sample_loss = [(loss * mask).sum() / mask.sum().clamp(min=1.0) for loss, mask in zip(loss_list, mask_list)] + loss = torch.stack(sample_loss[:micro_batch_size]).mean() + elif self.loss_type == 'bnpo': + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + elif self.loss_type == 'dr_grpo': + loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + else: + raise ValueError(f'Unknown loss type: {self.loss_type}') + + # loss = loss.mean() + avg_metric = { + 'loss': loss.clone().detach(), + 'completions/mean_length': lengths.float().mean(), + } + max_metric = { + 'completions/max_length': lengths.float().max(), + } + min_metric = { + 'completions/min_length': lengths.float().min(), + } + if self.beta != 0.0: + avg_metric['kl'] = per_token_kl.mean().item() + avg_reporting_metric = loss.new_tensor(list(avg_metric.values())) + max_reporting_metric = loss.new_tensor(list(max_metric.values())) + min_reporting_metric = loss.new_tensor(list(min_metric.values())) + torch.distributed.all_reduce( + avg_reporting_metric, torch.distributed.ReduceOp.AVG, group=mpu.get_data_parallel_group()) + + torch.distributed.all_reduce( + max_reporting_metric, torch.distributed.ReduceOp.MAX, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce( + min_reporting_metric, torch.distributed.ReduceOp.MIN, group=mpu.get_data_parallel_group()) + avg_reporting_metric = {k: avg_reporting_metric[i] for i, k in enumerate(avg_metric.keys())} + max_reporting_metric = {k: max_reporting_metric[i] for i, k in enumerate(max_metric.keys())} + min_reporting_metric = {k: min_reporting_metric[i] for i, k in enumerate(min_metric.keys())} + addition_metrics = { + key: torch.tensor(sum(val) / len(val), device=loss.device) + for key, val in self._metrics['train'].items() + } + + reporting_metric = {**avg_reporting_metric, **max_reporting_metric, **min_reporting_metric, **addition_metrics} + # fix megatron-lm bug + # https://github.com/NVIDIA/Megatron-LM/blob/core_r0.12.0/megatron/core/pipeline_parallel/schedules.py#L291 + loss = loss / mpu.get_context_parallel_world_size() + return loss, reporting_metric + + def model_forward(self, model, data_iterator, no_grad=True, per_token=False): + # used to calculate model forward (logps) in GRPO + with self.stimer(bdata=True): + data = self.get_batch(data_iterator) + data.pop('loss_scale', None) + labels = data.get('labels') + context = torch.no_grad() if no_grad else nullcontext() + with context: + output_tensor = self._forward_step_helper(model, data) + packed_seq_params = data['packed_seq_params'] + data['logps'] = None if labels is None else self.get_logps( + output_tensor, labels, data['packed_seq_params'], packed_seq_params.num_samples, per_token=per_token) + return data + + @contextmanager + def offload_context(self): + if self.args.offload_model: + offload_megatron_model_to_cpu(self.wrapped_models) + if hasattr(self, 'ref_models') and self.ref_models: + offload_megatron_model_to_cpu(self.ref_models) + log_gpu_memory('after offload model to cpu') + if getattr(self, 'optimizer', None) and self.args.offload_optimizer: + offload_megatron_optimizer(self.optimizer) + log_gpu_memory('after offload optimizer to cpu') + + try: + yield + finally: + # reload (load back) model when exiting context + if self.args.offload_model: + load_megatron_model_to_gpu(self.wrapped_models) + if hasattr(self, 'ref_models') and self.ref_models: + load_megatron_model_to_gpu(self.ref_models) + log_gpu_memory('after load model to gpu') + if getattr(self, 'optimizer', None) and self.args.offload_optimizer: + load_megatron_optimizer(self.optimizer) + log_gpu_memory('after load optimizer to gpu') diff --git a/swift/megatron/trainers/rlhf_mixin.py b/swift/megatron/trainers/rlhf_mixin.py index ead111435e..55e4ae6b42 100644 --- a/swift/megatron/trainers/rlhf_mixin.py +++ b/swift/megatron/trainers/rlhf_mixin.py @@ -59,6 +59,8 @@ def _forward_step_helper(model, inputs): if mpu.is_pipeline_first_stage(): micro_batch_size = 1 # use qkv_format 'thd' seq_length = inputs['input_ids'].shape[1] + if 'position_ids' in inputs: + seq_length = inputs['position_ids'].shape[-1] if args.sequence_parallel: seq_length //= mpu.get_tensor_model_parallel_world_size() recv_shape_buffer = torch.tensor([seq_length, micro_batch_size, args.hidden_size], @@ -82,11 +84,16 @@ def _forward_step_helper(model, inputs): return output_tensor - def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None): + def get_logps(self, output_tensor, labels, packed_seq_params, num_samples=None, per_token=False): args = get_args() per_token_logps = -output_tensor loss_mask = labels != -100 per_token_logps = per_token_logps * loss_mask + if per_token: + if args.context_parallel_size > 1: + per_token_logps = all_reduce(per_token_logps, group=mpu.get_context_parallel_group()) + return per_token_logps + if num_samples is None: num_samples = packed_seq_params.num_samples * 2 cu_seqlens = packed_seq_params.cu_seqlens_q[:num_samples + 1] // args.context_parallel_size diff --git a/swift/megatron/trainers/utils.py b/swift/megatron/trainers/utils.py index 35dd538f0d..77cce7ec7f 100644 --- a/swift/megatron/trainers/utils.py +++ b/swift/megatron/trainers/utils.py @@ -1,14 +1,23 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict +import gc +import time +from contextlib import contextmanager +from typing import Any, Dict, List, Optional import torch +from accelerate.utils import gather as hf_gather +from accelerate.utils import gather_object as hf_gather_object from megatron.core import mpu +from megatron.core.distributed import DistributedDataParallel as DDP +from megatron.core.optimizer import ChainedOptimizer from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.utils import get_batch_on_this_cp_rank as mcore_get_batch_on_this_cp_rank -from megatron.training import get_args +from megatron.training import get_args, get_wandb_writer from swift.llm import get_packed_seq_params as _get_packed_seq_params from swift.llm import to_device +from swift.utils import get_logger +from swift.utils.torch_utils import empty_cache, get_current_device def get_swift_datasets_provider(train_dataset, val_dataset): @@ -65,7 +74,7 @@ def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: torch.Tensor, dim: int): for i in range(cu_seqlens.shape[0] - 1): slices = [slice(None)] * inputs.ndim slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1]) - val = inputs[slices] + val = inputs[tuple(slices)] view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:]) val = val.view(view_shape) index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu', @@ -95,6 +104,11 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): keys.append('decoder_input') else: keys.append('input_ids') + if hasattr(args, 'rlhf_type') and args.rlhf_type == 'grpo': + keys.append('truncated_mask') + keys.append('advantages') + keys.append('completion_mask') + packed_seq_params = batch.get('packed_seq_params') if packed_seq_params is None: return mcore_get_batch_on_this_cp_rank(batch) @@ -107,3 +121,270 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]): batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1) return batch + + +@contextmanager +def profiling_context(trainer, name: str): + start_time = time.perf_counter() + yield + end_time = time.perf_counter() + duration = end_time - start_time + + profiling_metrics = {f'profiling/Time taken: {trainer.__class__.__name__}.{name}': duration} + wandb_writer = get_wandb_writer() + if wandb_writer and trainer.is_main_process: + wandb_writer.log(profiling_metrics) + + # TODO: add swanlab support + + +def gather(tensor, group: Optional[torch.distributed.ProcessGroup] = None): + if group is None: + return hf_gather(tensor) + size = torch.distributed.get_world_size(group=group) + output = [torch.empty_like(tensor) for _ in range(size)] + torch.distributed.all_gather(output, tensor, group=group, async_op=False) + + return torch.cat(output, dim=0) + + +def gather_object(object: Any, group: Optional[torch.distributed.ProcessGroup] = None): + if group is None: + return hf_gather_object(object) + size = torch.distributed.get_world_size(group=group) + output_objects = [None for _ in range(size)] + torch.distributed.all_gather_object(output_objects, object, group=group) + return [x for y in output_objects for x in y] + + +# code borrowed from verl +@torch.no_grad() +def load_megatron_model_to_gpu(models, load_grad=True): + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # sometimes, we don't want to load grad for pure inference + if load_grad: + buffer.grad_data.storage().resize_(buffer.grad_data_size) + buffer.grad_data.zero_() + + if buffer.param_data.storage().size() == 0: + buffer.param_data.storage().resize_(buffer.param_data_size) + # copy data from cpu to cuda + buffer.param_data.copy_(buffer.param_data.cpu_data, non_blocking=True) + else: + # we need this for ref module + device_id = get_current_device() + for _, param in model_chunk.named_parameters(): + param.data = param.data.to(device_id, non_blocking=True) + if param.grad is not None: + param.grad = param.grad.to(device_id, non_blocking=True) + gc.collect() + empty_cache() + + +@torch.no_grad() +def offload_megatron_model_to_cpu(models): + """ + In megatron, the model and optimizer storage are: + - bf16 parameter data chunked in model parallel group + - fp32 grad chunked in model parallel group + - fp32 main_parameter chunked in model and dp group + - fp32 optimizer state chunked in model and dp group + """ + for model_chunk in models: + if isinstance(model_chunk, DDP): + model_chunk_all_buffers = [model_chunk.buffers, model_chunk.expert_parallel_buffers] + for buffers in model_chunk_all_buffers: + for buffer in buffers: + # offload parameters + if buffer.param_data.storage().size() > 0: + buffer.param_data.cpu_data = buffer.param_data.data.cpu().pin_memory() + buffer.param_data_size = buffer.param_data.storage().size() + buffer.param_data.storage().resize_(0) + + assert buffer.param_data_size == buffer.param_data.cpu_data.storage().size() + + if buffer.grad_data.storage().size() > 0: + # if the grad_data size is already zero, we assume that it is already offloaded + buffer.grad_data_size = buffer.grad_data.storage().size() + buffer.grad_data.storage().resize_(0) + else: + # we need this for ref module + for _, param in model_chunk.named_parameters(): + param.data = param.data.to('cpu', non_blocking=True) + if param.grad is not None: + param.grad = param.grad.to('cpu', non_blocking=True) + gc.collect() + empty_cache() + + +@torch.no_grad() +def load_megatron_copy_params(optimizers): + """ + Load optimizer parameters back to GPU. Handles ChainedOptimizer. + + Args: + optimizers: Optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def load_tensor_to_gpu(tensor): + if tensor is None: + return + device_id = get_current_device() + tensor.data = tensor.data.to(device_id, non_blocking=True) + + def load_group_to_gpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + load_tensor_to_gpu(param) + else: + load_tensor_to_gpu(param_group) + else: + load_tensor_to_gpu(group) + + # Load all parameter groups to GPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, 'shard_fp32_from_float16_groups'): + load_group_to_gpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def offload_megatron_copy_params(optimizers): + """ + Offload optimizer parameters to CPU. Supports both Megatron optimizers + and `ChainedOptimizer`, which wraps a list of underlying optimizers. + + Args: + optimizers: The optimizer or ChainedOptimizer instance. + """ + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + def offload_tensor_to_cpu(tensor): + if tensor is None: + return + tensor.data = tensor.data.to('cpu', non_blocking=True) + + def offload_group_to_cpu(group): + if group is None: + return + + if isinstance(group, list): + for param_group in group: + if isinstance(param_group, list): + for param in param_group: + offload_tensor_to_cpu(param) + else: + offload_tensor_to_cpu(param_group) + else: + offload_tensor_to_cpu(group) + + # Offload all parameter groups to CPU for each underlying optimizer + + for _opt in _iter_opts(optimizers): + if hasattr(_opt, 'shard_fp32_from_float16_groups'): + offload_group_to_cpu(_opt.shard_fp32_from_float16_groups) + + +@torch.no_grad() +def load_megatron_optimizer(optimizers): + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + load_megatron_copy_params(_opt) + # if we are using HybridDeviceOptimizer, we need to only move gpu optimizer state to gpu + if hasattr(_opt.optimizer, '_move_new_state_to_right_device'): + _opt.optimizer._move_new_state_to_right_device() + else: + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if 'exp_avg' in v: + v['exp_avg'] = v['exp_avg'].to(get_current_device(), non_blocking=True) + if 'exp_avg_sq' in v: + v['exp_avg_sq'] = v['exp_avg_sq'].to(get_current_device(), non_blocking=True) + gc.collect() + empty_cache() + + +@torch.no_grad() +def offload_megatron_optimizer(optimizers): + + def _iter_opts(opt): + if isinstance(opt, ChainedOptimizer): + return opt.chained_optimizers + return [opt] + + for _opt in _iter_opts(optimizers): + offload_megatron_copy_params(_opt) + opt_state_dict_values = _opt.optimizer.state.values() + for v in opt_state_dict_values: + if 'exp_avg' in v: + v['exp_avg'] = v['exp_avg'].to('cpu', non_blocking=True) + if 'exp_avg_sq' in v: + v['exp_avg_sq'] = v['exp_avg_sq'].to('cpu', non_blocking=True) + gc.collect() + empty_cache() + + +def log_gpu_memory(prefix: str = ''): + logger = get_logger() + + logger.info(f'{prefix} GPU memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB allocated, ' + f'{torch.cuda.memory_reserved()/1024**3:.2f}GB reserved') + + +def should_filter_lora_parameter(name: str) -> bool: + if 'lora_' in name: + return True + + if 'original_module' in name: + return True + return False + + +def patch_model_for_lora_export(model): + original_named_parameters = model.named_parameters + original_state_dict = model.state_dict + + def filtered_named_parameters(*args, **kwargs): + for name, param in original_named_parameters(*args, **kwargs): + if not should_filter_lora_parameter(name): + yield name, param + + def filtered_state_dict(*args, **kwargs): + state_dict = original_state_dict(*args, **kwargs) + filtered = {} + for name, param in state_dict.items(): + if not should_filter_lora_parameter(name): + filtered[name] = param + return filtered + + model.named_parameters = filtered_named_parameters + model.state_dict = filtered_state_dict + + def restore(): + model.named_parameters = original_named_parameters + model.state_dict = original_state_dict + + return restore diff --git a/swift/megatron/tuners/lora.py b/swift/megatron/tuners/lora.py index f9ad78ef50..1ef284a3d7 100644 --- a/swift/megatron/tuners/lora.py +++ b/swift/megatron/tuners/lora.py @@ -422,6 +422,40 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = N if origin_device.type == 'cpu': self.to(device=origin_device) + def unmerge(self) -> None: + """ + Unmerge all merged adapter weights from the base weights. + + This method reverses the merge operation by subtracting the LoRA delta weights + from the base layer weights, restoring the original base weights. + """ + if not self.merged: + # No adapters to unmerge + return + + base_layer = self.get_base_layer() + origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device + if origin_device.type == 'cpu': + self.to(device=get_current_device()) + + for active_adapter in self.merged_adapters: + if active_adapter in self.lora_A.keys(): + if self.is_grouped: + orig_weights = [getattr(base_layer, f'weight{i}') for i in range(base_layer.num_gemms)] + else: + orig_weights = [base_layer.weight] + + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, delta_weights): + # Subtract the delta weight to unmerge + orig_weight.data -= delta_weight + + # Clear the merged adapters list + self.merged_adapters = [] + + if origin_device.type == 'cpu': + self.to(device=origin_device) + def dispatch_megatron( target: torch.nn.Module, diff --git a/swift/trainers/rlhf_trainer/__init__.py b/swift/trainers/rlhf_trainer/__init__.py index 8830dbac20..829dba091b 100644 --- a/swift/trainers/rlhf_trainer/__init__.py +++ b/swift/trainers/rlhf_trainer/__init__.py @@ -14,6 +14,7 @@ from .gkd_trainer import GKDTrainer from .rlhf_mixin import RLHFTrainerMixin from .utils import patch_lora_merge, patch_lora_unmerge, round_robin, _ForwardRedirection + from .vllm_client import VLLMClient else: _import_structure = { 'cpo_trainer': ['CPOTrainer'], @@ -26,6 +27,7 @@ 'gkd_trainer': ['GKDTrainer'], 'rlhf_mixin': ['RLHFTrainerMixin'], 'utils': ['patch_lora_merge', 'patch_lora_unmerge', 'round_robin', '_ForwardRedirection'], + 'vllm_client': ['VLLMClient'], } import sys