1212from math import prod
1313from pathlib import Path
1414from typing import TYPE_CHECKING , Any , Callable , Iterable , Iterator , Sequence , SupportsIndex , cast
15+ from transformers import AutoConfig
1516
1617import torch
1718
@@ -256,7 +257,7 @@ def parse_args() -> argparse.Namespace:
256257 help = "only print out what will be done, without writing any new files" ,
257258 )
258259 parser .add_argument (
259- "--base" , type = Path , required = True ,
260+ "--base" , type = Path ,
260261 help = "directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required" ,
261262 )
262263 parser .add_argument (
@@ -267,6 +268,12 @@ def parse_args() -> argparse.Namespace:
267268 return parser .parse_args ()
268269
269270
271+ def load_hparams_from_hf (hf_model_id : str ) -> dict [str , Any ]:
272+ # normally, adapter does not come with base model config, we need to load it from AutoConfig
273+ config = AutoConfig .from_pretrained (hf_model_id )
274+ return config .to_dict ()
275+
276+
270277if __name__ == '__main__' :
271278 args = parse_args ()
272279 logging .basicConfig (level = logging .DEBUG if args .verbose else logging .INFO )
@@ -281,7 +288,7 @@ def parse_args() -> argparse.Namespace:
281288
282289 ftype = ftype_map [args .outtype ]
283290
284- dir_base_model : Path = args .base
291+ dir_base_model : Path | None = args .base
285292 dir_lora : Path = args .lora_path
286293 lora_config = dir_lora / "adapter_config.json"
287294 input_model = dir_lora / "adapter_model.safetensors"
@@ -301,9 +308,25 @@ def parse_args() -> argparse.Namespace:
301308 input_model = os .path .join (dir_lora , "adapter_model.bin" )
302309 lora_model = torch .load (input_model , map_location = "cpu" , weights_only = True )
303310
311+ # load LoRA config
312+ with open (lora_config , "r" ) as f :
313+ lparams : dict [str , Any ] = json .load (f )
314+
304315 # load base model
305- logger .info (f"Loading base model: { dir_base_model .name } " )
306- hparams = Model .load_hparams (dir_base_model )
316+ if dir_base_model is None :
317+ if "base_model_name_or_path" in lparams :
318+ model_id = lparams ["base_model_name_or_path" ]
319+ logger .info (f"Loading base model from Hugging Face: { model_id } " )
320+ hparams = load_hparams_from_hf (model_id )
321+ else :
322+ logger .error ("'base_model_name_or_path' is not found in adapter_config.json" )
323+ logger .error ("Base model config is required. Please download the base model and add its path to --base" )
324+ sys .exit (1 )
325+ else :
326+ logger .info (f"Loading base model: { dir_base_model .name } " )
327+ hparams = Model .load_hparams (dir_base_model )
328+
329+
307330 with torch .inference_mode ():
308331 try :
309332 model_class = Model .from_model_architecture (hparams ["architectures" ][0 ])
@@ -323,6 +346,9 @@ def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
323346 self .dir_model_card = dir_lora_model
324347 self .lora_alpha = float (lora_alpha )
325348
349+ def set_vocab (self ):
350+ pass
351+
326352 def set_type (self ):
327353 self .gguf_writer .add_type (gguf .GGUFType .ADAPTER )
328354 self .gguf_writer .add_string (gguf .Keys .Adapter .TYPE , "lora" )
@@ -384,9 +410,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
384410 yield (dest_name + ".lora_a" , lora_a )
385411 yield (dest_name + ".lora_b" , lora_b )
386412
387- with open (lora_config , "r" ) as f :
388- lparams : dict [str , Any ] = json .load (f )
389-
390413 alpha : float = lparams ["lora_alpha" ]
391414
392415 model_instance = LoraModel (
@@ -399,6 +422,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
399422 dry_run = args .dry_run ,
400423 dir_lora_model = dir_lora ,
401424 lora_alpha = alpha ,
425+ hparams = hparams ,
402426 )
403427
404428 logger .info ("Exporting model..." )
0 commit comments