1212from  math  import  prod 
1313from  pathlib  import  Path 
1414from  typing  import  TYPE_CHECKING , Any , Callable , Iterable , Iterator , Sequence , SupportsIndex , cast 
15+ from  transformers  import  AutoConfig 
1516
1617import  torch 
1718
@@ -256,8 +257,8 @@ def parse_args() -> argparse.Namespace:
256257        help = "only print out what will be done, without writing any new files" ,
257258    )
258259    parser .add_argument (
259-         "--base" , type = Path ,  required = True , 
260-         help = "directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required" ,
260+         "--base" , type = Path ,
261+         help = "directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config " ,
261262    )
262263    parser .add_argument (
263264        "lora_path" , type = Path ,
@@ -267,6 +268,12 @@ def parse_args() -> argparse.Namespace:
267268    return  parser .parse_args ()
268269
269270
271+ def  load_hparams_from_hf (hf_model_id : str ) ->  dict [str , Any ]:
272+     # normally, adapter does not come with base model config, we need to load it from AutoConfig 
273+     config  =  AutoConfig .from_pretrained (hf_model_id )
274+     return  config .to_dict ()
275+ 
276+ 
270277if  __name__  ==  '__main__' :
271278    args  =  parse_args ()
272279    logging .basicConfig (level = logging .DEBUG  if  args .verbose  else  logging .INFO )
@@ -281,7 +288,7 @@ def parse_args() -> argparse.Namespace:
281288
282289    ftype  =  ftype_map [args .outtype ]
283290
284-     dir_base_model : Path  =  args .base 
291+     dir_base_model : Path  |   None   =  args .base 
285292    dir_lora : Path  =  args .lora_path 
286293    lora_config  =  dir_lora  /  "adapter_config.json" 
287294    input_model  =  dir_lora  /  "adapter_model.safetensors" 
@@ -301,9 +308,29 @@ def parse_args() -> argparse.Namespace:
301308        input_model  =  os .path .join (dir_lora , "adapter_model.bin" )
302309        lora_model  =  torch .load (input_model , map_location = "cpu" , weights_only = True )
303310
311+     # load LoRA config 
312+     with  open (lora_config , "r" ) as  f :
313+         lparams : dict [str , Any ] =  json .load (f )
314+ 
304315    # load base model 
305-     logger .info (f"Loading base model: { dir_base_model .name }  )
306-     hparams  =  Model .load_hparams (dir_base_model )
316+     if  dir_base_model  is  None :
317+         if  "base_model_name_or_path"  in  lparams :
318+             model_id  =  lparams ["base_model_name_or_path" ]
319+             logger .info (f"Loading base model from Hugging Face: { model_id }  )
320+             try :
321+                 hparams  =  load_hparams_from_hf (model_id )
322+             except  OSError  as  e :
323+                 logger .error (f"Failed to load base model config: { e }  )
324+                 logger .error ("Please try downloading the base model and add its path to --base" )
325+                 sys .exit (1 )
326+         else :
327+             logger .error ("'base_model_name_or_path' is not found in adapter_config.json" )
328+             logger .error ("Base model config is required. Please download the base model and add its path to --base" )
329+             sys .exit (1 )
330+     else :
331+         logger .info (f"Loading base model: { dir_base_model .name }  )
332+         hparams  =  Model .load_hparams (dir_base_model )
333+ 
307334    with  torch .inference_mode ():
308335        try :
309336            model_class  =  Model .from_model_architecture (hparams ["architectures" ][0 ])
@@ -323,13 +350,15 @@ def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
323350                self .dir_model_card  =  dir_lora_model 
324351                self .lora_alpha  =  float (lora_alpha )
325352
353+             def  set_vocab (self ):
354+                 pass 
355+ 
326356            def  set_type (self ):
327357                self .gguf_writer .add_type (gguf .GGUFType .ADAPTER )
328358                self .gguf_writer .add_string (gguf .Keys .Adapter .TYPE , "lora" )
329359
330360            def  set_gguf_parameters (self ):
331361                self .gguf_writer .add_float32 (gguf .Keys .Adapter .LORA_ALPHA , self .lora_alpha )
332-                 super ().set_gguf_parameters ()
333362
334363            def  generate_extra_tensors (self ) ->  Iterable [tuple [str , Tensor ]]:
335364                # Never add extra tensors (e.g. rope_freqs) for LoRA adapters 
@@ -350,7 +379,7 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
350379                        logger .error (f"Unexpected name '{ name }  )
351380                        if  ".embed_tokens.weight"  in  name  or  ".lm_head.weight"  in  name :
352381                            logger .error ("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning" )
353-                             logger .error ("Hint: if you are using TRL, make sure not  to call setup_chat_format() " )
382+                             logger .error ("Please refer  to https://github.com/ggerganov/llama.cpp/pull/9948 " )
354383                        sys .exit (1 )
355384
356385                    if  base_name  in  tensor_map :
@@ -384,9 +413,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
384413                    yield  (dest_name  +  ".lora_a" , lora_a )
385414                    yield  (dest_name  +  ".lora_b" , lora_b )
386415
387-         with  open (lora_config , "r" ) as  f :
388-             lparams : dict [str , Any ] =  json .load (f )
389- 
390416        alpha : float  =  lparams ["lora_alpha" ]
391417
392418        model_instance  =  LoraModel (
@@ -399,6 +425,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
399425            dry_run = args .dry_run ,
400426            dir_lora_model = dir_lora ,
401427            lora_alpha = alpha ,
428+             hparams = hparams ,
402429        )
403430
404431        logger .info ("Exporting model..." )
0 commit comments