1212from math import prod
1313from pathlib import Path
1414from typing import TYPE_CHECKING , Any , Callable , Iterable , Iterator , Sequence , SupportsIndex , cast
15+ from transformers import AutoConfig
1516
1617import torch
1718
@@ -256,8 +257,8 @@ def parse_args() -> argparse.Namespace:
256257 help = "only print out what will be done, without writing any new files" ,
257258 )
258259 parser .add_argument (
259- "--base" , type = Path , required = True ,
260- help = "directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required" ,
260+ "--base" , type = Path ,
261+ help = "directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config " ,
261262 )
262263 parser .add_argument (
263264 "lora_path" , type = Path ,
@@ -267,6 +268,12 @@ def parse_args() -> argparse.Namespace:
267268 return parser .parse_args ()
268269
269270
271+ def load_hparams_from_hf (hf_model_id : str ) -> dict [str , Any ]:
272+ # normally, adapter does not come with base model config, we need to load it from AutoConfig
273+ config = AutoConfig .from_pretrained (hf_model_id )
274+ return config .to_dict ()
275+
276+
270277if __name__ == '__main__' :
271278 args = parse_args ()
272279 logging .basicConfig (level = logging .DEBUG if args .verbose else logging .INFO )
@@ -281,7 +288,7 @@ def parse_args() -> argparse.Namespace:
281288
282289 ftype = ftype_map [args .outtype ]
283290
284- dir_base_model : Path = args .base
291+ dir_base_model : Path | None = args .base
285292 dir_lora : Path = args .lora_path
286293 lora_config = dir_lora / "adapter_config.json"
287294 input_model = dir_lora / "adapter_model.safetensors"
@@ -301,9 +308,29 @@ def parse_args() -> argparse.Namespace:
301308 input_model = os .path .join (dir_lora , "adapter_model.bin" )
302309 lora_model = torch .load (input_model , map_location = "cpu" , weights_only = True )
303310
311+ # load LoRA config
312+ with open (lora_config , "r" ) as f :
313+ lparams : dict [str , Any ] = json .load (f )
314+
304315 # load base model
305- logger .info (f"Loading base model: { dir_base_model .name } " )
306- hparams = Model .load_hparams (dir_base_model )
316+ if dir_base_model is None :
317+ if "base_model_name_or_path" in lparams :
318+ model_id = lparams ["base_model_name_or_path" ]
319+ logger .info (f"Loading base model from Hugging Face: { model_id } " )
320+ try :
321+ hparams = load_hparams_from_hf (model_id )
322+ except OSError as e :
323+ logger .error (f"Failed to load base model config: { e } " )
324+ logger .error ("Please try downloading the base model and add its path to --base" )
325+ sys .exit (1 )
326+ else :
327+ logger .error ("'base_model_name_or_path' is not found in adapter_config.json" )
328+ logger .error ("Base model config is required. Please download the base model and add its path to --base" )
329+ sys .exit (1 )
330+ else :
331+ logger .info (f"Loading base model: { dir_base_model .name } " )
332+ hparams = Model .load_hparams (dir_base_model )
333+
307334 with torch .inference_mode ():
308335 try :
309336 model_class = Model .from_model_architecture (hparams ["architectures" ][0 ])
@@ -323,13 +350,15 @@ def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
323350 self .dir_model_card = dir_lora_model
324351 self .lora_alpha = float (lora_alpha )
325352
353+ def set_vocab (self ):
354+ pass
355+
326356 def set_type (self ):
327357 self .gguf_writer .add_type (gguf .GGUFType .ADAPTER )
328358 self .gguf_writer .add_string (gguf .Keys .Adapter .TYPE , "lora" )
329359
330360 def set_gguf_parameters (self ):
331361 self .gguf_writer .add_float32 (gguf .Keys .Adapter .LORA_ALPHA , self .lora_alpha )
332- super ().set_gguf_parameters ()
333362
334363 def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
335364 # Never add extra tensors (e.g. rope_freqs) for LoRA adapters
@@ -350,7 +379,7 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
350379 logger .error (f"Unexpected name '{ name } ': Not a lora_A or lora_B tensor" )
351380 if ".embed_tokens.weight" in name or ".lm_head.weight" in name :
352381 logger .error ("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning" )
353- logger .error ("Hint: if you are using TRL, make sure not to call setup_chat_format() " )
382+ logger .error ("Please refer to https://github.com/ggerganov/llama.cpp/pull/9948 " )
354383 sys .exit (1 )
355384
356385 if base_name in tensor_map :
@@ -384,9 +413,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
384413 yield (dest_name + ".lora_a" , lora_a )
385414 yield (dest_name + ".lora_b" , lora_b )
386415
387- with open (lora_config , "r" ) as f :
388- lparams : dict [str , Any ] = json .load (f )
389-
390416 alpha : float = lparams ["lora_alpha" ]
391417
392418 model_instance = LoraModel (
@@ -399,6 +425,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
399425 dry_run = args .dry_run ,
400426 dir_lora_model = dir_lora ,
401427 lora_alpha = alpha ,
428+ hparams = hparams ,
402429 )
403430
404431 logger .info ("Exporting model..." )
0 commit comments