Skip to content

Commit 4ddb7fa

Browse files
authored
[grpo] pass trainer state to reward funcs(#4779)
1 parent 3356b67 commit 4ddb7fa

File tree

4 files changed

+36
-15
lines changed

4 files changed

+36
-15
lines changed

docs/source/Instruction/GRPO/DeveloperGuide/奖励函数.md

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# 奖励函数
22
## 自定义奖励函数
3-
奖励函数接受模型生成的文本 completions 以及其他数据集中的列作为参数(kwargs),并对模型生成的文本进行打分。以下是一个示例,展示了如何实现一个简单的长度奖励函数。该函数会在模型生成的文本长度超过 1024 时,给予 1.0 的奖励信号;否则,奖励信号为 0.0。
3+
奖励函数接受模型生成的文本 completions 其他数据集中的列以及训练器状态作为参数(kwargs)进行打分, 其中[训练器状态](https://huggingface.co/docs/transformers/main/main_classes/callback#transformers.TrainerState)包含训练的步数等信息。
4+
5+
注意:模型输入相关的列(比如query,response)会被处理为 messages 键,原数据集中的 assistant response 会被舍弃,请使用额外的列进行保留。
6+
7+
以下是一个示例,展示了如何实现一个简单的长度奖励函数。该函数会在模型生成的文本长度超过 1024 时,给予 1.0 的奖励信号;否则,奖励信号为 0.0。
48

59
```python
610
from swift.plugin import ORM, orms
@@ -13,23 +17,27 @@ orms['dummy']= DummyLengthRewardFunction
1317

1418
**获取数据集中的其他列**
1519

16-
比如奖励函数需要获取数据集`solution`列作为辅助计算,以下是两种获取方式
20+
比如奖励函数需要获取数据集`solution`列、当前训练步数和总步数作为辅助计算,以下是两种获取方式
1721

18-
第一种:__call__入参中显式定义 solution 列名
22+
第一种:__call__入参中显式定义列名
1923
```python
20-
def __call__(completions, solution, **kwargs):
24+
def __call__(completions, solution, trainer_state, **kwargs):
2125
print(solution)
26+
global_step = trainer_state.global_step
27+
max_steps = trainer_state.max_steps
2228
...
2329
```
2430

2531
第二种:在kwargs中获取
2632
```python
2733
def __call__(completions, **kwargs):
2834
solution = kwargs.get('solution')
35+
trainer_state = kwargs.get('trainer_state')
36+
global_step = trainer_state.global_step
37+
max_steps = trainer_state.max_steps
2938
...
3039
```
3140

32-
注意:messages 相关的列(比如query,response)会被处理,以及原数据集中的 assistant response 会被舍弃,请使用额外的列进行保留。
3341

3442
**使用自定义奖励函数**
3543

docs/source_en/Instruction/GRPO/DeveloperGuide/reward_function.md

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
# Reward Function
22
## Custom Reward Function
3-
The reward function takes the model-generated text `completions` and other columns from the dataset as parameters (`kwargs`) and scores the model-generated text. Below is an example demonstrating how to implement a simple length-based reward function. This function assigns a reward signal of 1.0 if the length of the model-generated text exceeds 1024; otherwise, the reward signal is 0.0.
3+
The reward function takes as arguments (via kwargs) the model-generated completions, other columns from the dataset, and the training state, and calculates a reward score. The [trainer state]() includes information such as the current training step.
4+
5+
Note: The columns related to model input (such as query and response) are converted to the messages key. The original assistant response in the dataset will be discarded, so please use extra columns if you wish to retain it.
6+
7+
Below is an example illustrating how to implement a simple length-based reward function. This function assigns a reward of 1.0 if the length of the generated completion exceeds 1024, and 0.0 otherwise.
48

59
```python
610
from swift.plugin import ORM, orms
@@ -12,25 +16,28 @@ orms['dummy']= DummyLengthRewardFunction
1216
```
1317

1418
**Accessing Other Columns in the Dataset**
19+
For example, if the reward function needs to access the solution column from the dataset, as well as the current training step and the total number of steps for calculation, there are two ways to retrieve these values:
1520

16-
For example, if the reward function needs to access the solution column from the dataset for auxiliary calculations, here are two ways to achieve this:
1721

18-
Explicitly define the solution column name in the __call__ parameters:
22+
Explicitly define the column name in the __call__ parameters:
1923
```python
20-
def __call__(completions, solution, **kwargs):
24+
def __call__(completions, solution, trainer_state, **kwargs):
2125
print(solution)
26+
global_step = trainer_state.global_step
27+
max_steps = trainer_state.max_steps
2228
...
2329
```
2430

2531
Retrieve it from kwargs:
2632
```python
2733
def __call__(completions, **kwargs):
2834
solution = kwargs.get('solution')
35+
trainer_state = kwargs.get('trainer_state')
36+
global_step = trainer_state.global_step
37+
max_steps = trainer_state.max_steps
2938
...
3039
```
3140

32-
Note: Columns related to messages (e.g., query, response) will be processed, and the original assistant responses in the dataset will be discarded. Use additional columns to retain such information.
33-
3441
**Using Custom Reward Functions**
3542

3643
You can add the reward function in [plugin program](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/plugin/plugin.py), register it using the parameter `--external_plugins examples/train/grpo/plugin/plugin.py`, and specify it via the `reward_funcs` parameter.

examples/train/grpo/plugin/plugin.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,9 @@ def __init__(self):
462462
self.format_max_possible = 1.0
463463
self.format_min_possible = 0.0
464464

465-
def __call__(self, completions, solution, global_step, **kwargs) -> List[float]:
465+
def __call__(self, completions, solution, **kwargs) -> List[float]:
466+
trainer_state = kwargs.get('trainer_state')
467+
global_step = trainer_state.global_step
466468
max_possible_reward = self.format_max_possible
467469
min_possible_reward = self.format_min_possible
468470
# Two stage (Coarse) Setting, divide training into two phases. Format Reward in [0,0.5] if step < 30 else [0,1]
@@ -521,9 +523,11 @@ def __init__(self):
521523
self.length_min_possible = 0.0
522524

523525
# customized reward functions: length
524-
def __call__(self, completions, solution, global_step, **kwargs):
526+
def __call__(self, completions, solution, **kwargs):
525527
max_possible_reward = self.length_max_possible
526528
min_possible_reward = self.length_min_possible
529+
trainer_state = kwargs.get('trainer_state')
530+
global_step = trainer_state.global_step
527531
# SCHEDULELENGTH: enable Dynamic Length Reward
528532
if os.getenv('SCHEDULELENGTH', 0) == '1':
529533
max_reward_len = (640 - 384) * global_step / 105 + 384
@@ -639,7 +643,9 @@ def compute_tool_call_reward(self, gt_tools, pd_tools, max_possible_reward, min_
639643
return (max_possible_reward - min_possible_reward) * score / local_max_possible + min_possible_reward
640644

641645
# custoimzed reward functions: tool call correctness
642-
def __call__(self, completions, solution, global_step, **kwargs):
646+
def __call__(self, completions, solution, **kwargs):
647+
trainer_state = kwargs.get('trainer_state')
648+
global_step = trainer_state.global_step
643649
max_possible_reward = self.tool_max_possible
644650
min_possible_reward = self.tool_min_possible
645651
# two stage (Coarse) Setting, divide training into two phases.

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ def _score_completions(self, inputs: InputsType) -> Tuple[torch.Tensor, torch.Te
907907
else:
908908
# Repeat all input columns (but "messages" and "completion") to match the number of generations
909909
reward_kwargs = RowPreprocessor.rows_to_batched(inputs)
910-
reward_kwargs['global_step'] = self.state.global_step
910+
reward_kwargs['trainer_state'] = self.state
911911
output_reward_func = reward_func(completions, **reward_kwargs)
912912
output_reward_func = [reward if reward is not None else torch.nan for reward in output_reward_func]
913913
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

0 commit comments

Comments
 (0)