@@ -226,6 +226,9 @@ def get_base_tensor_name(lora_tensor_name: str) -> str:
226226 base_name = lora_tensor_name .replace ("base_model.model." , "" )
227227 base_name = base_name .replace (".lora_A.weight" , ".weight" )
228228 base_name = base_name .replace (".lora_B.weight" , ".weight" )
229+ # models produced by mergekit-extract-lora have token embeddings in the adapter
230+ base_name = base_name .replace (".lora_embedding_A" , ".weight" )
231+ base_name = base_name .replace (".lora_embedding_B" , ".weight" )
229232 return base_name
230233
231234
@@ -260,6 +263,10 @@ def parse_args() -> argparse.Namespace:
260263 "--base" , type = Path ,
261264 help = "directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config" ,
262265 )
266+ parser .add_argument (
267+ "--base-model-id" , type = str ,
268+ help = "the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')" ,
269+ )
263270 parser .add_argument (
264271 "lora_path" , type = Path ,
265272 help = "directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)" ,
@@ -290,6 +297,7 @@ def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
290297
291298 dir_base_model : Path | None = args .base
292299 dir_lora : Path = args .lora_path
300+ base_model_id : str | None = args .base_model_id
293301 lora_config = dir_lora / "adapter_config.json"
294302 input_model = dir_lora / "adapter_model.safetensors"
295303
@@ -313,7 +321,10 @@ def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
313321 lparams : dict [str , Any ] = json .load (f )
314322
315323 # load base model
316- if dir_base_model is None :
324+ if base_model_id is not None :
325+ logger .info (f"Loading base model from Hugging Face: { base_model_id } " )
326+ hparams = load_hparams_from_hf (base_model_id )
327+ elif dir_base_model is None :
317328 if "base_model_name_or_path" in lparams :
318329 model_id = lparams ["base_model_name_or_path" ]
319330 logger .info (f"Loading base model from Hugging Face: { model_id } " )
@@ -371,17 +382,26 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
371382 if self .lazy :
372383 tensor = LazyTorchTensor .from_eager (tensor )
373384 base_name = get_base_tensor_name (name )
374- is_lora_a = ".lora_A.weight" in name
375- is_lora_b = ".lora_B.weight" in name
385+ # note: lora_embedding is transposed by mergekit-extract-lora, so it's reversed here
386+ is_lora_a = ".lora_A.weight" in name or ".lora_embedding_B" in name
387+ is_lora_b = ".lora_B.weight" in name or ".lora_embedding_A" in name
376388 if not is_lora_a and not is_lora_b :
377389 if ".base_layer.weight" in name :
378390 continue
391+ # mergekit-extract-lora add these layernorm to the adapter
392+ if ".layernorm" or ".norm" in name :
393+ yield (base_name , tensor )
394+ continue
379395 logger .error (f"Unexpected name '{ name } ': Not a lora_A or lora_B tensor" )
380396 if ".embed_tokens.weight" in name or ".lm_head.weight" in name :
381397 logger .error ("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning" )
382398 logger .error ("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948" )
383399 sys .exit (1 )
384400
401+ # mergekit-extract-lora transposes this tensor, we need to transpose it back
402+ if ".lora_embedding" in name :
403+ tensor = tensor .T
404+
385405 if base_name in tensor_map :
386406 if is_lora_a :
387407 tensor_map [base_name ].A = tensor
@@ -407,6 +427,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
407427 if name == "lm_head.weight" and len (dest ) == 0 :
408428 raise ValueError ("lm_head is present in adapter, but is ignored in base model" )
409429 for dest_name , dest_data in dest :
430+ # mergekit-extract-lora add these layernorm to the adapter
431+ if "_norm" in dest_name :
432+ assert dest_data .dim () == 1
433+ yield (dest_name , dest_data )
434+ continue
435+
436+ # otherwise, we must get the lora_A and lora_B tensors
410437 assert isinstance (dest_data , LoraTorchTensor )
411438 lora_a , lora_b = dest_data .get_lora_A_B ()
412439
0 commit comments