Skip to content

Commit e8c7b2c

Browse files
authored
[grpo] support None reward & multi-task doc & more profiling (#4459)
* wip: doc todo * doc * doc polish * clean * doc
1 parent f9e72aa commit e8c7b2c

File tree

3 files changed

+167
-22
lines changed

3 files changed

+167
-22
lines changed

docs/source/Instruction/GRPO.md

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,55 @@ swift rlhf \
287287
1. 在 GRPOTrainer 中,reward_model 会依次append到 reward_funcs 中。因此,reward_weights 的顺序对应 [reward_funcs, reward_model]
288288
2. reward_model_plugin 默认为 default,即使用 ORM 处理逻辑。
289289

290+
## 多任务训练
291+
我们可以在数据集中添加一个用于标识任务类型的列,并在奖励函数/奖励模型插件中根据任务类型进行判断,从而实现多任务训练。假设数据集中包含数学和编程任务,比如:
292+
293+
```
294+
{"query": "Solve the equation x + 2 = 5", "solution": "3", "task": "math"},
295+
{"query": "Write a function to calculate the Fibonacci sequence", "solution": "xxx", "task": "code"},
296+
{"query": "What is the integral of x^2?", "solution": "xxx", "task": "math"},
297+
{"query": "Implement a sorting algorithm in Python", "solution": "xxx", "task": "code"},
298+
```
299+
300+
下面是针对不同任务的奖励函数的示例:
301+
302+
```python
303+
from swift.plugin import ORM, orms
304+
import random
305+
306+
# Math-specific reward function
307+
class MathRandomReward(ORM):
308+
def __call__(self, completions, task, **kwargs):
309+
rewards = []
310+
for completion, t in zip(completions, task):
311+
if t == "math":
312+
import random
313+
# imple math accuracy logic
314+
reward = random.random()
315+
rewards.append(reward)
316+
else:
317+
# Return None for non-math tasks
318+
rewards.append(None)
319+
return rewards
320+
321+
# Coding-specific reward function
322+
class CodeRandomReward(ORM):
323+
def __call__(self, completions, task, **kwargs):
324+
rewards = []
325+
for prompt, completion, t in zip(prompts, completions, task):
326+
if t == "code":
327+
# imple coding accuracy logic
328+
reward = random.random()
329+
rewards.append(reward)
330+
else:
331+
# Return None for non-coding tasks
332+
rewards.append(None)
333+
return rewards
334+
335+
orms['math_reward'] = MathRandomReward
336+
orms['code_reward'] = CodeRandomReward
337+
```
338+
对于非当前任务的数据, 通过返回 None 来处理,从而使得奖励相关仅计算任务内的数据。
290339

291340
## DAPO
292341
[Decoupled Clip and Dynamic sAmpling Policy Optimization (DAPO)](https://arxiv.org/abs/2503.14476)在GRPO的基础上设置了几种trick,分别是
@@ -363,7 +412,21 @@ num_generations = 64
363412

364413
**5. clip_ratio为什么总是1?**
365414

366-
num_iterations = 1,async_generate = False 下为 on-policy RL,old_policy此时等于policy
415+
Clip机制的核心目的是限制策略更新的幅度,防止因单次更新过大而导致策略性能崩溃(即策略更新后表现急剧下降)。
416+
Clip操作的具体公式如下:
417+
$$
418+
L_{\text{CLIP}}(\theta) = \mathbb{E}_{t} \left[ \min\left(r_{t}(\theta) \hat{A}_{t}, \text{clip}(r_{t}(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_{t} \right) \right]
419+
$$
420+
421+
其中:$r_{t}(\theta) = \frac{\pi_{\theta}(a_{t} \mid s_{t})}{\pi_{\text{old}}(a_{t} \mid s_{t})}$ 是重要性采样比,衡量新旧策略的差异。$\hat{A}_{t}$ 是优势函数(advantage function),表示动作的相对收益。$\epsilon$ 用于限制 $r_{t}(\theta)$ 的偏离范围。
422+
423+
在 on-policy 训练过程中,由于每次更新都使用最新策略生成的数据,新旧策略相同,即 $\pi_{\theta} = \pi_{\text{old}}$
424+
425+
因此重要性采样比恒为 1,此时,clip 操作不会生效。
426+
427+
在设置以下参数情况下,算法为off-policy (near-on-policy)
428+
1. num_iterations > 1
429+
2. steps_per_generation > gradient_accumulation_steps
367430

368431
参考[issue](https://github.com/huggingface/open-r1/issues/239#issuecomment-2646297851)
369432

docs/source_en/Instruction/GRPO.md

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,60 @@ Notes:
301301
1. In the GRPOTrainer, reward_model instances are appended sequentially to reward_funcs. Therefore, the order of reward_weights corresponds to [reward_funcs, reward_model].
302302
2. The default value for reward_model_plugin is default, which uses the ORM processing logic.
303303

304+
## Multi-task training
305+
306+
We can add a column to the dataset to identify the task type and make judgments based on the task type in the reward function/reward model plugin, thereby enabling multi-task training. Suppose the dataset contains math and programming tasks, such as:
307+
```
308+
{"query": "Solve the equation x + 2 = 5", "solution": "3", "task": "math"},
309+
{"query": "Write a function to calculate the Fibonacci sequence", "solution": "xxx", "task": "code"},
310+
{"query": "What is the integral of x^2?", "solution": "xxx", "task": "math"},
311+
{"query": "Implement a sorting algorithm in Python", "solution": "xxx", "task": "code"},
312+
```
313+
314+
Below are examples of reward functions for different tasks:
315+
316+
```python
317+
from swift.plugin import ORM, orms
318+
319+
# Math-specific reward function
320+
from swift.plugin import ORM, orms
321+
import random
322+
323+
# Math-specific reward function
324+
class MathRandomReward(ORM):
325+
def __call__(self, completions, task, **kwargs):
326+
rewards = []
327+
for completion, t in zip(completions, task):
328+
if t == "math":
329+
import random
330+
# imple math accuracy logic
331+
reward = random.random()
332+
rewards.append(reward)
333+
else:
334+
# Return None for non-math tasks
335+
rewards.append(None)
336+
return rewards
337+
338+
# Coding-specific reward function
339+
class CodeRandomReward(ORM):
340+
def __call__(self, completions, task, **kwargs):
341+
rewards = []
342+
for completion, t in zip(completions, task):
343+
if t == "code":
344+
# imple coding accuracy logic
345+
reward = random.random()
346+
rewards.append(reward)
347+
else:
348+
# Return None for non-coding tasks
349+
rewards.append(None)
350+
return rewards
351+
352+
orms['math_reward'] = MathRandomReward
353+
orms['code_reward'] = CodeRandomReward
354+
```
355+
356+
For data that does not belong to the current task, it is handled by returning None, ensuring that the reward calculation only applies to data within the task.
357+
304358

305359
## DAPO
306360
Decoupled Clip and Dynamic Sampling Policy Optimization (DAPO) introduces several tricks based on GRPO, which are:
@@ -380,7 +434,21 @@ See reference: [issue](https://github.com/modelscope/ms-swift/issues/3912)
380434

381435
**5. Why is clip_ratio always 1?**
382436

383-
When num_iterations = 1 and async_generate = False, it's on-policy RL, and old_policy is equal to policy.
437+
The core purpose of the Clip mechanism is to limit the magnitude of policy updates, preventing a single update from being too large and causing a collapse in policy performance (i.e., a sudden drop in performance after the policy is updated). The specific formula for the Clip operation is as follows:
438+
439+
$$
440+
L_{\text{CLIP}}(\theta) = \mathbb{E}_{t} \left[ \min\left(r_{t}(\theta) \hat{A}_{t}, \text{clip}(r_{t}(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_{t} \right) \right]
441+
$$
442+
443+
Where $r_{t}(\theta) = \frac{\pi_{\theta}(a_{t} \mid s_{t})}{\pi_{\text{old}}(a_{t} \mid s_{t})}$ is the importance sampling ratio, measuring the difference between the new and old policies. $\hat{A}_{t}$ is the advantage function, representing the relative reward of an action. $\epsilon$ is used to limit the deviation range of $r_{t}(\theta)$
444+
445+
446+
Therefore, the importance sampling is always equal to 1, and in this case, the clip operation will not take effect.
447+
448+
Under the following parameter settings, the algorithm is off-policy (near-on-policy).
449+
450+
1. num_iterations > 1
451+
2. steps_per_generation > gradient_accumulation_steps
384452

385453
See reference: [issue](https://github.com/huggingface/open-r1/issues/239#issuecomment-2646297851)
386454

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
from transformers import PreTrainedModel, TrainerCallback
2929
from transformers.trainer import Trainer
3030
from trl import GRPOTrainer as HFGRPOTrainer
31-
from trl.extras.profiling import profiling_decorator
31+
from trl.extras.profiling import profiling_context, profiling_decorator
3232
from trl.models import prepare_deepspeed
3333
from trl.trainer.callbacks import SyncRefModelCallback
34-
from trl.trainer.grpo_trainer import nanmax, nanmin
34+
from trl.trainer.grpo_trainer import nanmax, nanmin, nanstd
3535

3636
from swift.llm import InferRequest, MultiModelKeys, RequestConfig, RowPreprocessor, get_model_arch, to_device
3737
from swift.llm.model.utils import get_llm_model
@@ -873,19 +873,30 @@ def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Te
873873
completions = [example['messages'][-1]['content'] for example in inputs]
874874
rewards_per_func = torch.zeros((len(inputs), len(self.reward_funcs)), device=device)
875875

876-
for i, (reward_func, reward_model_plugin) in enumerate(zip(self.reward_funcs, self.reward_model_plugins)):
877-
# reward model
878-
if isinstance(reward_func, nn.Module):
879-
rewards_per_func[:, i] = reward_model_plugin(inputs=inputs)
880-
# reward function
881-
else:
882-
# Repeat all input columns (but "messages" and "completion") to match the number of generations
883-
reward_kwargs = RowPreprocessor.rows_to_batched(inputs)
884-
output_reward_func = reward_func(completions, **reward_kwargs)
876+
for i, (reward_func, reward_model_plugin, reward_func_name) in enumerate(
877+
zip(self.reward_funcs, self.reward_model_plugins, self.reward_func_names)):
878+
with profiling_context(self, reward_func_name):
879+
# reward model
880+
if isinstance(reward_func, nn.Module):
881+
output_reward_func = reward_model_plugin(inputs=inputs)
882+
# reward function
883+
else:
884+
# Repeat all input columns (but "messages" and "completion") to match the number of generations
885+
reward_kwargs = RowPreprocessor.rows_to_batched(inputs)
886+
output_reward_func = reward_func(completions, **reward_kwargs)
887+
output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
885888
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
886889

890+
# If all reward functions return None for a given row, issue a detailed warning
891+
if torch.isnan(rewards_per_func).all(dim=1).any():
892+
nan_row_idx = torch.isnan(rewards_per_func).all(dim=1).nonzero(as_tuple=True)[0][0]
893+
row_reward_kwargs = {key: value[nan_row_idx] for key, value in reward_kwargs.items()}
894+
row_reward_kwargs['completion'] = completions[nan_row_idx]
895+
logger.warning(f'All reward functions returned None for the following kwargs: {row_reward_kwargs}. '
896+
'Please ensure that at least one reward function returns a valid reward.')
897+
887898
total_rewards_per_func = gather(rewards_per_func)
888-
total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
899+
total_rewards = (total_rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1)
889900

890901
return total_rewards_per_func, total_rewards, completions
891902

@@ -1027,10 +1038,11 @@ def _log_metrics(self, inputs, messages, completions, rewards, rewards_per_func)
10271038

10281039
self._metrics[mode]['completions/clipped_ratio'].append(clipped_completions_ratio)
10291040

1041+
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
10301042
for i, reward_func_name in enumerate(self.reward_func_names):
1031-
mean_rewards = rewards_per_func[:, i].mean().item()
1043+
mean_rewards = torch.nanmean(rewards_per_func[:, i]).item()
10321044
self._metrics[mode][f'rewards/{reward_func_name}/mean'].append(mean_rewards)
1033-
std_rewards = rewards_per_func[:, i].std().item()
1045+
std_rewards = nanstd(rewards_per_func[:, i]).item()
10341046
self._metrics[mode][f'rewards/{reward_func_name}/std'].append(std_rewards)
10351047

10361048
# Log overall reward stats
@@ -1071,7 +1083,8 @@ def _compute_loss(self, model, inputs):
10711083
# apply the completion_mask to exclude loss and metrics for overlong completions
10721084
if self.args.overlong_filter and any(truncated_mask):
10731085
if all(truncated_mask):
1074-
logger.info('All completions are overlong, loss and KL will be zero')
1086+
logger.info('All completions are overlong and truncated, '
1087+
'resulting in NaN some values for some metrics (e.g., KL)')
10751088
truncated_mask = truncated_mask.unsqueeze(-1).expand_as(completion_mask).to(completion_mask.device)
10761089
completion_mask = completion_mask * (~truncated_mask)
10771090

@@ -1341,11 +1354,12 @@ def _engine_infer(
13411354
*,
13421355
use_tqdm: Optional[bool] = False,
13431356
):
1344-
if self.vllm_mode == 'server':
1345-
self._process_infer_requests_images(infer_requests)
1346-
return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm)
1347-
else:
1348-
return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm)
1357+
with profiling_context(self, 'generate'):
1358+
if self.vllm_mode == 'server':
1359+
self._process_infer_requests_images(infer_requests)
1360+
return self.vllm_client.infer(infer_requests, asdict(request_config), use_tqdm=use_tqdm)
1361+
else:
1362+
return self.engine.infer(infer_requests, request_config, use_tqdm=use_tqdm)
13491363

13501364
def _process_infer_requests_images(self, infer_requests: List[InferRequest]):
13511365
# Process image format into a format that session.post can accept

0 commit comments

Comments
 (0)