Skip to content

Commit 5fa657f

Browse files
committed
[checkpointio] fix size compute
1 parent eb69e64 commit 5fa657f

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

colossalai/zero/low_level/low_level_optim.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
FP16MixedPrecisionMixin,
1919
MixedPrecisionMixin,
2020
)
21+
from colossalai.checkpoint_io.utils import calculate_tensor_size
2122
from colossalai.interface import OptimizerWrapper
2223
from colossalai.logging import get_dist_logger
2324
from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
@@ -865,19 +866,17 @@ def state_dict_shard(
865866

866867
for k, v in states.items():
867868
if isinstance(v, torch.Tensor) and k != "step":
868-
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
869-
pinned_state_dicts[param_idx][k] = torch.empty_like(
870-
working_param, pin_memory=True, device="cpu"
871-
)
872869
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
873870
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
874871
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
872+
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
873+
pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu")
875874
if pinned_state_dicts:
876875
pinned_state_dicts[param_idx][k].copy_(state_tensor)
877876
current_block[k] = pinned_state_dicts[param_idx][k]
878877
else:
879878
current_block[k] = state_tensor.cpu()
880-
current_block_size += state_tensor.numel()
879+
current_block_size += calculate_tensor_size(state_tensor)
881880

882881
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
883882
yield ret_block, ret_block_size

0 commit comments

Comments
 (0)