|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import pickle |
| 4 | +import shutil |
| 5 | +import time |
| 6 | + |
| 7 | +import torch |
| 8 | +import torch.distributed.checkpoint as dist_cp |
| 9 | +from transformers import AutoConfig, AutoModelForCausalLM |
| 10 | +from typing_extensions import override |
| 11 | + |
| 12 | + |
| 13 | +class UnpicklerWrapper(pickle.Unpickler): |
| 14 | + @override |
| 15 | + def find_class(self, mod_name, name): |
| 16 | + class DummyClass: |
| 17 | + def __init__(self, *args, **kwargs): |
| 18 | + pass |
| 19 | + |
| 20 | + if mod_name.startswith("megatron") or mod_name.startswith("glm"): |
| 21 | + return DummyClass |
| 22 | + return super().find_class(mod_name, name) |
| 23 | + |
| 24 | + |
| 25 | +class WrappedStorageReader(dist_cp.FileSystemReader): |
| 26 | + @override |
| 27 | + def read_metadata(self): |
| 28 | + path = self.fs.concat_path(self.path, ".metadata") |
| 29 | + with self.fs.create_stream(path, "rb") as metadata_file: |
| 30 | + metadata = UnpicklerWrapper(metadata_file).load() |
| 31 | + if getattr(metadata, "storage_meta", None) is None: |
| 32 | + metadata.storage_meta = dist_cp.StorageMeta() |
| 33 | + metadata.storage_meta.load_id = self.load_id |
| 34 | + if metadata.planner_data is None: |
| 35 | + metadata.planner_data = {} |
| 36 | + return metadata |
| 37 | + |
| 38 | + |
| 39 | +class EmptyStateDictLoadPlanner(dist_cp.default_planner.DefaultLoadPlanner): |
| 40 | + @override |
| 41 | + def set_up_planner( |
| 42 | + self, |
| 43 | + state_dict: dist_cp.metadata.STATE_DICT_TYPE, |
| 44 | + metadata: dist_cp.metadata.Metadata | None = None, |
| 45 | + is_coordinator: bool = False, |
| 46 | + ) -> None: |
| 47 | + for k, v in metadata.state_dict_metadata.items(): |
| 48 | + if "optimizer" in k: |
| 49 | + continue |
| 50 | + print(f"find {k} in torch_dist ckpt") |
| 51 | + if isinstance(v, dist_cp.metadata.TensorStorageMetadata): |
| 52 | + v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment] |
| 53 | + state_dict[k] = v |
| 54 | + super().set_up_planner(state_dict, metadata, is_coordinator) |
| 55 | + |
| 56 | + |
| 57 | +def _detect_model_dir(input_dir: str) -> str: |
| 58 | + model_dir = os.path.join(input_dir, "model") |
| 59 | + return model_dir if os.path.isdir(model_dir) else input_dir |
| 60 | + |
| 61 | + |
| 62 | +def _load_fsdp_state_dict(input_dir: str) -> dict[str, torch.Tensor]: |
| 63 | + state_dict: dict[str, torch.Tensor] = {} |
| 64 | + dist_cp.state_dict_loader._load_state_dict( |
| 65 | + state_dict, |
| 66 | + storage_reader=WrappedStorageReader(input_dir), |
| 67 | + planner=EmptyStateDictLoadPlanner(), |
| 68 | + no_dist=True, |
| 69 | + ) |
| 70 | + return state_dict |
| 71 | + |
| 72 | + |
| 73 | +def _get_candidate_prefixes(keys: list[str]) -> list[str]: |
| 74 | + predefined = [ |
| 75 | + "model_state.model.", |
| 76 | + "model_state.", |
| 77 | + "model.", |
| 78 | + "module.", |
| 79 | + "", |
| 80 | + ] |
| 81 | + |
| 82 | + detected: set[str] = set() |
| 83 | + for key in keys: |
| 84 | + for prefix in predefined: |
| 85 | + if prefix and key.startswith(prefix): |
| 86 | + detected.add(prefix) |
| 87 | + |
| 88 | + # Always keep empty string as a fall back option for exact match. |
| 89 | + detected.add("") |
| 90 | + # Preserve predefined order while keeping only detected prefixes. |
| 91 | + return [p for p in predefined if p in detected] |
| 92 | + |
| 93 | + |
| 94 | +def _strip_best_prefix(keys: list[str], target_keys: set[str]) -> tuple[str, int]: |
| 95 | + best_prefix = "" |
| 96 | + best_match = -1 |
| 97 | + |
| 98 | + for prefix in _get_candidate_prefixes(keys): |
| 99 | + mapped_keys = {k.removeprefix(prefix) for k in keys} |
| 100 | + match_count = len(mapped_keys & target_keys) |
| 101 | + if match_count > best_match: |
| 102 | + best_match = match_count |
| 103 | + best_prefix = prefix |
| 104 | + |
| 105 | + return best_prefix, best_match |
| 106 | + |
| 107 | + |
| 108 | +def _convert_fsdp_to_hf( |
| 109 | + origin_hf_dir: str, |
| 110 | + input_dir: str, |
| 111 | + output_dir: str, |
| 112 | +) -> None: |
| 113 | + print(f"loading FSDP model from {input_dir}") |
| 114 | + t = time.time() |
| 115 | + state_dict = _load_fsdp_state_dict(input_dir) |
| 116 | + print(f"FSDP model loaded in {time.time()-t:.2f} sec.") |
| 117 | + |
| 118 | + tensor_items = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)} |
| 119 | + |
| 120 | + config = AutoConfig.from_pretrained(origin_hf_dir, trust_remote_code=True) |
| 121 | + hf_model = AutoModelForCausalLM.from_config(config) |
| 122 | + target_keys = set(hf_model.state_dict().keys()) |
| 123 | + |
| 124 | + best_prefix, best_match = _strip_best_prefix(list(tensor_items.keys()), target_keys) |
| 125 | + total_keys = len(tensor_items) |
| 126 | + |
| 127 | + print(f"Using prefix '{best_prefix}' for key mapping. " f"Matched {best_match}/{total_keys} parameter keys.") |
| 128 | + |
| 129 | + model_state = {k.removeprefix(best_prefix): v for k, v in tensor_items.items()} |
| 130 | + |
| 131 | + if not model_state: |
| 132 | + raise ValueError( |
| 133 | + "No model weights found in checkpoint. " |
| 134 | + "Please pass the checkpoint directory (e.g. iter_xxx or iter_xxx/model)." |
| 135 | + ) |
| 136 | + |
| 137 | + missing, unexpected = hf_model.load_state_dict(model_state, strict=False) |
| 138 | + print(f"Missing keys: {missing}\nUnexpected keys: {unexpected}") |
| 139 | + |
| 140 | + os.makedirs(output_dir, exist_ok=True) |
| 141 | + hf_model.save_pretrained(output_dir, safe_serialization=True) |
| 142 | + print(f"Model weights saved to {output_dir}") |
| 143 | + |
| 144 | + |
| 145 | +def copy_assets(origin_hf_dir: str, output_dir: str) -> None: |
| 146 | + for filename in os.listdir(origin_hf_dir): |
| 147 | + if filename == "model.safetensors.index.json" or filename.endswith(".safetensors"): |
| 148 | + continue |
| 149 | + origin_filename = os.path.join(origin_hf_dir, filename) |
| 150 | + if not os.path.isfile(origin_filename): |
| 151 | + print(f"Skip {filename}, not a file.") |
| 152 | + continue |
| 153 | + src, dst = origin_filename, os.path.join(output_dir, filename) |
| 154 | + print(f"copy from {src} to {dst}") |
| 155 | + shutil.copy(src, dst) |
| 156 | + |
| 157 | + |
| 158 | +if __name__ == "__main__": |
| 159 | + parser = argparse.ArgumentParser() |
| 160 | + parser.add_argument("--input-dir", type=str, required=True) |
| 161 | + parser.add_argument("--output-dir", type=str, required=True) |
| 162 | + parser.add_argument( |
| 163 | + "--origin-hf-dir", |
| 164 | + type=str, |
| 165 | + required=True, |
| 166 | + help="The original Hugging Face model directory to load config/tokenizer assets.", |
| 167 | + ) |
| 168 | + parser.add_argument( |
| 169 | + "-f", "--force", action="store_true", help="Force overwrite the output directory if it exists." |
| 170 | + ) |
| 171 | + args = parser.parse_args() |
| 172 | + |
| 173 | + if os.path.exists(args.output_dir) and not args.force: |
| 174 | + raise ValueError(f"Output directory {args.output_dir} already exists. Use --force to overwrite it.") |
| 175 | + |
| 176 | + model_dir = _detect_model_dir(args.input_dir) |
| 177 | + _convert_fsdp_to_hf(args.origin_hf_dir, model_dir, args.output_dir) |
| 178 | + copy_assets(args.origin_hf_dir, args.output_dir) |
0 commit comments