Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
# limitations under the License.

# Local
from .checkpoint_utils import patch_huggingface_save_and_load_for_dtensors
from .checkpoint_utils import (
patch_huggingface_save_and_load_for_dtensors,
recover_safetensors_from_dcp,
)
from .scattermoe_prepare import prepare_scattermoe

# this is a special patch function to disable foreach for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,75 +457,38 @@ def save_sharded_safetensors(
# --------------------------- SCRIPT -------------------------


# have it serve as a conversion script
if __name__ == "__main__":
# Standard
import argparse

parser = argparse.ArgumentParser(
description=(
"Utility for converting ScatterMoE checkpoint back to the "
"orginal state dict format. "
"The ScatterMoE checkpoint was saved after the pretrained model "
"had been converted by a module swap, hence the state dict will "
"no longer resemble the original. This utility creaes"
)
)

parser.add_argument(
"checkpoint_dir",
help="Path to the checkpoint.",
)

parser.add_argument(
"output_dir", help="Path to the location to write the converted checkpoint."
)

parser.add_argument(
"pretrained_model_name_or_path",
help=(
"In order to reconstruct the state dict, we requre hints from "
"the original pretrained model checkpoint (from which this "
"checkpoint is obtained)."
),
default=None,
)

args = parser.parse_args()

# search for an FSDP checkpoint. If it is an FSDP checkpoint, it must
# start with FSDP_MODEL_NAME
if args.checkpoint_dir.startswith(FSDP_MODEL_NAME):
checkpoint_dir = args.checkpoint_dir
def recover_safetensors_from_dcp(
checkpoint_dir, pretrained_model_name_or_path, output_dir
):
if checkpoint_dir.startswith(FSDP_MODEL_NAME):
loader = get_state_dict_from_dcp_checkpoint
else:
checkpoint_dir = [
fsdp_checkpoint_dirs = [
x
for x in os.listdir(args.checkpoint_dir)
if os.path.isdir(os.path.join(args.checkpoint_dir, x))
for x in os.listdir(checkpoint_dir)
if os.path.isdir(os.path.join(checkpoint_dir, x))
and x.startswith(FSDP_MODEL_NAME)
]
if len(checkpoint_dir) == 1:
checkpoint_dir = os.path.join(args.checkpoint_dir, checkpoint_dir[0])
if len(fsdp_checkpoint_dirs) == 1:
checkpoint_dir = os.path.join(checkpoint_dir, fsdp_checkpoint_dirs[0])
loader = get_state_dict_from_dcp_checkpoint
elif len(checkpoint_dir) > 1:
elif len(fsdp_checkpoint_dirs) > 1:
raise ValueError(
f"Found > 1 dirs in dcp checkpoint dir {args.checkpoint_dir} "
f"Found > 1 dirs in dcp checkpoint dir {checkpoint_dir} "
f"that starts with {FSDP_MODEL_NAME}. Please spectify the exact dir."
)
else:
# then take it as a safetensors checkpoint
# - do not support .bin checkpoints
checkpoint_dir = args.checkpoint_dir
loader = get_state_dict_from_safe_checkpoint

# - pretrained model name
_name_or_path = args.pretrained_model_name_or_path
_name_or_path = pretrained_model_name_or_path

# assume output directory exists, we do not create it
# - copy the config file if exists
config_file = os.path.join(checkpoint_dir, CONFIG_NAME)
target_config_file = os.path.join(args.output_dir, CONFIG_NAME)
target_config_file = os.path.join(output_dir, CONFIG_NAME)
if os.path.exists(config_file):
shutil.copyfile(config_file, target_config_file)

Expand All @@ -544,6 +507,46 @@ def save_sharded_safetensors(
# save it as a safetensors file
save_sharded_safetensors(
{k: v.contiguous() for k, v in state_dict.items()},
args.output_dir,
output_dir,
metadata={"format": "pt"},
)


# have it serve as a conversion script
if __name__ == "__main__":
# Standard
import argparse

parser = argparse.ArgumentParser(
description=(
"Utility for converting ScatterMoE checkpoint back to the "
"orginal state dict format. "
"The ScatterMoE checkpoint was saved after the pretrained model "
"had been converted by a module swap, hence the state dict will "
"no longer resemble the original. This utility creaes"
)
)

parser.add_argument(
"checkpoint_dir",
help="Path to the checkpoint.",
)

parser.add_argument(
"output_dir", help="Path to the location to write the converted checkpoint."
)

parser.add_argument(
"pretrained_model_name_or_path",
help=(
"In order to reconstruct the state dict, we requre hints from "
"the original pretrained model checkpoint (from which this "
"checkpoint is obtained)."
),
default=None,
)

args = parser.parse_args()
recover_safetensors_from_dcp(
args.checkpoint_dir, args.pretrained_model_name_or_path, args.output_dir
)