Skip to content

Commit 702de2f

Browse files
willmjfabianlim
andauthored
feat: Checkpoint utils safetensors (#116)
* checkpoint conversion handle non-dcp case Signed-off-by: Yu Chin Fabian Lim <[email protected]> * improvements Signed-off-by: Yu Chin Fabian Lim <[email protected]> * fix: sharded safetensors save Signed-off-by: Will Johnson <[email protected]> * fix: lint Signed-off-by: Will Johnson <[email protected]> * fmt Signed-off-by: Will Johnson <[email protected]> --------- Signed-off-by: Yu Chin Fabian Lim <[email protected]> Signed-off-by: Will Johnson <[email protected]> Co-authored-by: Yu Chin Fabian Lim <[email protected]>
1 parent 733992a commit 702de2f

File tree

1 file changed

+124
-43
lines changed

1 file changed

+124
-43
lines changed

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 124 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,25 @@
1414

1515
# Standard
1616
from collections import defaultdict
17-
from typing import List
17+
from typing import Dict, List, Union
1818
import json
1919
import os
2020
import re
21+
import shutil
2122

2223
# Third Party
2324
from accelerate.logging import get_logger
2425
from accelerate.utils.constants import FSDP_MODEL_NAME, OPTIMIZER_NAME
26+
from huggingface_hub import split_torch_state_dict_into_shards
27+
from safetensors.torch import load_file, save_file
2528
from torch.distributed.checkpoint.default_planner import (
2629
DefaultLoadPlanner,
2730
DefaultSavePlanner,
2831
)
2932
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
3033
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
3134
from transformers import PretrainedConfig
35+
from transformers.utils import CONFIG_NAME, SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
3236
import torch
3337
import torch.distributed.checkpoint as dcp
3438

@@ -213,24 +217,10 @@ def _dict_from_json_file(resolved_config_file):
213217
return os.path.dirname(result)
214218

215219

216-
# function to get the ScatterMoE state dict from its DCP checkpoint
217-
# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
218-
# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
219-
# can restore the checkpoint to be loaded by the original architecture.
220-
def recover_original_state_dict_from_dcp_checkpoint(
220+
# function to get the state dict from dcp_checkpoint
221+
def get_state_dict_from_dcp_checkpoint(
221222
dcp_checkpoint_dir: str,
222-
pretrained_model_name_or_path: str = None,
223223
):
224-
"""
225-
Parameters:
226-
dcp_checkpoint_dir (str): the DCP to be converted.
227-
pretrained_model_name_or_path (str): Optional, if provided we will
228-
use the hints to remap the
229-
"""
230-
231-
# reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
232-
# - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap
233-
234224
# guarded, load some internal functions
235225
# pylint: disable=import-outside-toplevel
236226
# Third Party
@@ -245,11 +235,46 @@ def recover_original_state_dict_from_dcp_checkpoint(
245235
planner=_EmptyStateDictLoadPlanner(),
246236
no_dist=True,
247237
)
248-
sd = sd[KEY_MODEL]
238+
return [KEY_MODEL]
239+
240+
241+
# function to get state dict from regular checkoint
242+
# - note this assumes sharded safetensors, we do not support
243+
# the non-sharded case for now
244+
def get_state_dict_from_safe_checkpoint(
245+
safe_checkpoint_dir: str,
246+
):
247+
# Load the index
248+
safe_index_file = os.path.join(safe_checkpoint_dir, SAFE_WEIGHTS_INDEX_NAME)
249+
with open(safe_index_file, "r", encoding="utf-8") as f:
250+
index = json.load(f)
251+
252+
sd = {}
253+
shard_files = list(set(index["weight_map"].values()))
254+
for shard_file in shard_files:
255+
for key, v in load_file(os.path.join(safe_checkpoint_dir, shard_file)).items():
256+
sd[key] = v
257+
258+
return sd
249259

250-
# if not provided
251-
if pretrained_model_name_or_path is None:
252-
return sd
260+
261+
# function to get the ScatterMoE state dict from its DCP checkpoint
262+
# - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
263+
# to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
264+
# can restore the checkpoint to be loaded by the original architecture.
265+
def recover_original_state_dict_from_checkpoint(
266+
sd: Dict,
267+
pretrained_model_name_or_path: str = None,
268+
):
269+
"""
270+
Parameters:
271+
dcp_checkpoint_dir (str): the DCP to be converted.
272+
pretrained_model_name_or_path (str): Optional, if provided we will
273+
use the hints to remap the
274+
"""
275+
276+
# reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
277+
# - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap
253278

254279
# now do the remap
255280
loc = get_resolved_checkpoint_location(pretrained_model_name_or_path)
@@ -398,6 +423,37 @@ def _infer_prefixes_and_module_names(
398423
return sd
399424

400425

426+
def save_sharded_safetensors(
427+
input_state_dict: Dict,
428+
save_directory: str,
429+
metadata: Dict,
430+
max_shard_size: Union[int, str] = "5GB",
431+
):
432+
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
433+
".safetensors", "{suffix}.safetensors"
434+
)
435+
state_dict_split = split_torch_state_dict_into_shards(
436+
input_state_dict,
437+
filename_pattern=filename_pattern,
438+
max_shard_size=max_shard_size,
439+
)
440+
index = {
441+
"metadata": state_dict_split.metadata,
442+
"weight_map": state_dict_split.tensor_to_filename,
443+
}
444+
# Save the index
445+
with open(
446+
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
447+
) as f:
448+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
449+
f.write(content)
450+
451+
filename_to_tensors = state_dict_split.filename_to_tensors.items()
452+
for shard_file, tensors in filename_to_tensors:
453+
shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors}
454+
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
455+
456+
401457
# --------------------------- SCRIPT -------------------------
402458

403459

@@ -417,8 +473,8 @@ def _infer_prefixes_and_module_names(
417473
)
418474

419475
parser.add_argument(
420-
"dcp_checkpoint_dir",
421-
help="Path to the distributed checkpoint.",
476+
"checkpoint_dir",
477+
help="Path to the checkpoint.",
422478
)
423479

424480
parser.add_argument(
@@ -432,37 +488,62 @@ def _infer_prefixes_and_module_names(
432488
"the original pretrained model checkpoint (from which this "
433489
"checkpoint is obtained)."
434490
),
491+
default=None,
435492
)
436493

437494
args = parser.parse_args()
438495

439-
# search for the checkpint. By the code above, it must
496+
# search for an FSDP checkpoint. If it is an FSDP checkpoint, it must
440497
# start with FSDP_MODEL_NAME
441-
if args.dcp_checkpoint_dir.startswith(FSDP_MODEL_NAME):
442-
checkpoint_dir = args.dcp_checkpoint_dir
498+
if args.checkpoint_dir.startswith(FSDP_MODEL_NAME):
499+
checkpoint_dir = args.checkpoint_dir
500+
loader = get_state_dict_from_dcp_checkpoint
443501
else:
444502
checkpoint_dir = [
445503
x
446-
for x in os.listdir(args.dcp_checkpoint_dir)
447-
if os.path.isdir(os.path.join(args.dcp_checkpoint_dir, x))
504+
for x in os.listdir(args.checkpoint_dir)
505+
if os.path.isdir(os.path.join(args.checkpoint_dir, x))
448506
and x.startswith(FSDP_MODEL_NAME)
449507
]
450-
if len(checkpoint_dir) > 1:
508+
if len(checkpoint_dir) == 1:
509+
checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_dir[0])
510+
loader = get_state_dict_from_dcp_checkpoint
511+
elif len(checkpoint_dir) > 1:
451512
raise ValueError(
452-
f"Found > 1 dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} "
513+
f"Found > 1 dirs in dcp checkpoint dir {args.checkpoint_dir} "
453514
f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir."
454515
)
455-
if len(checkpoint_dir) == 0:
456-
raise ValueError(
457-
f"Found no dirs in dcp checkpoint dir {args.dcp_checkpoint_dir} "
458-
f"that starts with {FSDP_MODEL_NAME}. Nothing to convert"
459-
)
460-
checkpoint_dir = os.path.join(args.dcp_checkpoint_dir, checkpoint_dir[0])
461-
462-
# get the converted statedict
463-
state_dict = recover_original_state_dict_from_dcp_checkpoint(
464-
checkpoint_dir, args.pretrained_model_name_or_path
516+
else:
517+
# then take it as a safetensors checkpoint
518+
# - do not support .bin checkpoints
519+
checkpoint_dir = args.checkpoint_dir
520+
loader = get_state_dict_from_safe_checkpoint
521+
522+
# - pretrained model name
523+
_name_or_path = args.pretrained_model_name_or_path
524+
525+
# assume output directory exists, we do not create it
526+
# - copy the config file if exists
527+
config_file = os.path.join(checkpoint_dir, CONFIG_NAME)
528+
target_config_file = os.path.join(args.output_dir, CONFIG_NAME)
529+
if os.path.exists(config_file):
530+
shutil.copyfile(config_file, target_config_file)
531+
532+
# try to populate pretrained_model_name_or_path from the config path
533+
# if it was None
534+
if not _name_or_path:
535+
with open(target_config_file, "r", encoding="utf-8") as file:
536+
_name_or_path = json.load(file).get("_name_or_path")
537+
538+
# get the state_dict
539+
state_dict = loader(checkpoint_dir)
540+
541+
# recover the original state dict
542+
state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path)
543+
544+
# save it as a safetensors file
545+
save_sharded_safetensors(
546+
{k: v.contiguous() for k, v in state_dict.items()},
547+
args.output_dir,
548+
metadata={"format": "pt"},
465549
)
466-
467-
# save it
468-
torch.save(state_dict, args.output_dir)

0 commit comments

Comments
 (0)