@@ -73,7 +73,7 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
7373 use_temp_file : bool = False , eager : bool = False ,
7474 metadata_override : Path | None = None , model_name : str | None = None ,
7575 split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False ,
76- small_first_shard : bool = False , hparams : dict [str , Any ] | None = None ):
76+ small_first_shard : bool = False , hparams : dict [str , Any ] | None = None , remote_hf_model_id : str | None = None ):
7777 if type (self ) is Model :
7878 raise TypeError (f"{ type (self ).__name__ !r} should not be directly instantiated" )
7979
@@ -83,11 +83,23 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
8383 self .is_big_endian = is_big_endian
8484 self .endianess = gguf .GGUFEndian .BIG if is_big_endian else gguf .GGUFEndian .LITTLE
8585 self .use_temp_file = use_temp_file
86- self .lazy = not eager
87- self .part_names = Model .get_model_part_names (self .dir_model , "model" , ".safetensors" )
88- self .is_safetensors = len (self .part_names ) > 0
89- if not self .is_safetensors :
90- self .part_names = Model .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
86+ self .lazy = not eager or (remote_hf_model_id is not None )
87+ if remote_hf_model_id is not None :
88+ self .is_safetensors = True
89+
90+ def get_remote_tensors () -> Iterator [tuple [str , Tensor ]]:
91+ logger .info (f"Using remote model with HuggingFace id: { remote_hf_model_id } " )
92+ remote_tensors = gguf .utility .SafetensorRemote .get_list_tensors_hf_model (remote_hf_model_id )
93+ self .tensor_names = set (name for name in remote_tensors .keys ())
94+ for name , remote_tensor in gguf .utility .SafetensorRemote .get_list_tensors_hf_model (remote_hf_model_id ).items ():
95+ yield (name , LazyTorchTensor .from_remote_tensor (remote_tensor ))
96+
97+ self .get_tensors = get_remote_tensors
98+ else :
99+ self .part_names = Model .get_model_part_names (self .dir_model , "model" , ".safetensors" )
100+ self .is_safetensors = len (self .part_names ) > 0
101+ if not self .is_safetensors :
102+ self .part_names = Model .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
91103 self .hparams = Model .load_hparams (self .dir_model ) if hparams is None else hparams
92104 self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
93105 self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
@@ -5393,6 +5405,14 @@ def from_safetensors_slice(cls, st_slice: Any) -> Tensor:
53935405 lazy = cls (meta = cls .meta_with_dtype_and_shape (dtype , shape ), args = (st_slice ,), func = lambda s : s [:])
53945406 return cast (torch .Tensor , lazy )
53955407
5408+ @classmethod
5409+ def from_remote_tensor (cls , remote_tensor : gguf .utility .RemoteTensor ):
5410+ dtype = cls ._dtype_str_map [remote_tensor .dtype ]
5411+ shape = remote_tensor .shape
5412+ meta = cls .meta_with_dtype_and_shape (dtype , shape )
5413+ lazy = cls (meta = meta , args = (remote_tensor ,), func = lambda r : torch .frombuffer (r .data (), dtype = dtype ).reshape (shape ))
5414+ return cast (torch .Tensor , lazy )
5415+
53965416 @classmethod
53975417 def __torch_function__ (cls , func , types , args = (), kwargs = None ):
53985418 del types # unused
@@ -5516,8 +5536,9 @@ def main() -> None:
55165536
55175537 if args .remote :
55185538 from huggingface_hub import snapshot_download
5539+ args .remote = str (dir_model )
55195540 local_dir = snapshot_download (
5520- repo_id = str ( dir_model ) ,
5541+ repo_id = args . remote ,
55215542 allow_patterns = ["LICENSE" , "*.json" , "*.md" , "*.txt" , "tokenizer.model" ])
55225543 dir_model = Path (local_dir )
55235544 logger .info (f"Downloaded config and tokenizer to { local_dir } " )
@@ -5569,7 +5590,7 @@ def main() -> None:
55695590 metadata_override = args .metadata , model_name = args .model_name ,
55705591 split_max_tensors = args .split_max_tensors ,
55715592 split_max_size = split_str_to_n_bytes (args .split_max_size ), dry_run = args .dry_run ,
5572- small_first_shard = args .no_tensor_first_split )
5593+ small_first_shard = args .no_tensor_first_split , remote_hf_model_id = args . remote or None )
55735594
55745595 if args .vocab_only :
55755596 logger .info ("Exporting model vocab..." )
0 commit comments