|
37 | 37 | from swift.llm.model.utils import get_llm_model |
38 | 38 | from swift.llm.template.template_inputs import StdTemplateInputs |
39 | 39 | from swift.plugin import loss_scale_map, multi_turns, orms, rm_plugins |
40 | | -from swift.utils import (JsonlWriter, gc_collect, get_device, get_logger, is_vllm_available, is_wandb_available, |
41 | | - seed_worker) |
| 40 | +from swift.utils import (JsonlWriter, gc_collect, get_current_device, get_device, get_logger, is_vllm_available, |
| 41 | + is_wandb_available, seed_worker) |
42 | 42 | from ..mixin import SwiftMixin |
43 | 43 | from .rlhf_mixin import RLHFTrainerMixin |
44 | 44 | from .utils import _ForwardRedirection, patch_lora_merge, patch_lora_unmerge, unwrap_model_for_generation |
@@ -93,10 +93,6 @@ def __init__(self, |
93 | 93 |
|
94 | 94 | self.processing_class = kwargs.get('template').tokenizer |
95 | 95 |
|
96 | | - # for offload model/optimizer |
97 | | - self.offload_modules = {} |
98 | | - self.offload_states = {} |
99 | | - |
100 | 96 | if not isinstance(reward_funcs, list): |
101 | 97 | reward_funcs = [reward_funcs] |
102 | 98 |
|
@@ -759,7 +755,9 @@ def _prefetch(self, dataloader: DataLoader): |
759 | 755 | def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]: |
760 | 756 | if self.vllm_mode == 'colocate' and self.args.sleep_level > 0: |
761 | 757 | if self.args.offload_model: |
762 | | - self.offload_model() |
| 758 | + self.offload_model(self.accelerator.unwrap_model(self.model)) |
| 759 | + if self.ref_model: |
| 760 | + self.offload_model(self.ref_model) |
763 | 761 | if self.args.offload_optimizer: |
764 | 762 | self.offload_optimizer() |
765 | 763 | if self.args.gc_collect_after_offload: |
@@ -797,7 +795,9 @@ def _fast_infer(self, inputs: InputsType) -> Tuple[InputsType, OutputsType]: |
797 | 795 | if self.args.gc_collect_after_offload: |
798 | 796 | gc_collect() |
799 | 797 | if self.args.offload_model: |
800 | | - self.load_model() |
| 798 | + self.load_model(self.accelerator.unwrap_model(self.model)) |
| 799 | + if self.ref_model: |
| 800 | + self.load_model(self.ref_model) |
801 | 801 | if self.args.offload_optimizer: |
802 | 802 | self.load_optimizer() |
803 | 803 | return inputs, outputs |
@@ -1387,60 +1387,38 @@ def _queue(self): |
1387 | 1387 | return self.train_queue |
1388 | 1388 |
|
1389 | 1389 | @torch.no_grad() |
1390 | | - def offload_model(self): |
1391 | | - if len(self.offload_modules) > 0: |
1392 | | - return |
1393 | | - unwrapped_model = self.accelerator.unwrap_model(self.model) |
1394 | | - for name, module in unwrapped_model.named_modules(): |
1395 | | - if isinstance(module, torch.nn.Embedding): |
1396 | | - self.offload_modules[name] = module.weight.device |
1397 | | - module.to('cpu') |
1398 | | - elif not hasattr(module, 'device'): |
1399 | | - pass |
1400 | | - elif module.device.type != 'cpu': |
1401 | | - self.offload_modules[name] = module.device |
1402 | | - module.to('cpu') |
| 1390 | + def offload_model(self, model): |
| 1391 | + for param in model.parameters(): |
| 1392 | + param.data = param.data.to(torch.device('cpu'), non_blocking=True) |
1403 | 1393 |
|
1404 | 1394 | @torch.no_grad() |
1405 | | - def load_model(self): |
1406 | | - if len(self.offload_modules) == 0: |
1407 | | - return |
1408 | | - unwrapped_model = self.accelerator.unwrap_model(self.model) |
1409 | | - for name, device in self.offload_modules.items(): |
1410 | | - module = unwrapped_model.get_submodule(name) |
1411 | | - if isinstance(module, torch.nn.Embedding): |
1412 | | - module.weight.to(device) |
1413 | | - else: |
1414 | | - module.to(device) |
1415 | | - self.offload_modules.clear() |
| 1395 | + def load_model(self, model): |
| 1396 | + device = get_current_device() |
| 1397 | + for param in model.parameters(): |
| 1398 | + param.data = param.data.to(device, non_blocking=True) |
1416 | 1399 |
|
1417 | 1400 | @torch.no_grad() |
1418 | 1401 | def offload_optimizer(self): |
1419 | | - if len(self.offload_states) > 0: |
1420 | | - return |
1421 | 1402 | if not self.optimizer.state: |
1422 | 1403 | return |
1423 | 1404 | for param_group in self.optimizer.param_groups: |
1424 | 1405 | for param in param_group['params']: |
1425 | 1406 | state = self.optimizer.state[param] |
1426 | 1407 | for key, value in state.items(): |
1427 | 1408 | if isinstance(value, torch.Tensor): |
1428 | | - self.offload_states[key] = value.device |
1429 | 1409 | state[key] = value.to('cpu', non_blocking=True) |
1430 | 1410 |
|
1431 | 1411 | @torch.no_grad() |
1432 | 1412 | def load_optimizer(self): |
1433 | | - if len(self.offload_states) == 0: |
1434 | | - return |
| 1413 | + device = get_current_device() |
1435 | 1414 | if not self.optimizer.state: |
1436 | 1415 | return |
1437 | 1416 | for param_group in self.optimizer.param_groups: |
1438 | 1417 | for param in param_group['params']: |
1439 | 1418 | state = self.optimizer.state[param] |
1440 | 1419 | for key, value in state.items(): |
1441 | 1420 | if isinstance(value, torch.Tensor): |
1442 | | - state[key] = value.to(self.offload_states[key], non_blocking=True) |
1443 | | - self.offload_states.clear() |
| 1421 | + state[key] = value.to(device, non_blocking=True) |
1444 | 1422 |
|
1445 | 1423 | @contextmanager |
1446 | 1424 | def multi_turn_completion_length_context(self): |
|
0 commit comments