|
1 | 1 | import copy
|
2 | 2 | import logging
|
3 | 3 | import os
|
| 4 | +from collections import defaultdict |
4 | 5 | from functools import reduce
|
5 | 6 | from pathlib import Path
|
6 | 7 | from shutil import rmtree
|
|
10 | 11 | import torch.distributed as dist
|
11 | 12 | import torch.nn as nn
|
12 | 13 | from torch.distributed import ProcessGroup
|
| 14 | +from torch.optim import Optimizer |
13 | 15 | from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
14 | 16 | from torch.utils._pytree import tree_map
|
15 | 17 |
|
|
37 | 39 | load_shard_state_dict,
|
38 | 40 | load_state_dict,
|
39 | 41 | load_state_dict_into_model,
|
40 |
| - load_states_into_optimizer, |
41 | 42 | save_config_file,
|
42 | 43 | save_param_groups,
|
43 | 44 | save_state_dict,
|
@@ -724,26 +725,37 @@ def _get_param_id_from_optimizer_param(
|
724 | 725 | state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
725 | 726 | if not low_cpu_mem_mode:
|
726 | 727 | state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
|
727 |
| - load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) |
| 728 | + self.load_states_into_optimizer(optimizer, state_dict, id_map) |
728 | 729 | loaded_file.add(filename)
|
729 | 730 |
|
730 |
| - # Then shard the loaded optimizer states if using tp/zero. |
731 |
| - for param, state in optimizer.optim.state.items(): |
732 |
| - device = param.device |
733 |
| - if master_to_working_map is not None: |
734 |
| - working_param = master_to_working_map[id(param)] |
735 |
| - else: |
736 |
| - working_param = param |
737 |
| - original_shape = optimizer.param_info["param2shape"][id(working_param)] |
738 |
| - sharded_state = self.shard_from_complete_optimizer_state( |
739 |
| - state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True |
740 |
| - ) |
741 |
| - optimizer.optim.state[param] = sharded_state |
742 |
| - |
743 | 731 | sharded_optimizer_loading_epilogue(optimizer.optim)
|
744 | 732 | if self.verbose and self.coordinator.is_master():
|
745 | 733 | logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
746 | 734 |
|
| 735 | + def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict): |
| 736 | + state_dict = {int(k): v for k, v in state_dict.items()} |
| 737 | + new_states = defaultdict(dict) |
| 738 | + master_to_working_map = optimizer.get_master_to_working_map() |
| 739 | + for k, state in state_dict.items(): |
| 740 | + if k in id_map: |
| 741 | + param = id_map[k] |
| 742 | + device = param.device |
| 743 | + dtype = param.dtype |
| 744 | + if master_to_working_map is not None: |
| 745 | + working_param = master_to_working_map[id(param)] |
| 746 | + else: |
| 747 | + working_param = param |
| 748 | + original_shape = optimizer.param_info["param2shape"][id(working_param)] |
| 749 | + new_states[param] = self.shard_from_complete_optimizer_state( |
| 750 | + state, |
| 751 | + current_shape=working_param.shape, |
| 752 | + original_shape=original_shape, |
| 753 | + device=device, |
| 754 | + dtype=dtype, |
| 755 | + inplace=True, |
| 756 | + ) |
| 757 | + optimizer.optim.state.update(new_states) |
| 758 | + |
747 | 759 | def save_unsharded_model(
|
748 | 760 | self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
|
749 | 761 | ):
|
@@ -988,22 +1000,7 @@ def _get_param_id_from_optimizer_param(
|
988 | 1000 | for param in pg["params"]:
|
989 | 1001 | param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
990 | 1002 | id_map[param_id] = param
|
991 |
| - load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True) |
992 |
| - |
993 |
| - # Then shard the loaded optimizer states if using tp/zero. |
994 |
| - for param, state in optimizer.optim.state.items(): |
995 |
| - if param is None: |
996 |
| - continue |
997 |
| - device = param.device |
998 |
| - if master_to_working_map is not None: |
999 |
| - working_param = master_to_working_map[id(param)] |
1000 |
| - else: |
1001 |
| - working_param = param |
1002 |
| - original_shape = optimizer.param_info["param2shape"][id(working_param)] |
1003 |
| - sharded_state = self.shard_from_complete_optimizer_state( |
1004 |
| - state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True |
1005 |
| - ) |
1006 |
| - optimizer.optim.state[param] = sharded_state |
| 1003 | + self.load_states_into_optimizer(optimizer, state_dict["state"], id_map) |
1007 | 1004 |
|
1008 | 1005 | sharded_optimizer_loading_epilogue(optimizer.optim)
|
1009 | 1006 |
|
@@ -1086,6 +1083,7 @@ def shard_from_complete_optimizer_state(
|
1086 | 1083 | current_shape: torch.Size,
|
1087 | 1084 | original_shape: torch.Size,
|
1088 | 1085 | device: torch.device,
|
| 1086 | + dtype: torch.dtype, |
1089 | 1087 | inplace: bool,
|
1090 | 1088 | ) -> OrderedDict:
|
1091 | 1089 | """
|
@@ -1135,7 +1133,7 @@ def shard_from_complete_optimizer_state(
|
1135 | 1133 | slice_size = v.numel() // self.global_dp_size
|
1136 | 1134 | v = v.split(slice_size, dim=0)[self.dp_rank]
|
1137 | 1135 |
|
1138 |
| - state_[k] = v.detach().clone().to(device) |
| 1136 | + state_[k] = v.detach().clone().to(device=device, dtype=dtype) |
1139 | 1137 |
|
1140 | 1138 | return state_
|
1141 | 1139 |
|
|
0 commit comments