1515from pathlib import Path
1616from hashlib import sha256
1717from typing import TYPE_CHECKING , Any , Callable , ContextManager , Iterable , Iterator , Literal , Sequence , TypeVar , cast
18+ from itertools import chain
1819
1920import math
2021import numpy as np
@@ -64,15 +65,14 @@ class Model:
6465 model_name : str | None
6566 metadata_override : Path | None
6667 dir_model_card : Path
67- is_lora : bool
6868
6969 # subclasses should define this!
7070 model_arch : gguf .MODEL_ARCH
7171
7272 def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , is_big_endian : bool = False ,
7373 use_temp_file : bool = False , eager : bool = False ,
7474 metadata_override : Path | None = None , model_name : str | None = None ,
75- split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False , small_first_shard : bool = False , is_lora : bool = False ):
75+ split_max_tensors : int = 0 , split_max_size : int = 0 , dry_run : bool = False , small_first_shard : bool = False ):
7676 if type (self ) is Model :
7777 raise TypeError (f"{ type (self ).__name__ !r} should not be directly instantiated" )
7878
@@ -94,7 +94,6 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
9494 self .metadata_override = metadata_override
9595 self .model_name = model_name
9696 self .dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
97- self .is_lora = is_lora # true if model is used inside convert_lora_to_gguf.py
9897
9998 # Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
10099 if self .ftype == gguf .LlamaFileType .GUESSED :
@@ -270,10 +269,14 @@ def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims:
270269
271270 return False
272271
272+ # some models need extra generated tensors (like rope_freqs)
273+ def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
274+ return ()
275+
273276 def prepare_tensors (self ):
274277 max_name_len = max (len (s ) for _ , s in self .tensor_map .mapping .values ()) + len (".weight," )
275278
276- for name , data_torch in self .get_tensors ():
279+ for name , data_torch in chain ( self .generate_extra_tensors (), self . get_tensors () ):
277280 # we don't need these
278281 if name .endswith ((".attention.masked_bias" , ".attention.bias" , ".rotary_emb.inv_freq" )):
279282 continue
@@ -1617,7 +1620,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
16171620
16181621 return [(self .map_tensor_name (name ), data_torch )]
16191622
1620- def prepare_tensors (self ):
1623+ def generate_extra_tensors (self ) -> Iterable [ tuple [ str , Tensor ]] :
16211624 if rope_scaling := self .find_hparam (["rope_scaling" ], optional = True ):
16221625 if rope_scaling .get ("rope_type" , '' ).lower () == "llama3" :
16231626 base = self .hparams .get ("rope_theta" , 10000.0 )
@@ -1644,9 +1647,9 @@ def prepare_tensors(self):
16441647 smooth = (old_context_len / wavelen - low_freq_factor ) / (high_freq_factor - low_freq_factor )
16451648 rope_factors .append (1 / ((1 - smooth ) / factor + smooth ))
16461649
1647- if not self .is_lora :
1648- self .gguf_writer .add_tensor (self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FREQS ), np .array (rope_factors , dtype = np .float32 ))
1650+ yield (self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FREQS ), torch .tensor (rope_factors , dtype = torch .float32 ))
16491651
1652+ def prepare_tensors (self ):
16501653 super ().prepare_tensors ()
16511654
16521655 if self ._experts is not None :
@@ -1870,8 +1873,6 @@ class MiniCPM3Model(Model):
18701873 def set_gguf_parameters (self ):
18711874 hparams = self .hparams
18721875
1873- rope_dims = hparams ["qk_rope_head_dim" ]
1874-
18751876 self .gguf_writer .add_file_type (self .ftype )
18761877 self .gguf_writer .add_context_length (hparams ["max_position_embeddings" ])
18771878 self .gguf_writer .add_embedding_length (hparams ["hidden_size" ])
@@ -1887,24 +1888,25 @@ def set_gguf_parameters(self):
18871888 self .gguf_writer .add_key_length (hparams ["qk_nope_head_dim" ] + hparams ["qk_rope_head_dim" ])
18881889 self .gguf_writer .add_rope_dimension_count (hparams ["qk_rope_head_dim" ])
18891890
1891+ def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
18901892 rope_scaling = self .find_hparam (['rope_scaling' ], True )
1891- if rope_scaling is None :
1892- return
1893+ if rope_scaling is not None :
1894+ rope_dims = self . hparams [ "qk_rope_head_dim" ]
18931895
1894- long_factors = rope_scaling .get ('long_factor' , None )
1895- short_factors = rope_scaling .get ('short_factor' , None )
1896+ long_factors = rope_scaling .get ('long_factor' , None )
1897+ short_factors = rope_scaling .get ('short_factor' , None )
18961898
1897- if long_factors is None or short_factors is None :
1898- raise KeyError ('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor' )
1899+ if long_factors is None or short_factors is None :
1900+ raise KeyError ('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor' )
18991901
1900- if len (long_factors ) != len (short_factors ) or len (long_factors ) != rope_dims / 2 :
1901- raise ValueError (f'The length of rope long and short factors must be { rope_dims / 2 } ' )
1902+ if len (long_factors ) != len (short_factors ) or len (long_factors ) != rope_dims / 2 :
1903+ raise ValueError (f'The length of rope long and short factors must be { rope_dims / 2 } ' )
19021904
1903- self .gguf_writer . add_tensor (gguf .TENSOR_NAMES [ gguf . MODEL_TENSOR .ROPE_FACTORS_LONG ] + ".weight" , np . array (long_factors , dtype = np .float32 ))
1904- self .gguf_writer . add_tensor (gguf .TENSOR_NAMES [ gguf . MODEL_TENSOR .ROPE_FACTORS_SHORT ] + ".weight" , np . array (short_factors , dtype = np .float32 ))
1905+ yield ( self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FACTORS_LONG ), torch . tensor (long_factors , dtype = torch .float32 ))
1906+ yield ( self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FACTORS_SHORT ), torch . tensor (short_factors , dtype = torch .float32 ))
19051907
19061908 def set_vocab (self ):
1907- self ._set_vocab_llama_hf ()
1909+ self ._set_vocab_sentencepiece ()
19081910
19091911 def _reverse_hf_permute (self , weights : Tensor , n_head : int , n_kv_head : int | None = None ) -> Tensor :
19101912 if n_kv_head is not None and n_head != n_kv_head :
@@ -2216,6 +2218,13 @@ def set_gguf_parameters(self):
22162218 self .gguf_writer .add_file_type (self .ftype )
22172219 self .gguf_writer .add_sliding_window (self .find_hparam (["sliding_window" ]))
22182220
2221+ def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
2222+ n_embd = self .find_hparam (["hidden_size" , "n_embd" ])
2223+ n_head = self .find_hparam (["num_attention_heads" , "n_head" ])
2224+ max_pos_embds = self .find_hparam (["n_positions" , "max_position_embeddings" ])
2225+ orig_max_pos_embds = self .find_hparam (["original_max_position_embeddings" ])
2226+ rope_dims = n_embd // n_head
2227+
22192228 # write rope scaling for long context (128k) model
22202229 rope_scaling = self .find_hparam (['rope_scaling' ], True )
22212230 if rope_scaling is None :
@@ -2245,9 +2254,8 @@ def set_gguf_parameters(self):
22452254 if len (long_factors ) != len (short_factors ) or len (long_factors ) != rope_dims / 2 :
22462255 raise ValueError (f'The length of rope long and short factors must be { rope_dims / 2 } ' )
22472256
2248- if not self .is_lora :
2249- self .gguf_writer .add_tensor (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .ROPE_FACTORS_LONG ] + ".weight" , np .array (long_factors , dtype = np .float32 ))
2250- self .gguf_writer .add_tensor (gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .ROPE_FACTORS_SHORT ] + ".weight" , np .array (short_factors , dtype = np .float32 ))
2257+ yield (self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FACTORS_LONG ), torch .tensor (long_factors , dtype = torch .float32 ))
2258+ yield (self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FACTORS_SHORT ), torch .tensor (short_factors , dtype = torch .float32 ))
22512259
22522260
22532261@Model .register ("PlamoForCausalLM" )
@@ -4071,7 +4079,7 @@ def set_gguf_parameters(self):
40714079 self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .LINEAR )
40724080 self .gguf_writer .add_rope_scaling_factor (hparams ["rope_scaling" ]["factor" ])
40734081
4074- def prepare_tensors (self ):
4082+ def generate_extra_tensors (self ) -> Iterable [ tuple [ str , Tensor ]] :
40754083 if rope_scaling := self .find_hparam (["rope_scaling" ], optional = True ):
40764084 if rope_scaling .get ("rope_type" , '' ).lower () == "llama3" :
40774085 base = self .hparams .get ("rope_theta" , 10000.0 )
@@ -4098,10 +4106,7 @@ def prepare_tensors(self):
40984106 smooth = (old_context_len / wavelen - low_freq_factor ) / (high_freq_factor - low_freq_factor )
40994107 rope_factors .append (1 / ((1 - smooth ) / factor + smooth ))
41004108
4101- if not self .is_lora :
4102- self .gguf_writer .add_tensor (self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FREQS ), np .array (rope_factors , dtype = np .float32 ))
4103-
4104- super ().prepare_tensors ()
4109+ yield (self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FREQS ), torch .tensor (rope_factors , dtype = torch .float32 ))
41054110
41064111
41074112@Model .register ("GraniteForCausalLM" )
0 commit comments