forked from foundation-model-stack/fms-fsdp
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathfms_to_hf_mamba.py
More file actions
37 lines (26 loc) · 1.18 KB
/
fms_to_hf_mamba.py
File metadata and controls
37 lines (26 loc) · 1.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import fire
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from torch.distributed._shard.checkpoint import FileSystemReader, load_state_dict
from fms_fsdp.utils.config_utils import get_model_config
def main(model_variant, load_path, save_path, tokenizer_name_or_path):
print("Initializing model...")
config_data = get_model_config(model_variant)
mamba_config = MambaConfig(**config_data)
model = MambaLMHeadModel(mamba_config)
print(f"Reading state dict from {load_path}")
state_dict = {"model_state": model.state_dict()}
load_state_dict(
state_dict=state_dict, storage_reader=FileSystemReader(load_path), no_dist=True
)
print("Loading state dict into the model...")
model.load_state_dict(state_dict["model_state"])
print("Saving model to HF-compatible format...")
model.save_pretrained(save_path)
print("Copying tokenizer...")
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
tokenizer.save_pretrained(save_path)
print(f"Model saving at {save_path}")
if __name__ == "__main__":
fire.Fire(main)