Skip to content

Commit 5c09d72

Browse files
[checkpointio] fix checkpoint for 3d (#6187)
* fix checkpoint io for 3d * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update hybrid_parallel_checkpoint_io.py * fix --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 2b415e5 commit 5c09d72

File tree

1 file changed

+30
-32
lines changed

1 file changed

+30
-32
lines changed

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import logging
33
import os
4+
from collections import defaultdict
45
from functools import reduce
56
from pathlib import Path
67
from shutil import rmtree
@@ -10,6 +11,7 @@
1011
import torch.distributed as dist
1112
import torch.nn as nn
1213
from torch.distributed import ProcessGroup
14+
from torch.optim import Optimizer
1315
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
1416
from torch.utils._pytree import tree_map
1517

@@ -37,7 +39,6 @@
3739
load_shard_state_dict,
3840
load_state_dict,
3941
load_state_dict_into_model,
40-
load_states_into_optimizer,
4142
save_config_file,
4243
save_param_groups,
4344
save_state_dict,
@@ -724,26 +725,37 @@ def _get_param_id_from_optimizer_param(
724725
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
725726
if not low_cpu_mem_mode:
726727
state_dict = create_pinned_state_dict(state_dict, empty=False, num_threads=num_threads)
727-
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
728+
self.load_states_into_optimizer(optimizer, state_dict, id_map)
728729
loaded_file.add(filename)
729730

730-
# Then shard the loaded optimizer states if using tp/zero.
731-
for param, state in optimizer.optim.state.items():
732-
device = param.device
733-
if master_to_working_map is not None:
734-
working_param = master_to_working_map[id(param)]
735-
else:
736-
working_param = param
737-
original_shape = optimizer.param_info["param2shape"][id(working_param)]
738-
sharded_state = self.shard_from_complete_optimizer_state(
739-
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
740-
)
741-
optimizer.optim.state[param] = sharded_state
742-
743731
sharded_optimizer_loading_epilogue(optimizer.optim)
744732
if self.verbose and self.coordinator.is_master():
745733
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
746734

735+
def load_states_into_optimizer(self, optimizer: Optimizer, state_dict: dict, id_map: dict):
736+
state_dict = {int(k): v for k, v in state_dict.items()}
737+
new_states = defaultdict(dict)
738+
master_to_working_map = optimizer.get_master_to_working_map()
739+
for k, state in state_dict.items():
740+
if k in id_map:
741+
param = id_map[k]
742+
device = param.device
743+
dtype = param.dtype
744+
if master_to_working_map is not None:
745+
working_param = master_to_working_map[id(param)]
746+
else:
747+
working_param = param
748+
original_shape = optimizer.param_info["param2shape"][id(working_param)]
749+
new_states[param] = self.shard_from_complete_optimizer_state(
750+
state,
751+
current_shape=working_param.shape,
752+
original_shape=original_shape,
753+
device=device,
754+
dtype=dtype,
755+
inplace=True,
756+
)
757+
optimizer.optim.state.update(new_states)
758+
747759
def save_unsharded_model(
748760
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
749761
):
@@ -988,22 +1000,7 @@ def _get_param_id_from_optimizer_param(
9881000
for param in pg["params"]:
9891001
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
9901002
id_map[param_id] = param
991-
load_states_into_optimizer(optimizer.optim, state_dict["state"], id_map, strict=True)
992-
993-
# Then shard the loaded optimizer states if using tp/zero.
994-
for param, state in optimizer.optim.state.items():
995-
if param is None:
996-
continue
997-
device = param.device
998-
if master_to_working_map is not None:
999-
working_param = master_to_working_map[id(param)]
1000-
else:
1001-
working_param = param
1002-
original_shape = optimizer.param_info["param2shape"][id(working_param)]
1003-
sharded_state = self.shard_from_complete_optimizer_state(
1004-
state, current_shape=working_param.shape, original_shape=original_shape, device=device, inplace=True
1005-
)
1006-
optimizer.optim.state[param] = sharded_state
1003+
self.load_states_into_optimizer(optimizer, state_dict["state"], id_map)
10071004

10081005
sharded_optimizer_loading_epilogue(optimizer.optim)
10091006

@@ -1086,6 +1083,7 @@ def shard_from_complete_optimizer_state(
10861083
current_shape: torch.Size,
10871084
original_shape: torch.Size,
10881085
device: torch.device,
1086+
dtype: torch.dtype,
10891087
inplace: bool,
10901088
) -> OrderedDict:
10911089
"""
@@ -1135,7 +1133,7 @@ def shard_from_complete_optimizer_state(
11351133
slice_size = v.numel() // self.global_dp_size
11361134
v = v.split(slice_size, dim=0)[self.dp_rank]
11371135

1138-
state_[k] = v.detach().clone().to(device)
1136+
state_[k] = v.detach().clone().to(device=device, dtype=dtype)
11391137

11401138
return state_
11411139

0 commit comments

Comments
 (0)