diff --git a/README.md b/README.md index 90c7364dfcba0..5d45185ebf979 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [GLM-4-0414](https://huggingface.co/collections/THUDM/glm-4-0414-67f3cbcb34dd9d252707cb2e) - [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966) - [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct) +- [x] [EXAONE4](https://huggingface.co/models?search=EXAONE4) - [x] [FalconMamba Models](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) - [x] [Jais](https://huggingface.co/inceptionai/jais-13b-chat) - [x] [Bielik-11B-v2.3](https://huggingface.co/collections/speakleash/bielik-11b-v23-66ee813238d9b526a072408a) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3d5e7e5a45600..4638370c9d4d8 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -300,6 +300,7 @@ def prepare_tensors(self): gguf.MODEL_TENSOR.POS_EMBD, gguf.MODEL_TENSOR.TOKEN_TYPES, gguf.MODEL_TENSOR.SSM_CONV1D, + gguf.MODEL_TENSOR.SHORTCONV_CONV, gguf.MODEL_TENSOR.TIME_MIX_FIRST, gguf.MODEL_TENSOR.TIME_MIX_W1, gguf.MODEL_TENSOR.TIME_MIX_W2, @@ -836,6 +837,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "f6791d196f87ce6b56a7d234be618e0d58f8cda3549416635b2bebcd22cd95c4": # ref: https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct res = "midm-2.0" + if chkhsh == "169bf0296a13c4d9b7672313f749eb36501d931022de052aad6e36f2bf34dd51": + # ref: https://huggingface.co/LiquidAI/LFM2-Tokenizer + res = "lfm2" if res is None: logger.warning("\n") @@ -6382,12 +6386,19 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: elif wavelen > low_freq_wavelen: rope_factors.append(factor) else: - smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smooth = ( + old_context_len / wavelen - low_freq_factor + ) / (high_freq_factor - low_freq_factor) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) +@ModelBase.register("Exaone4ForCausalLM") +class Exaone4Model(ExaoneModel): + model_arch = gguf.MODEL_ARCH.EXAONE4 + + @ModelBase.register("GraniteForCausalLM") class GraniteModel(LlamaModel): """Conversion for IBM's GraniteForCausalLM""" @@ -7073,6 +7084,57 @@ def set_vocab(self): chat_template = tokenizer.chat_template.replace("[:]", "") self.gguf_writer.add_chat_template(chat_template) + +@ModelBase.register("LFM2ForCausalLM") +class LFM2Model(TextModel): + model_arch = gguf.MODEL_ARCH.LFM2 + + def _add_feed_forward_length(self): + ff_dim = self.hparams["block_ff_dim"] + + auto_adjust_ff_dim = self.hparams["block_auto_adjust_ff_dim"] + ff_dim = self.hparams["block_ff_dim"] + ffn_dim_multiplier = self.hparams["block_ffn_dim_multiplier"] + multiple_of = self.hparams["block_multiple_of"] + + if auto_adjust_ff_dim: + ff_dim = int(2 * ff_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + ff_dim = int(ffn_dim_multiplier * ff_dim) + ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of) + + self.gguf_writer.add_feed_forward_length(ff_dim) + + def set_gguf_parameters(self): + # set num_key_value_heads only for attention layers + self.hparams["num_key_value_heads"] = [ + self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0 + for layer_type in self.hparams["layer_types"] + ] + + super().set_gguf_parameters() + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["norm_eps"]) + self._add_feed_forward_length() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if 'operator_norm' in name: + name = name.replace('operator_norm', 'norm') + elif 'attention.k_layernorm' in name or 'attention.q_layernorm' in name: + name = name.replace('attention', 'self_attn') + elif name.startswith("model.embedding_norm"): + name = name.replace("model.embedding_norm", 'word_embeddings_layernorm') + elif 'conv.conv' in name: + # conv op requires 2d tensor + data_torch = data_torch.squeeze(1) + elif 'self_attn.out_proj' in name: + name = name.replace('out_proj', 'o_proj') + + return [(self.map_tensor_name(name), data_torch)] + + ###### CONVERSION LOGIC ###### diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 9f9b88da8785a..16f4acfe7834f 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -130,6 +130,7 @@ class TOKENIZER_TYPE(IntEnum): {"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", }, {"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", }, {"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", }, + {"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"}, ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index f637571963730..41979733601d2 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -107,8 +107,11 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int if (nc == 4) { ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + } else if (nc == 3) { + ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { - GGML_ABORT("Only support kernel size = 4 now."); + GGML_ABORT("Only support kernel size = 3 or size = 4 right now."); } } else { if (nc == 4) { @@ -116,8 +119,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); ssm_conv_long_token_f32<<>>( src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + } else if (nc == 3) { + const int64_t split_n_t = 32; + dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); + ssm_conv_long_token_f32<<>>( + src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { - GGML_ABORT("Only support kernel size = 4 right now."); + GGML_ABORT("Only support kernel size = 3 or size = 4 right now."); } } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 37df17fa1cdd0..479b426aec153 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -187,6 +187,9 @@ class ConvNext: class Classifier: OUTPUT_LABELS = "{arch}.classifier.output_labels" + class ShortConv: + L_CACHE = "{arch}.shortconv.l_cache" + class Tokenizer: MODEL = "tokenizer.ggml.model" PRE = "tokenizer.ggml.pre" @@ -350,6 +353,7 @@ class MODEL_ARCH(IntEnum): JAIS = auto() NEMOTRON = auto() EXAONE = auto() + EXAONE4 = auto() GRANITE = auto() GRANITE_MOE = auto() GRANITE_HYBRID = auto() @@ -362,6 +366,7 @@ class MODEL_ARCH(IntEnum): ERNIE4_5 = auto() HUNYUAN_MOE = auto() SMOLLM3 = auto() + LFM2 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -533,6 +538,9 @@ class MODEL_TENSOR(IntEnum): POSNET_ATTN_K = auto() POSNET_ATTN_V = auto() POSNET_ATTN_OUT = auto() + SHORTCONV_CONV = auto() + SHORTCONV_INPROJ = auto() + SHORTCONV_OUTPROJ = auto() # vision V_MMPROJ = auto() V_MMPROJ_FC = auto() @@ -660,6 +668,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.JAIS: "jais", MODEL_ARCH.NEMOTRON: "nemotron", MODEL_ARCH.EXAONE: "exaone", + MODEL_ARCH.EXAONE4: "exaone4", MODEL_ARCH.GRANITE: "granite", MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.GRANITE_HYBRID: "granitehybrid", @@ -673,6 +682,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.FALCON_H1: "falcon-h1", MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe", MODEL_ARCH.SMOLLM3: "smollm3", + MODEL_ARCH.LFM2: "lfm2", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -844,6 +854,9 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", + MODEL_TENSOR.SHORTCONV_CONV: "blk.{bid}.shortconv.conv", + MODEL_TENSOR.SHORTCONV_INPROJ: "blk.{bid}.shortconv.in_proj", + MODEL_TENSOR.SHORTCONV_OUTPROJ: "blk.{bid}.shortconv.out_proj", # vision MODEL_TENSOR.V_MMPROJ: "mm.{bid}", MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc", @@ -2113,6 +2126,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.EXAONE4: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.GRANITE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2356,6 +2385,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.LFM2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.SHORTCONV_CONV, + MODEL_TENSOR.SHORTCONV_INPROJ, + MODEL_TENSOR.SHORTCONV_OUTPROJ, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.ATTN_NORM, # operator_norm + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a7ecf3d31209f..4f23f9b024619 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -648,6 +648,9 @@ def add_convnext_embedding_length(self, length: int) -> None: def add_convnext_block_count(self, length: int) -> None: self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length) + def add_shortconv_l_cache(self, length: int) -> None: + self.add_uint32(Keys.ShortConv.L_CACHE.format(arch=self.arch), length) + def add_block_count(self, length: int) -> None: self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 7a4f275ceec28..2fbae6f461448 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1015,6 +1015,18 @@ class TensorNameMap: "backbone.posnet.{bid}.proj_out", # wavtokenizer ), + MODEL_TENSOR.SHORTCONV_CONV: ( + "model.layers.{bid}.conv.conv", + ), + + MODEL_TENSOR.SHORTCONV_INPROJ: ( + "model.layers.{bid}.conv.in_proj", + ), + + MODEL_TENSOR.SHORTCONV_OUTPROJ: ( + "model.layers.{bid}.conv.out_proj", + ), + ############################################################################# ## Vision encoder diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 1105139fcdbb8..c1dd1b4c4d862 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -67,6 +67,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV7, "rwkv7" }, @@ -83,6 +84,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5, "ernie4_5" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_SMOLLM3, "smollm3" }, + { LLM_ARCH_LFM2, "lfm2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -188,6 +190,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, @@ -1474,6 +1478,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_EXAONE4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_RWKV6, { @@ -1830,6 +1853,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_LFM2, + { + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, + { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, + { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + } + }, { LLM_ARCH_UNKNOWN, { @@ -1997,6 +2041,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} @@ -2068,6 +2115,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_JAMBA: case LLM_ARCH_FALCON_H1: case LLM_ARCH_GRANITE_HYBRID: + case LLM_ARCH_LFM2: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index a9dd188a8f27d..a8212230e0822 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -71,6 +71,7 @@ enum llm_arch { LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, LLM_ARCH_EXAONE, + LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV7, @@ -87,6 +88,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_SMOLLM3, + LLM_ARCH_LFM2, LLM_ARCH_UNKNOWN, }; @@ -227,6 +229,8 @@ enum llm_kv { LLM_KV_CLASSIFIER_OUTPUT_LABELS, + LLM_KV_SHORTCONV_L_CACHE, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -396,6 +400,9 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + LLM_TENSOR_SHORTCONV_CONV, + LLM_TENSOR_SHORTCONV_INPROJ, + LLM_TENSOR_SHORTCONV_OUTPROJ, }; enum llm_tensor_layer { diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 86c814d51b901..7aa736e2f39db 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -71,6 +71,11 @@ uint32_t llama_hparams::n_embd_r() const { return token_shift_count * n_embd; } + if (n_shortconv_l_cache != 0) { + // for LFM2 models + return n_embd * (n_shortconv_l_cache - 1); + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 476d0a5eade28..d0500e4d0fd77 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -55,6 +55,8 @@ struct llama_hparams { struct llama_hparams_posnet posnet; struct llama_hparams_convnext convnext; + uint32_t n_shortconv_l_cache = 0; + std::array n_head_arr; std::array n_head_kv_arr; std::array n_ff_arr; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 8fc025afa0005..bc8eb5fc88556 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -43,15 +43,18 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_256M: return "256M"; case LLM_TYPE_270M: return "270M"; case LLM_TYPE_335M: return "335M"; + case LLM_TYPE_350M: return "350M"; case LLM_TYPE_410M: return "410M"; case LLM_TYPE_450M: return "450M"; case LLM_TYPE_475M: return "475M"; + case LLM_TYPE_700M: return "700M"; case LLM_TYPE_770M: return "770M"; case LLM_TYPE_780M: return "780M"; case LLM_TYPE_0_3B: return "0.3B"; case LLM_TYPE_0_5B: return "0.5B"; case LLM_TYPE_0_6B: return "0.6B"; case LLM_TYPE_1B: return "1B"; + case LLM_TYPE_1_2B: return "1.2B"; case LLM_TYPE_1_3B: return "1.3B"; case LLM_TYPE_1_4B: return "1.4B"; case LLM_TYPE_1_5B: return "1.5B"; @@ -1443,6 +1446,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_EXAONE4: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: { @@ -1663,6 +1675,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_LFM2: + { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_350M; break; + case 1536: type = LLM_TYPE_700M; break; + case 2048: type = LLM_TYPE_1_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -4215,6 +4241,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_EXAONE4: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_RWKV6: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4906,6 +4961,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_LFM2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + // ffn is same for transformer and conv layers + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // for operator_norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (!hparams.is_recurrent(i)) { + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, hparams.n_embd_k_gqa(i)}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, hparams.n_embd_v_gqa(i)}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } else { + layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); + layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); + layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -15859,6 +15947,180 @@ struct llm_build_smollm3 : public llm_graph_context { } }; +struct llm_build_lfm2 : public llm_graph_context { + const llama_model & model; + + llm_build_lfm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + + ggml_tensor * cur = build_inp_embd(model.tok_embd); + cb(cur, "model.embed_tokens", -1); + + ggml_tensor * inp_pos = build_inp_pos(); + + auto *inp_hybrid = build_inp_mem_hybrid(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + auto *prev_cur = cur; + cur = lfm2_rms_norm(cur, model.layers[il].attn_norm); + cb(cur, "model.layers.{}.operator_norm", il); + + cur = hparams.is_recurrent(il) ? + build_shortconv_block(gf, cur, inp_hybrid->get_recr(), il) : + build_attn_block(gf, cur, inp_pos, inp_hybrid->get_attn(), il) ; + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + prev_cur = ggml_get_rows(ctx0, prev_cur, inp_out_ids); + } + + cur = ggml_add(ctx0, prev_cur, cur); + cur = ggml_add(ctx0, cur, build_feed_forward(cur, il)); + } + + cur = lfm2_rms_norm(cur, model.tok_norm); + cb(cur, "model.embedding_norm", -1); + res->t_embd = cur; + + // lm_head is tied with embeddings + cur = ggml_mul_mat(ctx0, model.tok_embd, cur); + cb(cur, "lm_head", -1); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + ggml_tensor *build_feed_forward( + ggml_tensor * cur, + int il) const { + cur = lfm2_rms_norm(cur, model.layers[il].ffn_norm); + cb(cur, "model.layers.{}.ffn_norm", il); + + GGML_ASSERT(!model.layers[il].ffn_up_b); + GGML_ASSERT(!model.layers[il].ffn_gate_b); + GGML_ASSERT(!model.layers[il].ffn_down_b); + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "model.layers.{}.feed_forward.w2", il); + + return cur; + } + + ggml_tensor *build_attn_block(ggml_cgraph *gf, + ggml_tensor *cur, + ggml_tensor *inp_pos, + llm_graph_input_attn_kv_unified *inp_attn, + int il) const { + GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); + auto const n_embd_head = hparams.n_embd_head_v; + auto const n_head_kv = hparams.n_head_kv(il); + + auto *q = build_lora_mm(model.layers[il].wq, cur); + cb(q, "model.layers.{}.self_attn.q_proj", il); + auto *k = build_lora_mm(model.layers[il].wk, cur); + cb(k, "model.layers.{}.self_attn.k_proj", il); + auto *v = build_lora_mm(model.layers[il].wv, cur); + cb(v, "model.layers.{}.self_attn.v_proj", il); + + q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens); + k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens); + v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens); + + // qk norm + q = lfm2_rms_norm(q, model.layers[il].attn_q_norm); + cb(q, "model.layers.{}.self_attn.q_layernorm", il); + k = lfm2_rms_norm(k, model.layers[il].attn_k_norm); + cb(k, "model.layers.{}.self_attn.k_layernorm", il); + + // RoPE + q = ggml_rope_ext( + ctx0, q, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + k = ggml_rope_ext( + ctx0, k, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, + q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + + cb(cur, "model.layers.{}.self_attn.out_proj", il); + + return cur; + } + + ggml_tensor * build_shortconv_block(ggml_cgraph * gf, + ggml_tensor * cur, + llm_graph_input_rs *inp_recr, + int il) { + const auto * mctx_cur = static_cast(mctx)->get_recr(); + + auto *bcx = ggml_mul_mat(ctx0, model.layers[il].shortconv.in_proj, cur); + cb(bcx, "model.layers.{}.conv.in_proj", il); + + constexpr auto n_chunks = 3; + GGML_ASSERT(bcx->ne[0] % n_chunks == 0); + auto const chunk_size = bcx->ne[0] / n_chunks; + auto *b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx)); + auto *c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx)); + auto *x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx)); + + auto *bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); + + // read conv state directly, with build_rs generation is slower + ggml_tensor * conv_state = mctx_cur->get_r_l(il); + const int64_t n_seqs = ubatch.n_seqs; + ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs); + + bx = ggml_concat(ctx0, conv, bx, 0); + GGML_ASSERT(bx->ne[0] > conv->ne[0]); + + auto *new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); + GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); + + // write conv state + ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state)); + + auto *conv_kernel = model.layers[il].shortconv.conv; + GGML_ASSERT(hparams.n_shortconv_l_cache > 0); + + // construct ssm_conv op + ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); + cb(conv_out, "model.layers.{}.conv.conv", il); + + auto *y = ggml_mul(ctx0, c, conv_out); + + y = build_lora_mm(model.layers[il].shortconv.out_proj, y); + cb(y, "model.layers.{}.conv.out_proj", il); + + return y; + } + + // upcast to f32 before rms norm + ggml_tensor *lfm2_rms_norm(ggml_tensor *t, ggml_tensor *w) const { + auto *t_float = t; + if (t_float->type != GGML_TYPE_F32) { + t_float = ggml_cast(ctx0, t, GGML_TYPE_F32); + } + + auto *output = ggml_rms_norm(ctx0, t_float, hparams.f_norm_rms_eps); + if (output->type != t->type) { + output = ggml_cast(ctx0, output, t->type); + } + + return ggml_mul(ctx0, output, w); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -16195,6 +16457,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_EXAONE4: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_RWKV6: { llm = std::make_unique(*this, params, gf); @@ -16261,6 +16527,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_LFM2: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -16451,9 +16721,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ORION: case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: + case LLM_ARCH_EXAONE4: case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_LFM2: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index 431efbd516783..027a7f0c3e2c6 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -35,15 +35,18 @@ enum llm_type { LLM_TYPE_256M, LLM_TYPE_270M, LLM_TYPE_335M, + LLM_TYPE_350M, LLM_TYPE_410M, LLM_TYPE_450M, LLM_TYPE_475M, + LLM_TYPE_700M, LLM_TYPE_770M, LLM_TYPE_780M, LLM_TYPE_0_3B, LLM_TYPE_0_5B, LLM_TYPE_0_6B, LLM_TYPE_1B, + LLM_TYPE_1_2B, LLM_TYPE_1_3B, LLM_TYPE_1_4B, LLM_TYPE_1_5B, @@ -155,6 +158,12 @@ struct llama_layer_convnext { struct ggml_tensor * gamma = nullptr; }; +struct llama_layer_shortconv { + struct ggml_tensor * in_proj = nullptr; + struct ggml_tensor * conv = nullptr; + struct ggml_tensor * out_proj = nullptr; +}; + struct llama_layer { // normalization struct ggml_tensor * attn_norm = nullptr; @@ -341,6 +350,8 @@ struct llama_layer { struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; + + struct llama_layer_shortconv shortconv; }; struct llama_model { diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index f4b5713d7dd9a..4dbd1e309919a 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -844,6 +844,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: // do not quantize Mamba's small yet 2D weights // NOTE: can't use LLM_TN here because the layer number is not known quantize &= name.find("ssm_conv1d.weight") == std::string::npos; + quantize &= name.find("shortconv.conv.weight") == std::string::npos; // do not quantize RWKV's small yet 2D weights quantize &= name.find("time_mix_first.weight") == std::string::npos; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 02cdc244adbef..e0e578d6394d8 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1525,7 +1525,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "falcon3" || tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral" || - tokenizer_pre == "midm-2.0") { + tokenizer_pre == "midm-2.0" || + tokenizer_pre == "lfm2") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; add_bos = true;