@@ -1597,6 +1597,184 @@ def prepare_tensors(self):
15971597 raise ValueError (f"Unprocessed experts: { experts } " )
15981598
15991599
1600+ @Model .register ("DeciLMForCausalLM" )
1601+ class DeciModel (Model ):
1602+ model_arch = gguf .MODEL_ARCH .DECI
1603+
1604+ @staticmethod
1605+ def _ffn_mult_to_intermediate_size (ffn_mult : float , n_embd : int ) -> int :
1606+ # DeciLM-specific code
1607+ intermediate_size = int (2 * ffn_mult * n_embd / 3 )
1608+ return DeciModel ._find_multiple (intermediate_size , 256 )
1609+
1610+ @staticmethod
1611+ def _find_multiple (n : int , k : int ) -> int :
1612+ # DeciLM-specific code
1613+ if n % k == 0 :
1614+ return n
1615+ return n + k - (n % k )
1616+
1617+ def __init__ (self , * args , ** kwargs ):
1618+ super ().__init__ (* args , ** kwargs )
1619+
1620+ if "block_configs" in self .hparams : # Llama-3_1-Nemotron-51B
1621+ _block_configs : list [dict [str ,Any ]] = self .hparams ["block_configs" ]
1622+ assert self .block_count == len (_block_configs )
1623+ self ._num_kv_heads = list ()
1624+ self ._num_heads = list ()
1625+ _ffn_multipliers = list ()
1626+ # ***linear attention layer***
1627+ # if n_heads_in_group is None and replace_with_linear is True
1628+ # then _num_kv_heads[il] is 0 and _num_heads[il] is num_attention_heads
1629+ # ***attention-free layer***
1630+ # if n_heads_in_group is None and replace_with_linear is False
1631+ # then _num_kv_heads[il] is 0 and _num_heads[il] is 0
1632+ # ***normal attention-layer***
1633+ # if n_heads_in_group is not None, then
1634+ # _num_kv_heads[il] is num_attention_head // n_heads_in_group and
1635+ # _num_heads[il] is num_attention_head
1636+ # ***dummy layer*** for nemotron 253B
1637+ # if n_heads_in_group is None and ffn_mult is None
1638+ # then _num_kv_heads[il] is 0 and _num_heads[il] is 0 and _ffn_dims is 0
1639+ for il in range (len (_block_configs )):
1640+ if _block_configs [il ]["attention" ]["n_heads_in_group" ] is None :
1641+ if _block_configs [il ]["attention" ]["replace_with_linear" ] is True :
1642+ self ._num_kv_heads .append (0 )
1643+ self ._num_heads .append (self .hparams ["num_attention_heads" ])
1644+ else :
1645+ self ._num_kv_heads .append (0 )
1646+ self ._num_heads .append (0 )
1647+ else :
1648+ self ._num_kv_heads .append (self .hparams ["num_attention_heads" ] // _block_configs [il ]["attention" ]["n_heads_in_group" ])
1649+ self ._num_heads .append (self .hparams ["num_attention_heads" ])
1650+ if _block_configs [il ]["ffn" ]["ffn_mult" ] is None : # dummy layer
1651+ _ffn_multipliers .append (0.0 )
1652+ else :
1653+ _ffn_multipliers .append (_block_configs [il ]["ffn" ]["ffn_mult" ])
1654+ assert self .block_count == len (self ._num_kv_heads )
1655+ assert self .block_count == len (self ._num_heads )
1656+ assert self .block_count == len (_ffn_multipliers )
1657+ assert isinstance (self ._num_kv_heads , list ) and isinstance (self ._num_kv_heads [0 ], int )
1658+ assert isinstance (self ._num_heads , list ) and isinstance (self ._num_heads [0 ], int )
1659+ assert isinstance (_ffn_multipliers , list ) and isinstance (_ffn_multipliers [0 ], float )
1660+ self ._ffn_dims : list [int ] = [
1661+ DeciModel ._ffn_mult_to_intermediate_size (multiplier , self .hparams ["hidden_size" ])
1662+ for multiplier in _ffn_multipliers
1663+ ]
1664+
1665+ def set_vocab (self ):
1666+ # Please change tokenizer_config.json of Llama-3_1-Nemotron-51B's
1667+ # eos_token from '|eot_id|' to '|end_of_text|'
1668+ if self .hparams .get ("vocab_size" , 128256 ) == 128256 :
1669+ tokens , toktypes , tokpre = self .get_vocab_base ()
1670+ self .gguf_writer .add_tokenizer_model ("gpt2" )
1671+ self .gguf_writer .add_tokenizer_pre (tokpre )
1672+ self .gguf_writer .add_token_list (tokens )
1673+ self .gguf_writer .add_token_types (toktypes )
1674+
1675+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = True )
1676+ special_vocab .add_to_gguf (self .gguf_writer )
1677+ else :
1678+ # DeciLM-7B
1679+ self ._set_vocab_llama_hf ()
1680+
1681+ def set_gguf_parameters (self ):
1682+ if "block_configs" in self .hparams : # Llama-3_1-Nemotron-51B
1683+ assert self .block_count == len (self ._num_kv_heads )
1684+ assert self .block_count == len (self ._num_heads )
1685+ assert self .block_count == len (self ._ffn_dims )
1686+ if (rope_theta := self .hparams .get ("rope_theta" )) is not None :
1687+ self .gguf_writer .add_rope_freq_base (rope_theta )
1688+ self .gguf_writer .add_head_count_kv (self ._num_kv_heads )
1689+ self .gguf_writer .add_head_count (self ._num_heads )
1690+ self .gguf_writer .add_feed_forward_length (self ._ffn_dims )
1691+ self .gguf_writer .add_block_count (self .block_count )
1692+ self .gguf_writer .add_context_length (self .hparams ["max_position_embeddings" ])
1693+ self .gguf_writer .add_embedding_length (self .hparams ["hidden_size" ])
1694+ self .gguf_writer .add_layer_norm_rms_eps (self .hparams ["rms_norm_eps" ])
1695+ self .gguf_writer .add_key_length (self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ])
1696+ self .gguf_writer .add_value_length (self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ])
1697+ self .gguf_writer .add_file_type (self .ftype )
1698+ else : # DeciLM-7B
1699+ super ().set_gguf_parameters ()
1700+ if "num_key_value_heads_per_layer" in self .hparams : # DeciLM-7B
1701+ self ._num_kv_heads : list [int ] = self .hparams ["num_key_value_heads_per_layer" ]
1702+ assert self .block_count == len (self ._num_kv_heads )
1703+ self .gguf_writer .add_head_count_kv (self ._num_kv_heads )
1704+ hparams = self .hparams
1705+ self .gguf_writer .add_vocab_size (hparams ["vocab_size" ])
1706+
1707+ if "head_dim" in hparams :
1708+ rope_dim = hparams ["head_dim" ]
1709+ else :
1710+ rope_dim = hparams ["hidden_size" ] // hparams ["num_attention_heads" ]
1711+ self .gguf_writer .add_rope_dimension_count (rope_dim )
1712+
1713+ if self .hparams .get ("rope_scaling" ) is not None and "factor" in self .hparams ["rope_scaling" ]:
1714+ if self .hparams ["rope_scaling" ].get ("type" ) == "linear" :
1715+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .LINEAR )
1716+ self .gguf_writer .add_rope_scaling_factor (self .hparams ["rope_scaling" ]["factor" ])
1717+
1718+ @staticmethod
1719+ def permute (weights : Tensor , n_head : int , n_head_kv : int | None ):
1720+ if n_head_kv is not None and n_head != n_head_kv :
1721+ n_head = n_head_kv
1722+ return (weights .reshape (n_head , 2 , weights .shape [0 ] // n_head // 2 , * weights .shape [1 :])
1723+ .swapaxes (1 , 2 )
1724+ .reshape (weights .shape ))
1725+
1726+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
1727+ n_head = self .hparams ["num_attention_heads" ]
1728+ if bid is not None :
1729+ if "num_key_value_heads_per_layer" in self .hparams :
1730+ n_kv_head = self .hparams ["num_key_value_heads_per_layer" ][bid ]
1731+ elif "block_configs" in self .hparams :
1732+ n_kv_head = self ._num_kv_heads [bid ]
1733+ n_head = self ._num_heads [bid ]
1734+ else :
1735+ n_kv_head = self .hparams .get ("num_key_value_heads" )
1736+ else :
1737+ n_kv_head = self .hparams .get ("num_key_value_heads" )
1738+
1739+ if name .endswith (("q_proj.weight" , "q_proj.bias" )):
1740+ data_torch = DeciModel .permute (data_torch , n_head , n_head )
1741+ if name .endswith (("k_proj.weight" , "k_proj.bias" )):
1742+ data_torch = DeciModel .permute (data_torch , n_head , n_kv_head )
1743+ return [(self .map_tensor_name (name ), data_torch )]
1744+
1745+ def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
1746+ if rope_scaling := self .find_hparam (["rope_scaling" ], optional = True ):
1747+ if rope_scaling .get ("rope_type" , '' ).lower () == "llama3" :
1748+ base = self .hparams .get ("rope_theta" , 10000.0 )
1749+ dim = self .hparams .get ("head_dim" , self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ])
1750+ freqs = 1.0 / (base ** (torch .arange (0 , dim , 2 , dtype = torch .float32 ) / dim ))
1751+
1752+ factor = rope_scaling .get ("factor" , 8.0 )
1753+ low_freq_factor = rope_scaling .get ("low_freq_factor" , 1.0 )
1754+ high_freq_factor = rope_scaling .get ("high_freq_factor" , 4.0 )
1755+ old_context_len = self .hparams .get ("original_max_position_embeddings" , 8192 )
1756+
1757+ low_freq_wavelen = old_context_len / low_freq_factor
1758+ high_freq_wavelen = old_context_len / high_freq_factor
1759+ assert low_freq_wavelen != high_freq_wavelen
1760+
1761+ rope_factors = []
1762+ for freq in freqs :
1763+ wavelen = 2 * math .pi / freq
1764+ if wavelen < high_freq_wavelen :
1765+ rope_factors .append (1 )
1766+ elif wavelen > low_freq_wavelen :
1767+ rope_factors .append (factor )
1768+ else :
1769+ smooth = (old_context_len / wavelen - low_freq_factor ) / (high_freq_factor - low_freq_factor )
1770+ rope_factors .append (1 / ((1 - smooth ) / factor + smooth ))
1771+
1772+ yield (self .format_tensor_name (gguf .MODEL_TENSOR .ROPE_FREQS ), torch .tensor (rope_factors , dtype = torch .float32 ))
1773+
1774+ def prepare_tensors (self ):
1775+ super ().prepare_tensors ()
1776+
1777+
16001778@Model .register ("BitnetForCausalLM" )
16011779@Model .register ("BitNetForCausalLM" )
16021780class BitnetModel (Model ):
0 commit comments