Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# Standard
from collections import defaultdict
from typing import Dict, List
from typing import Dict, List, Union
import json
import os
import re
Expand All @@ -32,6 +32,7 @@
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from transformers import PretrainedConfig
from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from huggingface_hub import split_torch_state_dict_into_shards
import torch
import torch.distributed.checkpoint as dcp

Expand Down Expand Up @@ -422,16 +423,32 @@ def _infer_prefixes_and_module_names(
return sd


def save_single_safetensor(
sd: Dict,
save_directory: str,
def save_sharded_safetensors(
state_dict: Dict,
save_directory: str,
metadata: Dict,
max_shard_size: Union[int, str] = "5GB",
):
save_file(
sd,
os.path.join(save_directory, SAFE_WEIGHTS_NAME),
metadata,
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}
# Save the index
with open(
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME),
"w", encoding="utf-8"
) as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

filename_to_tensors = state_dict_split.filename_to_tensors.items()
for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)


# --------------------------- SCRIPT -------------------------
Expand Down Expand Up @@ -522,7 +539,7 @@ def save_single_safetensor(
state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path)

# save it as a safetensors file
save_single_safetensor(
save_sharded_safetensors(
{k: v.contiguous() for k, v in state_dict.items()},
args.output_dir,
metadata={"format": "pt"},
Expand Down