Skip to content

Commit 92c4f57

Browse files
authored
[RM] support margin & update doc (#4817)
* reward modeling document * support margin * add margin to standard keys * pop margin * margin wip * rm_encode * data collator * convert margin to float * fix template inputs from dict * convert to tensor in data_collator * revert rm mode * fix judge * doc * padding_free&liger check
1 parent f9afe2d commit 92c4f57

File tree

9 files changed

+59
-6
lines changed

9 files changed

+59
-6
lines changed

docs/source/Customization/自定义数据集.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ alpaca格式:
6969
{"messages": [{"role": "system", "content": "你是个有用无害的数学计算器"}, {"role": "user", "content": "1+1等于几"}, {"role": "assistant", "content": "等于2"}, {"role": "user", "content": "再加1呢"}, {"role": "assistant", "content": "等于3"}], "rejected_response": "我不知道"}
7070
```
7171

72+
> 注: RM 额外支持 margin 列,参考[RM文档](../Instruction/人类对齐.md#rm)
73+
7274
#### KTO
7375

7476
```jsonl

docs/source/Instruction/人类对齐.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,18 @@ RLHF中的Reward Modeling阶段
3939

4040
增加的value head权重会保存在`value_head.safetensors``value_head.bin`文件中
4141

42+
RM损失函数如下
43+
44+
$
45+
\text{loss} = -\log \sigma \left( r^{(c)} - r^{(r)} - m \right) + \lambda \left( r^{(c)} + r^{(r)} \right)^2
46+
$
47+
48+
- $r^{(c)}$: 模型对 chosen response 的打分
49+
- $r^{(r)}$: 模型对 rejected response 的打分
50+
- $\lambda$: L2正则项系数,鼓励模型输出接近0,使用参数`center_rewards_coefficient`进行设置,来自[论文](https://arxiv.org/pdf/2307.09288), 默认为0
51+
- $m$: margin项,鼓励模型根据不同难度的样本进行区分,需要数据集中提供`margin`列,默认为0,来自[论文](https://arxiv.org/pdf/2307.09288)
52+
53+
4254
训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/rm.sh).
4355

4456
## PPO

docs/source_en/Customization/Custom-dataset.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ The following outlines the standard dataset format for ms-swift, where the "syst
7070
{"messages": [{"role": "system", "content": "You are a useful and harmless math calculator"}, {"role": "user", "content": "What is 1 + 1?"}, {"role": "assistant", "content": "It equals 2"}, {"role": "user", "content": "What about adding 1?"}, {"role": "assistant", "content": "It equals 3"}], "rejected_response": "I don't know"}
7171
```
7272

73+
> Note: RM additionally supports the margin column. For details, refer to the [RM documentation](../Instruction/RLHF.md#rm).
74+
7375
#### KTO
7476

7577
```jsonl

docs/source_en/Instruction/RLHF.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,18 @@ Use the base model or instruct model trained with SFT as the foundation model. A
3838

3939
The weights of the added value head will be saved in `value_head.safetensors` or `value_head.bin`.
4040

41+
The loss function for reward modeling is as follows:
42+
43+
$
44+
\text{loss} = -\log \sigma \left( r^{(c)} - r^{(r)} - m \right) + \lambda \left( r^{(c)} + r^{(r)} \right)^2
45+
$
46+
47+
- $r^{(c)}$: The score assigned by the model to the chosen response.
48+
- $r^{(r)}$: The score assigned by the model to the rejected response.
49+
- $\lambda$: L2 regularization coefficient that encourages the model outputs to be close to zero. It is set by the parameter `center_rewards_coefficient`, as described in [the paper](https://arxiv.org/pdf/2307.09288), and defaults to 0.
50+
- $m$: Margin term that encourages the model to distinguish between samples of different difficulty levels. The dataset needs to provide a `margin` column for this; by default, it is 0. This term is also introduced in [the paper](https://arxiv.org/pdf/2307.09288).
51+
52+
4153
Reference the training script [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/rm.sh).
4254

4355
## PPO

swift/llm/argument/rlhf_args.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def __post_init__(self):
126126
self._set_default()
127127
self._init_external_vllm()
128128
super().__post_init__()
129+
self._check_padding_free()
129130
self._check_grpo()
130131
self._external_vllm_warning()
131132

@@ -261,10 +262,12 @@ def _check_grpo(self):
261262
raise ValueError('Liger loss does not support two-sided GRPO loss yet.')
262263
if self.sequence_parallel_size > 1:
263264
raise ValueError('Liger loss does not support sequence parallel yet.')
265+
if self.padding_free:
266+
raise ValueError('Liger loss does not support padding free yet.')
267+
264268
from trl.import_utils import is_liger_kernel_available
265269
assert is_liger_kernel_available(), (
266270
'Please install/update liger-kernel by running: pip install -U liger-kernel')
267-
268271
if self.vllm_mode == 'server':
269272
assert not self.use_vllm or self.vllm_server_host is not None
270273

@@ -333,3 +336,10 @@ def _deprecated_warning(self):
333336
if self.gc_collect_after_offload:
334337
logger.warning(
335338
"The parameter 'gc_collect_after_offload' has been deprecated and will be removed in version 3.7. ")
339+
340+
def _check_padding_free(self):
341+
if self.padding_free:
342+
supported_types = ['grpo', 'dpo', 'gkd']
343+
if self.rlhf_type not in supported_types:
344+
raise NotImplementedError(f"The current rlhf_type '{self.rlhf_type}' does not support padding_free. "
345+
'Please set --padding_free to false.')

swift/llm/dataset/preprocessor/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
class RowPreprocessor:
2323
standard_keys = [
24-
'messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects', 'channel'
24+
'messages', 'rejected_response', 'label', 'images', 'videos', 'audios', 'tools', 'objects', 'channel', 'margin'
2525
]
2626

2727
def __init__(self,

swift/llm/template/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def get_base_model(model):
320320
return model
321321

322322
def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
323+
margin = inputs.margin
323324
chosen_inputs, rejected_inputs = inputs, deepcopy(inputs)
324325
assert chosen_inputs.rejected_response is not None, f'inputs: {inputs}'
325326
rejected_inputs.messages[-1]['content'] = chosen_inputs.rejected_response
@@ -331,6 +332,8 @@ def _rlhf_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
331332
data = locals()[f'{prefix}_encoded']
332333
for k, v in data.items():
333334
encoded[f'{prefix}_{k}'] = v
335+
if margin:
336+
encoded['margin'] = float(margin)
334337
return encoded
335338

336339
def _kto_encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
@@ -1391,7 +1394,14 @@ def _rlhf_data_collator(self,
13911394
new_batch = []
13921395
for prefix in [chosen_prefix, rejected_prefix]:
13931396
new_batch += self._fetch_inputs_startswith(batch, prefix)
1394-
return self._data_collator(new_batch, padding_to=padding_to)
1397+
res = self._data_collator(new_batch, padding_to=padding_to)
1398+
1399+
# reward modeling
1400+
margin = [b['margin'] for b in batch if b.get('margin') is not None]
1401+
if margin:
1402+
res['margin'] = torch.tensor(margin, dtype=torch.float)
1403+
1404+
return res
13951405

13961406
def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
13971407
new_batch = self._fetch_inputs_startswith(batch, 'chosen_')
@@ -1532,12 +1542,14 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
15321542
inputs_embeds = [b['inputs_embeds'] for b in batch if b.get('inputs_embeds') is not None]
15331543
input_ids = [b['input_ids'] for b in batch if b.get('input_ids') is not None]
15341544
channel = [b['channel'] for b in batch if b.get('channel') is not None]
1545+
15351546
if inputs_embeds:
15361547
res['inputs_embeds'] = inputs_embeds
15371548
if input_ids:
15381549
res['input_ids'] = input_ids
15391550
if channel:
15401551
res['channel'] = channel
1552+
15411553
for key in ['labels', 'loss_scale', 'position_ids', 'token_type_ids']:
15421554
val = [b[key] for b in batch if b.get(key) is not None]
15431555
if val:

swift/llm/template/template_inputs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ class StdTemplateInputs:
110110
videos: List[str] = field(default_factory=list)
111111
objects: Dict[str, List[Any]] = field(default_factory=dict)
112112

113+
margin: Optional[float] = None # for reward modeling
114+
113115
def __post_init__(self):
114116
self.image_idx = 0
115117
self.audio_idx = 0
@@ -135,7 +137,7 @@ def is_multimodal(self):
135137
@classmethod
136138
def from_dict(cls, inputs: Dict[str, Any]) -> Tuple['StdTemplateInputs', Dict[str, Any]]:
137139
kwargs = {}
138-
for key in ['rejected_response', 'label', 'channel']:
140+
for key in ['rejected_response', 'label', 'channel', 'margin']:
139141
if key in inputs:
140142
kwargs[key] = inputs[key]
141143
messages = inputs['messages']

swift/trainers/rlhf_trainer/reward_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ def compute_loss(self,
2424
return_outputs=False,
2525
num_items_in_batch=None) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
2626
inputs.pop('labels', None) # not use
27+
margin = inputs.pop('margin', None)
2728
attention_mask = inputs['attention_mask']
2829
batch_size = attention_mask.shape[0] // 2
2930
rewards = model(**inputs).logits
3031
rewards_chosen, rewards_rejected = torch.split(rewards, batch_size, dim=0)
31-
if 'margin' in inputs:
32-
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs['margin']).mean()
32+
if margin is not None:
33+
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - margin).mean()
3334
else:
3435
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
3536
if self.args.center_rewards_coefficient is not None:

0 commit comments

Comments
 (0)