diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py index 660d2252..a3f2117d 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/__init__.py @@ -13,7 +13,10 @@ # limitations under the License. # Local -from .checkpoint_utils import patch_huggingface_save_and_load_for_dtensors +from .checkpoint_utils import ( + patch_huggingface_save_and_load_for_dtensors, + recover_safetensors_from_dcp, +) from .scattermoe_prepare import prepare_scattermoe # this is a special patch function to disable foreach for diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 7c5c54b1..6292b002 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -457,75 +457,38 @@ def save_sharded_safetensors( # --------------------------- SCRIPT ------------------------- -# have it serve as a conversion script -if __name__ == "__main__": - # Standard - import argparse - - parser = argparse.ArgumentParser( - description=( - "Utility for converting ScatterMoE checkpoint back to the " - "orginal state dict format. " - "The ScatterMoE checkpoint was saved after the pretrained model " - "had been converted by a module swap, hence the state dict will " - "no longer resemble the original. This utility creaes" - ) - ) - - parser.add_argument( - "checkpoint_dir", - help="Path to the checkpoint.", - ) - - parser.add_argument( - "output_dir", help="Path to the location to write the converted checkpoint." - ) - - parser.add_argument( - "pretrained_model_name_or_path", - help=( - "In order to reconstruct the state dict, we requre hints from " - "the original pretrained model checkpoint (from which this " - "checkpoint is obtained)." - ), - default=None, - ) - - args = parser.parse_args() - - # search for an FSDP checkpoint. If it is an FSDP checkpoint, it must - # start with FSDP_MODEL_NAME - if args.checkpoint_dir.startswith(FSDP_MODEL_NAME): - checkpoint_dir = args.checkpoint_dir +def recover_safetensors_from_dcp( + checkpoint_dir, pretrained_model_name_or_path, output_dir +): + if checkpoint_dir.startswith(FSDP_MODEL_NAME): loader = get_state_dict_from_dcp_checkpoint else: - checkpoint_dir = [ + fsdp_checkpoint_dirs = [ x - for x in os.listdir(args.checkpoint_dir) - if os.path.isdir(os.path.join(args.checkpoint_dir, x)) + for x in os.listdir(checkpoint_dir) + if os.path.isdir(os.path.join(checkpoint_dir, x)) and x.startswith(FSDP_MODEL_NAME) ] - if len(checkpoint_dir) == 1: - checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_dir[0]) + if len(fsdp_checkpoint_dirs) == 1: + checkpoint_dir = os.path.join(checkpoint_dir, fsdp_checkpoint_dirs[0]) loader = get_state_dict_from_dcp_checkpoint - elif len(checkpoint_dir) > 1: + elif len(fsdp_checkpoint_dirs) > 1: raise ValueError( - f"Found > 1 dirs in dcp checkpoint dir {args.checkpoint_dir} " + f"Found > 1 dirs in dcp checkpoint dir {checkpoint_dir} " f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir." ) else: # then take it as a safetensors checkpoint # - do not support .bin checkpoints - checkpoint_dir = args.checkpoint_dir loader = get_state_dict_from_safe_checkpoint # - pretrained model name - _name_or_path = args.pretrained_model_name_or_path + _name_or_path = pretrained_model_name_or_path # assume output directory exists, we do not create it # - copy the config file if exists config_file = os.path.join(checkpoint_dir, CONFIG_NAME) - target_config_file = os.path.join(args.output_dir, CONFIG_NAME) + target_config_file = os.path.join(output_dir, CONFIG_NAME) if os.path.exists(config_file): shutil.copyfile(config_file, target_config_file) @@ -544,6 +507,46 @@ def save_sharded_safetensors( # save it as a safetensors file save_sharded_safetensors( {k: v.contiguous() for k, v in state_dict.items()}, - args.output_dir, + output_dir, metadata={"format": "pt"}, ) + + +# have it serve as a conversion script +if __name__ == "__main__": + # Standard + import argparse + + parser = argparse.ArgumentParser( + description=( + "Utility for converting ScatterMoE checkpoint back to the " + "orginal state dict format. " + "The ScatterMoE checkpoint was saved after the pretrained model " + "had been converted by a module swap, hence the state dict will " + "no longer resemble the original. This utility creaes" + ) + ) + + parser.add_argument( + "checkpoint_dir", + help="Path to the checkpoint.", + ) + + parser.add_argument( + "output_dir", help="Path to the location to write the converted checkpoint." + ) + + parser.add_argument( + "pretrained_model_name_or_path", + help=( + "In order to reconstruct the state dict, we requre hints from " + "the original pretrained model checkpoint (from which this " + "checkpoint is obtained)." + ), + default=None, + ) + + args = parser.parse_args() + recover_safetensors_from_dcp( + args.checkpoint_dir, args.pretrained_model_name_or_path, args.output_dir + )