@@ -39,20 +39,14 @@ def convert_hf_checkpoint(
3939 config = TransformerArgs .from_params (config_args )
4040 print (f"Model config { config .__dict__ } " )
4141
42- # Load the json file containing weight mapping
42+ # Find all candidate weight mapping index files
4343 model_map_json_matches = [Path (m ) for m in glob .glob (str (model_dir / "*.index.json" ))]
44- if "mistral" not in model_name :
45- assert len (model_map_json_matches ) <= 1 , "Found multiple weight mapping files"
46- if len (model_map_json_matches ):
47- model_map_json = model_map_json_matches [0 ]
48- else :
49- model_map_json = model_dir / "pytorch_model.bin.index.json"
5044
5145 # If there is no weight mapping, check for a consolidated model and
5246 # tokenizer we can move. Llama 2 and Mistral have weight mappings, while
5347 # Llama 3 has a consolidated model and tokenizer.
5448 # Otherwise raise an error.
55- if not model_map_json . is_file () :
49+ if not model_map_json_matches :
5650 consolidated_pth = model_dir / "original" / "consolidated.00.pth"
5751 tokenizer_pth = model_dir / "original" / "tokenizer.model"
5852 if consolidated_pth .is_file () and tokenizer_pth .is_file ():
@@ -69,11 +63,29 @@ def convert_hf_checkpoint(
6963 return
7064 else :
7165 raise RuntimeError (
72- f"Could not find { model_map_json } or { consolidated_pth } plus { tokenizer_pth } "
66+ f"Could not find a valid model weight map or { consolidated_pth } plus { tokenizer_pth } "
7367 )
7468
75- with open (model_map_json ) as json_map :
76- bin_index = json .load (json_map )
69+ # Load the json file(s) containing weight mapping
70+ #
71+ # NOTE: If there are multiple index files, there are two possibilities:
72+ # 1. The files could be mapped to different weight format files (e.g. .bin
73+ # vs .safetensors)
74+ # 2. The files could be split subsets of the mappings that need to be
75+ # merged
76+ #
77+ # In either case, we can simply keep the mappings where the target file is
78+ # valid in the model dir.
79+ bin_files = {}
80+ for weight_map_file in model_map_json_matches :
81+ with open (weight_map_file , "r" ) as handle :
82+ weight_map = json .load (handle )
83+ valid_mappings = {
84+ k : model_dir / v
85+ for (k , v ) in weight_map .get ("weight_map" , {}).items ()
86+ if (model_dir / v ).is_file ()
87+ }
88+ bin_files .update (valid_mappings )
7789
7890 weight_map = {
7991 "model.embed_tokens.weight" : "tok_embeddings.weight" ,
@@ -97,7 +109,6 @@ def convert_hf_checkpoint(
97109 "model.norm.weight" : "norm.weight" ,
98110 "lm_head.weight" : "output.weight" ,
99111 }
100- bin_files = {model_dir / bin for bin in bin_index ["weight_map" ].values ()}
101112
102113 def permute (w , n_heads ):
103114 return (
0 commit comments