Skip to content

Commit 5ff5323

Browse files
authored
[hotfix] fix zero optim save (#6191)
1 parent 014837e commit 5ff5323

File tree

1 file changed

+71
-61
lines changed

1 file changed

+71
-61
lines changed

colossalai/zero/low_level/low_level_optim.py

Lines changed: 71 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -786,30 +786,36 @@ def state_dict(
786786
"""
787787
zero_state = dict()
788788
device = get_accelerator().get_current_device()
789-
for param, state in self.optim.state.items():
790-
working_param = self.master_to_working_param[id(param)]
791-
pg = self.param_to_pg[working_param]
792-
if not only_on_master or get_nd_rank(pg) == 0:
793-
zero_state[param] = copy.deepcopy(state)
794-
else:
795-
zero_state[param] = {}
796-
797-
if pinned_state_dicts is not None and param not in pinned_state_dicts:
798-
pinned_state_dicts[param] = {}
799-
800-
for k, v in state.items():
801-
if isinstance(v, torch.Tensor) and k != "step":
802-
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
803-
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
804-
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
805-
if not only_on_master or get_nd_rank(pg) == 0:
806-
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
807-
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
808-
if pinned_state_dicts is not None:
809-
pinned_state_dicts[param][k].copy_(param_state)
810-
zero_state[param][k] = pinned_state_dicts[param][k]
811-
else:
812-
zero_state[param][k] = param_state.cpu()
789+
for param_group in self.optim.param_groups:
790+
for param in param_group["params"]:
791+
if param not in self.optim.state:
792+
continue
793+
state = self.optim.state[param]
794+
working_param = self.master_to_working_param[id(param)]
795+
pg = self.param_to_pg[working_param]
796+
if not only_on_master or get_nd_rank(pg) == 0:
797+
zero_state[param] = copy.deepcopy(state)
798+
else:
799+
zero_state[param] = {}
800+
801+
if pinned_state_dicts is not None and param not in pinned_state_dicts:
802+
pinned_state_dicts[param] = {}
803+
804+
for k, v in state.items():
805+
if isinstance(v, torch.Tensor) and k != "step":
806+
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
807+
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
808+
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
809+
if not only_on_master or get_nd_rank(pg) == 0:
810+
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
811+
pinned_state_dicts[param][k] = torch.empty_like(
812+
param_state, pin_memory=True, device="cpu"
813+
)
814+
if pinned_state_dicts is not None:
815+
pinned_state_dicts[param][k].copy_(param_state)
816+
zero_state[param][k] = pinned_state_dicts[param][k]
817+
else:
818+
zero_state[param][k] = param_state.cpu()
813819

814820
states_dict = self._pack_state(zero_state)
815821

@@ -865,48 +871,52 @@ def state_dict_shard(
865871
device = get_accelerator().get_current_device()
866872
local_states = self.optim.state_dict()["state"]
867873

868-
idx2master = {}
874+
master2idx = {}
869875
cnt = 0
870876
for param_group in self.optim.param_groups:
871877
for param in param_group["params"]:
872-
idx2master[cnt] = param
878+
master2idx[param] = cnt
873879
cnt += 1
874-
for param_idx, states in local_states.items():
875-
current_block_size = 0
876-
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
877-
pinned_state_dicts[param_idx] = {}
878-
master_param = idx2master[param_idx]
879-
working_param = self.master_to_working_param[id(master_param)]
880-
pg = self.param_to_pg[working_param]
881-
if not only_on_master or get_nd_rank(pg) == 0:
882-
current_block = copy.deepcopy(states)
883-
else:
884-
current_block = {}
885-
886-
for k, v in states.items():
887-
if isinstance(v, torch.Tensor) and k != "step":
888-
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
889-
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
890-
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
891-
if not only_on_master or get_nd_rank(pg) == 0:
892-
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
893-
pinned_state_dicts[param_idx][k] = torch.empty_like(
894-
state_tensor, pin_memory=True, device="cpu"
895-
)
896-
if pinned_state_dicts is not None:
897-
pinned_state_dicts[param_idx][k].copy_(state_tensor)
898-
current_block[k] = pinned_state_dicts[param_idx][k]
899-
else:
900-
current_block[k] = state_tensor.cpu()
901-
current_block_size += calculate_tensor_size(state_tensor)
902-
903-
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
904-
yield ret_block, ret_block_size
905-
ret_block = dict()
906-
ret_block_size = 0
907880

908-
ret_block[param_idx] = current_block
909-
ret_block_size += current_block_size
881+
for param_group in self.optim.param_groups:
882+
for master_param in param_group["params"]:
883+
param_idx = master2idx[master_param]
884+
states = local_states[param_idx]
885+
886+
current_block_size = 0
887+
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
888+
pinned_state_dicts[param_idx] = {}
889+
working_param = self.master_to_working_param[id(master_param)]
890+
pg = self.param_to_pg[working_param]
891+
if not only_on_master or get_nd_rank(pg) == 0:
892+
current_block = copy.deepcopy(states)
893+
else:
894+
current_block = {}
895+
896+
for k, v in states.items():
897+
if isinstance(v, torch.Tensor) and k != "step":
898+
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
899+
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
900+
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
901+
if not only_on_master or get_nd_rank(pg) == 0:
902+
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
903+
pinned_state_dicts[param_idx][k] = torch.empty_like(
904+
state_tensor, pin_memory=True, device="cpu"
905+
)
906+
if pinned_state_dicts is not None:
907+
pinned_state_dicts[param_idx][k].copy_(state_tensor)
908+
current_block[k] = pinned_state_dicts[param_idx][k]
909+
else:
910+
current_block[k] = state_tensor.cpu()
911+
current_block_size += calculate_tensor_size(state_tensor)
912+
913+
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
914+
yield ret_block, ret_block_size
915+
ret_block = dict()
916+
ret_block_size = 0
917+
918+
ret_block[param_idx] = current_block
919+
ret_block_size += current_block_size
910920

911921
yield ret_block, ret_block_size
912922

0 commit comments

Comments
 (0)