Skip to content

Commit 97d646b

Browse files
hjh0119Jintao-Huang
authored andcommitted
[grpo] support offloading reference model (#4554)
* offload ref_model * argument * rm comment * doc and fix * refactor * clean scripts * rm unused dict * rm offload_ref_model argument * doc clean * doc
1 parent c9e4f33 commit 97d646b

File tree

10 files changed

+28
-51
lines changed

10 files changed

+28
-51
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ Running Environment:
119119
| transformers | >=4.33 | 4.51 | |
120120
| modelscope | >=1.23 | | |
121121
| peft | >=0.11,<0.16 | ||
122-
| trl | >=0.13,<0.18 | 0.17 |RLHF|
122+
| trl | >=0.13,<0.19 | 0.18 |RLHF|
123123
| deepspeed | >=0.14 | 0.14.5 | Training |
124124
| vllm | >=0.5.1 | 0.8 | Inference/Deployment/Evaluation |
125125
| lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |

README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ pip install -e .
115115
| transformers | >=4.33 | 4.51 ||
116116
| modelscope | >=1.23 | ||
117117
| peft | >=0.11,<0.16 | ||
118-
| trl | >=0.13,<0.18 | 0.17 |RLHF|
118+
| trl | >=0.13,<0.19 | 0.18 |RLHF|
119119
| deepspeed | >=0.14 | 0.14.5 |训练|
120120
| vllm | >=0.5.1 | 0.8 |推理/部署/评测|
121121
| lmdeploy | >=0.5 | 0.8 |推理/部署/评测|

docs/source/GetStarted/SWIFT安装.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2
7575
| transformers | >=4.33 | 4.51 ||
7676
| modelscope | >=1.23 | ||
7777
| peft | >=0.11,<0.16 | ||
78-
| trl | >=0.13,<0.18 | 0.17 |RLHF|
78+
| trl | >=0.13,<0.19 | 0.18 |RLHF|
7979
| deepspeed | >=0.14 | 0.14.5 |训练|
8080
| vllm | >=0.5.1 | 0.8 |推理/部署/评测|
8181
| lmdeploy | >=0.5 | 0.8 |推理/部署/评测|

docs/source/Instruction/GRPO.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ GRPO 训练框架支持集成高性能推理引擎(如 vLLM)来加速采样
4646
--sleep_level 1
4747
```
4848

49-
2. 在vLLM 推理阶段,释放训练模型和优化器占用的显存
49+
2. 在vLLM 推理阶段,释放模型和优化器占用的显存
5050

5151
```bash
5252
--offload_optimizer true \
@@ -222,7 +222,7 @@ A conversation between User and Assistant. The user asks a question, and the Ass
222222
- vllm_enable_prefix_caching: vllm透传参数,默认为True.
223223
- sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1], 默认为0,不释放.
224224
- offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。
225-
- offload_model: 是否在vLLM推理时offload 模型本身,默认为False。
225+
- offload_model: 是否在vLLM推理时 offload 模型,默认为False。
226226
- gc_collect_after_offload: 是否在offload结束时进行gc(python gc和GPU gc),默认为False。
227227
- completion_length_limit_scope: 在多轮对话中,`max_completion_length` 的限制范围。
228228
`total`限制所有对话轮次的总输出长度不超过`max_completion_length`, `per_round`限制每一轮的输出长度。

docs/source/Instruction/命令行参数.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ reward模型参数将在PPO、GRPO中使用。
446446
- vllm_enable_prefix_caching: vllm透传参数,默认为True。
447447
- sleep_level: 训练时释放 vLLM 显存,可选项为[0, 1], 默认为0,不释放
448448
- offload_optimizer: 是否在vLLM推理时offload optimizer参数,默认为False。
449-
- offload_model: 是否在vLLM推理时offload 模型本身,默认为False。
449+
- offload_model: 是否在vLLM推理时 offload 模型,默认为False。
450450
- gc_collect_after_offload: 是否在offload结束时进行gc(python gc和GPU gc),默认为False。
451451
- completion_length_limit_scope: 在多轮对话中,`max_completion_length` 的限制范围。
452452
`total`限制所有对话轮次的总输出长度不超过`max_completion_length`, `per_round`限制每一轮的输出长度。

docs/source_en/GetStarted/SWIFT-installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ More images can be found [here](https://modelscope.cn/docs/intro/environment-set
7676
| transformers | >=4.33 | 4.51 | |
7777
| modelscope | >=1.23 | | |
7878
| peft | >=0.11,<0.16 | | |
79-
| trl | >=0.13,<0.18 | 0.17 | RLHF |
79+
| trl | >=0.13,<0.19 | 0.18 | RLHF |
8080
| deepspeed | >=0.14 | 0.14.5 | Training |
8181
| vllm | >=0.5.1 | 0.8 | Inference/Deployment/Evaluation |
8282
| lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -457,8 +457,8 @@ The meanings of the following parameters can be referenced [here](https://huggin
457457
- vllm_limit_mm_per_prompt: vLLM passthrough parameter, default is None.
458458
- vllm_tensor_parallel_size: the tensor parallel size of vLLM engine, default is 1.
459459
- sleep_level: make vllm sleep when model is training. Options are 0 or 1, default is 0, no sleep
460-
- offload_optimizer: Whether to offload optimizer parameters during inference with vLLM/LMDeploy. The default is `False`.
461-
- offload_model: Whether to offload the model itself during inference with vLLM/LMDeploy. The default is `False`.
460+
- offload_optimizer: Whether to offload optimizer parameters during inference with vLLM. The default is `False`.
461+
- offload_model: Whether to offload the model during inference with vLLM. The default is `False`.
462462
- gc_collect_after_offload: Whether to perform garbage collection (both Python GC and GPU GC) after offloading. The default is `False`.
463463
- completion_length_limit_scope: Specifies the scope of the `max_completion_length` limit in multi-turn conversations.
464464
When set to `total`, the total output length across all turns must not exceed `max_completion_length`.

docs/source_en/Instruction/GRPO.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ When running in Colocate Mode , out-of-memory (OOM) errors are common due to sim
5353
--sleep_level 1
5454
```
5555

56-
2. Offload training model and optimizer memory during vLLM inference:
56+
2. Offload model and optimizer memory during vLLM inference:
5757

5858
```bash
5959
--offload_optimizer true \
@@ -232,7 +232,7 @@ Arguments
232232
- vllm_tensor_parallel_size: the tensor parallel size of vLLM engine, default is 1.
233233
- sleep_level: make vllm sleep when model is training. Options are 0 or 1, default is 0, no sleep
234234
- offload_optimizer: Whether to offload optimizer parameters during inference with vLLM. The default is `False`.
235-
- offload_model: Whether to offload the model itself during inference with vLLM. The default is `False`.
235+
- offload_model: Whether to offload the model during inference with vLLM. The default is `False`.
236236
- gc_collect_after_offload: Whether to perform garbage collection (both Python GC and GPU GC) after offloading. The default is `False`.
237237
- completion_length_limit_scope: Specifies the scope of the `max_completion_length` limit in multi-turn conversations.
238238
When set to `total`, the total output length across all turns must not exceed `max_completion_length`.

swift/trainers/arguments.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ class GRPOArgumentsMixin:
234234
# Dr. GRPO, https://arxiv.org/abs/2503.20783
235235
scale_rewards: bool = True
236236

237-
# compatible with trl main branch(0.17.0.dev0)
238237
wandb_log_unique_prompts: Optional[bool] = None
239238
generation_batch_size: Optional[int] = None
240239
steps_per_generation: Optional[int] = None

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
from swift.llm.model.utils import get_llm_model
3838
from swift.llm.template.template_inputs import StdTemplateInputs
3939
from swift.plugin import loss_scale_map, multi_turns, orms, rm_plugins
40-
from swift.utils import (JsonlWriter, gc_collect, get_device, get_logger, is_vllm_available, is_wandb_available,
41-
seed_worker)
40+
from swift.utils import (JsonlWriter, gc_collect, get_current_device, get_device, get_logger, is_vllm_available,
41+
is_wandb_available, seed_worker)
4242
from ..mixin import SwiftMixin
4343
from .rlhf_mixin import RLHFTrainerMixin
4444
from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge, unwrap_model_for_generation
@@ -93,10 +93,6 @@ def __init__(self,
9393

9494
self.processing_class = kwargs.get('template').tokenizer
9595

96-
# for offload model/optimizer
97-
self.offload_modules = {}
98-
self.offload_states = {}
99-
10096
if not isinstance(reward_funcs, list):
10197
reward_funcs = [reward_funcs]
10298

@@ -759,7 +755,9 @@ def _prefetch(self, dataloader: DataLoader):
759755
def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]:
760756
if self.vllm_mode == 'colocate' and self.args.sleep_level > 0:
761757
if self.args.offload_model:
762-
self.offload_model()
758+
self.offload_model(self.accelerator.unwrap_model(self.model))
759+
if self.ref_model:
760+
self.offload_model(self.ref_model)
763761
if self.args.offload_optimizer:
764762
self.offload_optimizer()
765763
if self.args.gc_collect_after_offload:
@@ -797,7 +795,9 @@ def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]:
797795
if self.args.gc_collect_after_offload:
798796
gc_collect()
799797
if self.args.offload_model:
800-
self.load_model()
798+
self.load_model(self.accelerator.unwrap_model(self.model))
799+
if self.ref_model:
800+
self.load_model(self.ref_model)
801801
if self.args.offload_optimizer:
802802
self.load_optimizer()
803803
return inputs, outputs
@@ -1387,60 +1387,38 @@ def _queue(self):
13871387
return self.train_queue
13881388

13891389
@torch.no_grad()
1390-
def offload_model(self):
1391-
if len(self.offload_modules) > 0:
1392-
return
1393-
unwrapped_model = self.accelerator.unwrap_model(self.model)
1394-
for name, module in unwrapped_model.named_modules():
1395-
if isinstance(module, torch.nn.Embedding):
1396-
self.offload_modules[name] = module.weight.device
1397-
module.to('cpu')
1398-
elif not hasattr(module, 'device'):
1399-
pass
1400-
elif module.device.type != 'cpu':
1401-
self.offload_modules[name] = module.device
1402-
module.to('cpu')
1390+
def offload_model(self, model):
1391+
for param in model.parameters():
1392+
param.data = param.data.to(torch.device('cpu'), non_blocking=True)
14031393

14041394
@torch.no_grad()
1405-
def load_model(self):
1406-
if len(self.offload_modules) == 0:
1407-
return
1408-
unwrapped_model = self.accelerator.unwrap_model(self.model)
1409-
for name, device in self.offload_modules.items():
1410-
module = unwrapped_model.get_submodule(name)
1411-
if isinstance(module, torch.nn.Embedding):
1412-
module.weight.to(device)
1413-
else:
1414-
module.to(device)
1415-
self.offload_modules.clear()
1395+
def load_model(self, model):
1396+
device = get_current_device()
1397+
for param in model.parameters():
1398+
param.data = param.data.to(device, non_blocking=True)
14161399

14171400
@torch.no_grad()
14181401
def offload_optimizer(self):
1419-
if len(self.offload_states) > 0:
1420-
return
14211402
if not self.optimizer.state:
14221403
return
14231404
for param_group in self.optimizer.param_groups:
14241405
for param in param_group['params']:
14251406
state = self.optimizer.state[param]
14261407
for key, value in state.items():
14271408
if isinstance(value, torch.Tensor):
1428-
self.offload_states[key] = value.device
14291409
state[key] = value.to('cpu', non_blocking=True)
14301410

14311411
@torch.no_grad()
14321412
def load_optimizer(self):
1433-
if len(self.offload_states) == 0:
1434-
return
1413+
device = get_current_device()
14351414
if not self.optimizer.state:
14361415
return
14371416
for param_group in self.optimizer.param_groups:
14381417
for param in param_group['params']:
14391418
state = self.optimizer.state[param]
14401419
for key, value in state.items():
14411420
if isinstance(value, torch.Tensor):
1442-
state[key] = value.to(self.offload_states[key], non_blocking=True)
1443-
self.offload_states.clear()
1421+
state[key] = value.to(device, non_blocking=True)
14441422

14451423
@contextmanager
14461424
def multi_turn_completion_length_context(self):

0 commit comments

Comments
 (0)