1414
1515# Standard
1616from collections import defaultdict
17- from typing import List
17+ from typing import Dict , List , Union
1818import json
1919import os
2020import re
21+ import shutil
2122
2223# Third Party
2324from accelerate .logging import get_logger
2425from accelerate .utils .constants import FSDP_MODEL_NAME , OPTIMIZER_NAME
26+ from huggingface_hub import split_torch_state_dict_into_shards
27+ from safetensors .torch import load_file , save_file
2528from torch .distributed .checkpoint .default_planner import (
2629 DefaultLoadPlanner ,
2730 DefaultSavePlanner ,
2831)
2932from torch .distributed .checkpoint .state_dict import get_state_dict , set_state_dict
3033from torch .distributed .fsdp .fully_sharded_data_parallel import StateDictType
3134from transformers import PretrainedConfig
35+ from transformers .utils import CONFIG_NAME , SAFE_WEIGHTS_INDEX_NAME , SAFE_WEIGHTS_NAME
3236import torch
3337import torch .distributed .checkpoint as dcp
3438
@@ -213,24 +217,10 @@ def _dict_from_json_file(resolved_config_file):
213217 return os .path .dirname (result )
214218
215219
216- # function to get the ScatterMoE state dict from its DCP checkpoint
217- # - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
218- # to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
219- # can restore the checkpoint to be loaded by the original architecture.
220- def recover_original_state_dict_from_dcp_checkpoint (
220+ # function to get the state dict from dcp_checkpoint
221+ def get_state_dict_from_dcp_checkpoint (
221222 dcp_checkpoint_dir : str ,
222- pretrained_model_name_or_path : str = None ,
223223):
224- """
225- Parameters:
226- dcp_checkpoint_dir (str): the DCP to be converted.
227- pretrained_model_name_or_path (str): Optional, if provided we will
228- use the hints to remap the
229- """
230-
231- # reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
232- # - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap
233-
234224 # guarded, load some internal functions
235225 # pylint: disable=import-outside-toplevel
236226 # Third Party
@@ -245,11 +235,46 @@ def recover_original_state_dict_from_dcp_checkpoint(
245235 planner = _EmptyStateDictLoadPlanner (),
246236 no_dist = True ,
247237 )
248- sd = sd [KEY_MODEL ]
238+ return [KEY_MODEL ]
239+
240+
241+ # function to get state dict from regular checkoint
242+ # - note this assumes sharded safetensors, we do not support
243+ # the non-sharded case for now
244+ def get_state_dict_from_safe_checkpoint (
245+ safe_checkpoint_dir : str ,
246+ ):
247+ # Load the index
248+ safe_index_file = os .path .join (safe_checkpoint_dir , SAFE_WEIGHTS_INDEX_NAME )
249+ with open (safe_index_file , "r" , encoding = "utf-8" ) as f :
250+ index = json .load (f )
251+
252+ sd = {}
253+ shard_files = list (set (index ["weight_map" ].values ()))
254+ for shard_file in shard_files :
255+ for key , v in load_file (os .path .join (safe_checkpoint_dir , shard_file )).items ():
256+ sd [key ] = v
257+
258+ return sd
249259
250- # if not provided
251- if pretrained_model_name_or_path is None :
252- return sd
260+
261+ # function to get the ScatterMoE state dict from its DCP checkpoint
262+ # - if the original pretrained_model_name_or_path is specified, will use the checkpoint as hints
263+ # to map the ScatterMoE checkpoint to that of the original model. This is useful so that we
264+ # can restore the checkpoint to be loaded by the original architecture.
265+ def recover_original_state_dict_from_checkpoint (
266+ sd : Dict ,
267+ pretrained_model_name_or_path : str = None ,
268+ ):
269+ """
270+ Parameters:
271+ dcp_checkpoint_dir (str): the DCP to be converted.
272+ pretrained_model_name_or_path (str): Optional, if provided we will
273+ use the hints to remap the
274+ """
275+
276+ # reference dcp_to_torch_save from torch.distributed.checkpoint.format_utils.py
277+ # - strategy is to use _EmptyStateDictLoadPlanner to populate the state dict, then we remap
253278
254279 # now do the remap
255280 loc = get_resolved_checkpoint_location (pretrained_model_name_or_path )
@@ -398,6 +423,37 @@ def _infer_prefixes_and_module_names(
398423 return sd
399424
400425
426+ def save_sharded_safetensors (
427+ input_state_dict : Dict ,
428+ save_directory : str ,
429+ metadata : Dict ,
430+ max_shard_size : Union [int , str ] = "5GB" ,
431+ ):
432+ filename_pattern = SAFE_WEIGHTS_NAME .replace (".bin" , "{suffix}.bin" ).replace (
433+ ".safetensors" , "{suffix}.safetensors"
434+ )
435+ state_dict_split = split_torch_state_dict_into_shards (
436+ input_state_dict ,
437+ filename_pattern = filename_pattern ,
438+ max_shard_size = max_shard_size ,
439+ )
440+ index = {
441+ "metadata" : state_dict_split .metadata ,
442+ "weight_map" : state_dict_split .tensor_to_filename ,
443+ }
444+ # Save the index
445+ with open (
446+ os .path .join (save_directory , SAFE_WEIGHTS_INDEX_NAME ), "w" , encoding = "utf-8"
447+ ) as f :
448+ content = json .dumps (index , indent = 2 , sort_keys = True ) + "\n "
449+ f .write (content )
450+
451+ filename_to_tensors = state_dict_split .filename_to_tensors .items ()
452+ for shard_file , tensors in filename_to_tensors :
453+ shard = {tensor : input_state_dict [tensor ].contiguous () for tensor in tensors }
454+ save_file (shard , os .path .join (save_directory , shard_file ), metadata = metadata )
455+
456+
401457# --------------------------- SCRIPT -------------------------
402458
403459
@@ -417,8 +473,8 @@ def _infer_prefixes_and_module_names(
417473 )
418474
419475 parser .add_argument (
420- "dcp_checkpoint_dir " ,
421- help = "Path to the distributed checkpoint." ,
476+ "checkpoint_dir " ,
477+ help = "Path to the checkpoint." ,
422478 )
423479
424480 parser .add_argument (
@@ -432,37 +488,62 @@ def _infer_prefixes_and_module_names(
432488 "the original pretrained model checkpoint (from which this "
433489 "checkpoint is obtained)."
434490 ),
491+ default = None ,
435492 )
436493
437494 args = parser .parse_args ()
438495
439- # search for the checkpint. By the code above , it must
496+ # search for an FSDP checkpoint. If it is an FSDP checkpoint , it must
440497 # start with FSDP_MODEL_NAME
441- if args .dcp_checkpoint_dir .startswith (FSDP_MODEL_NAME ):
442- checkpoint_dir = args .dcp_checkpoint_dir
498+ if args .checkpoint_dir .startswith (FSDP_MODEL_NAME ):
499+ checkpoint_dir = args .checkpoint_dir
500+ loader = get_state_dict_from_dcp_checkpoint
443501 else :
444502 checkpoint_dir = [
445503 x
446- for x in os .listdir (args .dcp_checkpoint_dir )
447- if os .path .isdir (os .path .join (args .dcp_checkpoint_dir , x ))
504+ for x in os .listdir (args .checkpoint_dir )
505+ if os .path .isdir (os .path .join (args .checkpoint_dir , x ))
448506 and x .startswith (FSDP_MODEL_NAME )
449507 ]
450- if len (checkpoint_dir ) > 1 :
508+ if len (checkpoint_dir ) == 1 :
509+ checkpoint_dir = os .path .join (args .checkpoint_dir , checkpoint_dir [0 ])
510+ loader = get_state_dict_from_dcp_checkpoint
511+ elif len (checkpoint_dir ) > 1 :
451512 raise ValueError (
452- f"Found > 1 dirs in dcp checkpoint dir { args .dcp_checkpoint_dir } "
513+ f"Found > 1 dirs in dcp checkpoint dir { args .checkpoint_dir } "
453514 f"that starts with { FSDP_MODEL_NAME } . Please spectify the exact dir."
454515 )
455- if len (checkpoint_dir ) == 0 :
456- raise ValueError (
457- f"Found no dirs in dcp checkpoint dir { args .dcp_checkpoint_dir } "
458- f"that starts with { FSDP_MODEL_NAME } . Nothing to convert"
459- )
460- checkpoint_dir = os .path .join (args .dcp_checkpoint_dir , checkpoint_dir [0 ])
461-
462- # get the converted statedict
463- state_dict = recover_original_state_dict_from_dcp_checkpoint (
464- checkpoint_dir , args .pretrained_model_name_or_path
516+ else :
517+ # then take it as a safetensors checkpoint
518+ # - do not support .bin checkpoints
519+ checkpoint_dir = args .checkpoint_dir
520+ loader = get_state_dict_from_safe_checkpoint
521+
522+ # - pretrained model name
523+ _name_or_path = args .pretrained_model_name_or_path
524+
525+ # assume output directory exists, we do not create it
526+ # - copy the config file if exists
527+ config_file = os .path .join (checkpoint_dir , CONFIG_NAME )
528+ target_config_file = os .path .join (args .output_dir , CONFIG_NAME )
529+ if os .path .exists (config_file ):
530+ shutil .copyfile (config_file , target_config_file )
531+
532+ # try to populate pretrained_model_name_or_path from the config path
533+ # if it was None
534+ if not _name_or_path :
535+ with open (target_config_file , "r" , encoding = "utf-8" ) as file :
536+ _name_or_path = json .load (file ).get ("_name_or_path" )
537+
538+ # get the state_dict
539+ state_dict = loader (checkpoint_dir )
540+
541+ # recover the original state dict
542+ state_dict = recover_original_state_dict_from_checkpoint (state_dict , _name_or_path )
543+
544+ # save it as a safetensors file
545+ save_sharded_safetensors (
546+ {k : v .contiguous () for k , v in state_dict .items ()},
547+ args .output_dir ,
548+ metadata = {"format" : "pt" },
465549 )
466-
467- # save it
468- torch .save (state_dict , args .output_dir )
0 commit comments