Skip to content

Commit 184a653

Browse files
committed
[checkpointio] fix pinned state dict
1 parent 5fa657f commit 184a653

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

colossalai/zero/low_level/low_level_optim.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -780,19 +780,19 @@ def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tens
780780
zero_state = dict()
781781
device = get_accelerator().get_current_device()
782782
for param, state in self.optim.state.items():
783-
if pinned_state_dicts and param not in pinned_state_dicts:
783+
if pinned_state_dicts is not None and param not in pinned_state_dicts:
784784
pinned_state_dicts[param] = {}
785785
zero_state[param] = copy.deepcopy(state)
786786
for k, v in state.items():
787787
if isinstance(v, torch.Tensor) and k != "step":
788-
if pinned_state_dicts and k not in pinned_state_dicts[param]:
789-
pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu")
790788
working_param = self.master_to_working_param[id(param)]
791789
pg = self.param_to_pg[working_param]
792790
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
793791
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
794792
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
795-
if pinned_state_dicts:
793+
if pinned_state_dicts is not None and k not in pinned_state_dicts[param]:
794+
pinned_state_dicts[param][k] = torch.empty_like(param_state, pin_memory=True, device="cpu")
795+
if pinned_state_dicts is not None:
796796
pinned_state_dicts[param][k].copy_(param_state)
797797
zero_state[param][k] = pinned_state_dicts[param][k]
798798
else:
@@ -858,7 +858,7 @@ def state_dict_shard(
858858
for param_idx, states in local_states.items():
859859
current_block_size = 0
860860
current_block = copy.deepcopy(states)
861-
if pinned_state_dicts and param_idx not in pinned_state_dicts:
861+
if pinned_state_dicts is not None and param_idx not in pinned_state_dicts:
862862
pinned_state_dicts[param_idx] = {}
863863
master_param = idx2master[param_idx]
864864
working_param = self.master_to_working_param[id(master_param)]
@@ -869,9 +869,9 @@ def state_dict_shard(
869869
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
870870
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
871871
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
872-
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
872+
if pinned_state_dicts is not None and k not in pinned_state_dicts[param_idx]:
873873
pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu")
874-
if pinned_state_dicts:
874+
if pinned_state_dicts is not None:
875875
pinned_state_dicts[param_idx][k].copy_(state_tensor)
876876
current_block[k] = pinned_state_dicts[param_idx][k]
877877
else:

0 commit comments

Comments
 (0)