|
18 | 18 | FP16MixedPrecisionMixin,
|
19 | 19 | MixedPrecisionMixin,
|
20 | 20 | )
|
| 21 | +from colossalai.checkpoint_io.utils import calculate_tensor_size |
21 | 22 | from colossalai.interface import OptimizerWrapper
|
22 | 23 | from colossalai.logging import get_dist_logger
|
23 | 24 | from colossalai.quantization.fp8 import all_gather_fp8, all_reduce_fp8, reduce_scatter_fp8
|
@@ -865,19 +866,17 @@ def state_dict_shard(
|
865 | 866 |
|
866 | 867 | for k, v in states.items():
|
867 | 868 | if isinstance(v, torch.Tensor) and k != "step":
|
868 |
| - if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: |
869 |
| - pinned_state_dicts[param_idx][k] = torch.empty_like( |
870 |
| - working_param, pin_memory=True, device="cpu" |
871 |
| - ) |
872 | 869 | state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
873 | 870 | all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
874 | 871 | state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
|
| 872 | + if pinned_state_dicts and k not in pinned_state_dicts[param_idx]: |
| 873 | + pinned_state_dicts[param_idx][k] = torch.empty_like(state_tensor, pin_memory=True, device="cpu") |
875 | 874 | if pinned_state_dicts:
|
876 | 875 | pinned_state_dicts[param_idx][k].copy_(state_tensor)
|
877 | 876 | current_block[k] = pinned_state_dicts[param_idx][k]
|
878 | 877 | else:
|
879 | 878 | current_block[k] = state_tensor.cpu()
|
880 |
| - current_block_size += state_tensor.numel() |
| 879 | + current_block_size += calculate_tensor_size(state_tensor) |
881 | 880 |
|
882 | 881 | if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
883 | 882 | yield ret_block, ret_block_size
|
|
0 commit comments