diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 63b54a9cf6b48..c151b1648d1b9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -699,6 +699,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "b3f499bb4255f8ca19fccd664443283318f2fd2414d5e0b040fbdd0cc195d6c5": # ref: https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B res = "deepseek-r1-qwen" + if chkhsh == "f4f37b6c8eb9ea29b3eac6bb8c8487c5ab7885f8d8022e67edc1c68ce8403e95": + # ref: https://huggingface.co/MiniMaxAI/MiniMax-Text-01 + res = "minimax-01" if res is None: logger.warning("\n") @@ -4906,6 +4909,70 @@ def _reverse_hf_permute(data_torch, n_heads, hidden_dim): return data_torch +@Model.register("MiniMaxText01ForCausalLM") +class MiniMaxText01Model(Model): + model_arch = gguf.MODEL_ARCH.MINIMAX01 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + layernorm_full_attention_alpha = self.hparams["layernorm_full_attention_alpha"] + layernorm_full_attention_beta = self.hparams["layernorm_full_attention_beta"] + layernorm_linear_attention_alpha = self.hparams["layernorm_linear_attention_alpha"] + layernorm_linear_attention_beta = self.hparams["layernorm_linear_attention_beta"] + layernorm_mlp_alpha = self.hparams["layernorm_mlp_alpha"] + layernorm_mlp_beta = self.hparams["layernorm_mlp_beta"] + assert layernorm_full_attention_alpha == layernorm_linear_attention_alpha == layernorm_mlp_alpha + assert layernorm_full_attention_beta == layernorm_linear_attention_beta == layernorm_mlp_beta == 1.0 + # we do not store the layernorm betas as they are all 1.0 + # layernorm alphas are stored as single residual_scale hparam + self.gguf_writer.add_residual_scale(layernorm_full_attention_alpha) + + self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"]) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + n_head = self.hparams["num_attention_heads"] + n_kv_head = self.hparams.get("num_key_value_heads") + + # process the experts separately + if name.find("block_sparse_moe.experts") != -1: + n_experts = self.hparams["num_local_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + tensors: list[tuple[str, Tensor]] = [] + + # merge the experts into a single 3d tensor + for wid in ["w1", "w2", "w3"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"layers.{bid}.feed_forward.experts.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + return tensors + else: + return [] + + return [(self.map_tensor_name(name), data_torch)] + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 8fe84df21ea20..5d6356bbe6dfe 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -279,6 +279,7 @@ class MODEL_ARCH(IntEnum): GRANITE_MOE = auto() CHAMELEON = auto() WAVTOKENIZER_DEC = auto() + MINIMAX01 = auto() class MODEL_TENSOR(IntEnum): @@ -301,6 +302,7 @@ class MODEL_TENSOR(IntEnum): ATTN_OUT_NORM = auto() ATTN_POST_NORM = auto() ATTN_ROT_EMBD = auto() + ATTN_GATE = auto() FFN_GATE_INP = auto() FFN_GATE_INP_SHEXP = auto() FFN_NORM = auto() @@ -466,6 +468,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.CHAMELEON: "chameleon", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", + MODEL_ARCH.MINIMAX01: "minimax01", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -490,6 +493,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm", + MODEL_TENSOR.ATTN_GATE: "blk.{bid}.attn_gate", MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", @@ -1535,6 +1539,24 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.POSNET_ATTN_OUT, ], # TODO + MODEL_ARCH.MINIMAX01: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_NORM_2, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], } # tensors that will not be serialized diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 617791e240b60..2f28d0205e98c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -130,6 +130,7 @@ class TensorNameMap: "transformer.h.{bid}.ln_attn", # falcon40b "encoder.layer.{bid}.layer_norm_1", # jina-v2-code "rwkv.blocks.{bid}.ln2", # rwkv + "model.layers.{bid}.self_attn.norm", # minimax_text-01 ), # Attention query-key-value @@ -214,6 +215,7 @@ class TensorNameMap: "encoder.layers.{bid}.self_attention.dense", # chatglm "transformer.layers.{bid}.attn.out_proj", # openelm "transformer.h.{bid}.attn.attention.out_proj", # exaone + "model.layers.{bid}.self_attn.out_proj", # minimax_text-01 ), # Attention output norm @@ -236,6 +238,10 @@ class TensorNameMap: "transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell ), + MODEL_TENSOR.ATTN_GATE: ( + "model.layers.{bid}.self_attn.output_gate", # minimax-text-01 + ), + # Feed-forward norm MODEL_TENSOR.FFN_NORM: ( "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox diff --git a/include/llama.h b/include/llama.h index 3b75e760780ef..9540a27d9e172 100644 --- a/include/llama.h +++ b/include/llama.h @@ -105,6 +105,7 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_MINIMAX = 29, }; enum llama_rope_type { diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index a7260f495d945..8690966d4559f 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -62,6 +62,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_MINIMAX01, "minimax01" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1293,6 +1294,27 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, + { + LLM_ARCH_MINIMAX01, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1320,6 +1342,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 122fdcebe0af6..3d77b49ec4809 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -66,6 +66,7 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_MINIMAX01, LLM_ARCH_UNKNOWN, }; @@ -215,6 +216,7 @@ enum llm_tensor { LLM_TENSOR_ATTN_V, LLM_TENSOR_ATTN_QKV, LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_GATE, LLM_TENSOR_ATTN_NORM, LLM_TENSOR_ATTN_NORM_2, LLM_TENSOR_ATTN_OUT_NORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 671d2a81adabf..36d94006e90ad 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -463,6 +463,88 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) { } } } + + if (lctx.inp_slopes) { + const int64_t n_head = hparams.n_head(); + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_slopes->buffer)); + + float * data = (float *) lctx.inp_slopes->data; + + float start = powf(2, -powf(2, -(log2f(n_head) - 3))); + float ratio = start; + + for (int h = 0; h < n_head; ++h) { + data[h] = start * powf(ratio, h); + } + } + + if (lctx.inp_q_decay) { + const int64_t n_head = hparams.n_head(); + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_q_decay->buffer)); + + float * slopes = (float *) lctx.inp_slopes->data; + float * data = (float *) lctx.inp_q_decay->data; + + for (int i = 0; i < n_seq_tokens; ++i) { + for (int h = 0; h < n_head; ++h) { + data[i * n_head + h] = -slopes[h] * (i + 1); + } + } + } + + if (lctx.inp_k_decay) { + const int64_t n_head = hparams.n_head(); + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_k_decay->buffer)); + + float * slopes = (float *) lctx.inp_slopes->data; + float * data = (float *) lctx.inp_k_decay->data; + + for (int i = 0; i < n_seq_tokens; ++i) { + for (int h = 0; h < n_head; ++h) { + data[i * n_head + h] = -slopes[h] * (n_seq_tokens - i - 1); + } + } + } + + if (lctx.inp_diag_decay) { + const int64_t n_head = hparams.n_head(); + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_diag_decay->buffer)); + + float * slopes = (float *) lctx.inp_slopes->data; + float * data = (float *) lctx.inp_diag_decay->data; + + for (int j = 0; j < n_seq_tokens; ++j) { + for (int i = 0; i < n_seq_tokens; ++i) { + int index = j - i; + for (int h = 0; h < n_head; ++h) { + float s_index = index >= 0 ? -slopes[h] * index : -INFINITY; + data[j * n_head * n_seq_tokens + i * n_head + h] = s_index; + } + } + } + } + + if (lctx.inp_seq_ids) { + const int64_t n_head = hparams.n_head(); + const int64_t n_seqs = ubatch.n_seqs; + + GGML_ASSERT(n_seqs != 0); + + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_seq_ids->buffer)); + + uint32_t * data = (uint32_t *) lctx.inp_seq_ids->data; + + for (int s = 0; s < n_seqs; ++s) { + data[s] = (ubatch.seq_id ? ubatch.seq_id[s][0] : 0); + } + } } // llama output diff --git a/src/llama-context.h b/src/llama-context.h index a9268b2920908..cb579c9f77809 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -107,6 +107,11 @@ struct llama_context { struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] + struct ggml_tensor * inp_slopes; // F32 [n_head] + struct ggml_tensor * inp_q_decay; // F32 [n_batch, n_head] + struct ggml_tensor * inp_k_decay; // F32 [n_batch, n_head] + struct ggml_tensor * inp_diag_decay; // F32 [n_batch, n_batch, n_head] + struct ggml_tensor * inp_seq_ids; // F32 [n_batch, n_batch, n_head] }; // TODO: make these methods of llama_context diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index feffdf0de52cf..cd0e7c9341df5 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -27,6 +27,8 @@ bool llama_kv_cache_init( const struct llama_hparams & hparams = model.hparams; const int32_t n_layer = hparams.n_layer; + const int n_head = hparams.n_head(); + const int n_embd_head_k = hparams.n_embd_head_k; cache.has_shift = false; @@ -53,7 +55,7 @@ bool llama_kv_cache_init( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { struct ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(3u*n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -70,6 +72,7 @@ bool llama_kv_cache_init( cache.k_l.reserve(n_layer); cache.v_l.reserve(n_layer); + cache.kv_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); @@ -93,10 +96,13 @@ bool llama_kv_cache_init( ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * kv = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd_head_k, n_embd_head_k, n_head, cparams.n_seq_max); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); + ggml_format_name(kv, "cache_kv_l%d", i); cache.k_l.push_back(k); cache.v_l.push_back(v); + cache.kv_l.push_back(kv); } // allocate tensors and initialize the buffers to avoid NaNs in the padding diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index dca6f3998c645..e7b428680cd5c 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -54,6 +54,8 @@ struct llama_kv_cache { std::vector k_l; // per layer std::vector v_l; + std::vector kv_l; + std::vector ctxs; std::vector bufs; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 031b4c30b75dd..37a9bd5fb0ee8 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -65,6 +65,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_70B: return "70B"; case LLM_TYPE_236B: return "236B"; case LLM_TYPE_314B: return "314B"; + case LLM_TYPE_456B: return "456B"; case LLM_TYPE_671B: return "671B"; case LLM_TYPE_SMALL: return "0.1B"; case LLM_TYPE_MEDIUM: return "0.4B"; @@ -1233,6 +1234,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); } break; + case LLM_ARCH_MINIMAX01: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); + + switch (hparams.n_layer) { + case 80: type = LLM_TYPE_456B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -3392,6 +3403,43 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); } break; + case LLM_ARCH_MINIMAX01: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (i % 8 == 7) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + } else { + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd_head_k * n_head}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd_head_k * n_head}, 0); + layer.wg = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -3891,6 +3939,7 @@ enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: case LLM_ARCH_MINICPM3: + case LLM_ARCH_MINIMAX01: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: diff --git a/src/llama-model.h b/src/llama-model.h index a7c30444786fd..2ed52f79079f3 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -64,6 +64,7 @@ enum llm_type { LLM_TYPE_70B, LLM_TYPE_236B, LLM_TYPE_314B, + LLM_TYPE_456B, LLM_TYPE_671B, LLM_TYPE_SMALL, LLM_TYPE_MEDIUM, @@ -157,6 +158,7 @@ struct llama_layer { struct ggml_tensor * wv = nullptr; struct ggml_tensor * wo = nullptr; struct ggml_tensor * wqkv = nullptr; + struct ggml_tensor * wg = nullptr; struct ggml_tensor * wq_a = nullptr; struct ggml_tensor * wq_b = nullptr; struct ggml_tensor * wkv_a_mqa = nullptr; @@ -284,6 +286,9 @@ struct llama_layer { struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; + + // minimax-01 + struct ggml_tensor * kv = nullptr; }; struct llama_model { diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 0782d3a41a1f5..10c1f2b8d11e6 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; break; + case LLAMA_VOCAB_PRE_TYPE_MINIMAX: + // original regex from tokenizer.json + // "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + regex_exprs = { + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -1587,6 +1594,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "megrez") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; + } else if ( + tokenizer_pre == "minimax-01") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } diff --git a/src/llama.cpp b/src/llama.cpp index e8cfe5012819c..c283939c84253 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1157,6 +1157,11 @@ struct llm_build_context { lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; + lctx.inp_slopes = nullptr; + lctx.inp_q_decay = nullptr; + lctx.inp_k_decay = nullptr; + lctx.inp_diag_decay = nullptr; + lctx.inp_seq_ids = nullptr; } void free() { @@ -1473,6 +1478,45 @@ struct llm_build_context { return lctx.inp_KQ_mask_cross; } + struct ggml_tensor * llm_build_inp_slopes() { + lctx.inp_slopes = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_head); + ggml_set_input(lctx.inp_slopes); + cb(lctx.inp_slopes, "slopes", -1); + return lctx.inp_slopes; + } + + struct ggml_tensor * llm_build_inp_q_decay() { + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + lctx.inp_q_decay = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, n_head, n_seq_tokens); + ggml_set_input(lctx.inp_q_decay); + cb(lctx.inp_q_decay, "q_decay_exp", -1); + return lctx.inp_q_decay; + } + + struct ggml_tensor * llm_build_inp_k_decay() { + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + lctx.inp_k_decay = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, n_head, n_seq_tokens); + ggml_set_input(lctx.inp_k_decay); + cb(lctx.inp_k_decay, "k_decay_exp", -1); + return lctx.inp_k_decay; + } + + struct ggml_tensor * llm_build_inp_diag_decay() { + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + lctx.inp_diag_decay = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_head, n_seq_tokens, n_seq_tokens); + ggml_set_input(lctx.inp_diag_decay); + cb(lctx.inp_diag_decay, "diag_decay_exp", -1); + return lctx.inp_diag_decay; + } + + struct ggml_tensor * llm_build_inp_seq_ids() { + const int64_t n_seqs = ubatch.n_seqs; + lctx.inp_seq_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_seqs); + ggml_set_input(lctx.inp_seq_ids); + cb(lctx.inp_seq_ids, "seq_ids", -1); + return lctx.inp_seq_ids; + } + struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); @@ -8099,6 +8143,326 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_minimax01() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, model.max_nodes(), false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seq_max = cparams.n_seq_max; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + // slope tensor + struct ggml_tensor * slopes = llm_build_inp_slopes(); + struct ggml_tensor * q_decay_exp = (n_seq_tokens != 1 ? llm_build_inp_q_decay() : nullptr); + struct ggml_tensor * k_decay_exp = (n_seq_tokens != 1 ? llm_build_inp_k_decay() : nullptr); + struct ggml_tensor * diag_decay_exp = (n_seq_tokens != 1 ? llm_build_inp_diag_decay() : nullptr); + struct ggml_tensor * seq_ids = (n_seqs > 1 ? llm_build_inp_seq_ids() : nullptr); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + struct ggml_tensor * residual = cur; + + // self-attention + if (il % 8 == 7) { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, ggml_element_size(Qcur)*n_embd_head, ggml_element_size(Qcur)*n_embd_head*n_head, 0); + + Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, ggml_element_size(Kcur)*n_embd_head, ggml_element_size(Kcur)*n_embd_head*n_head_kv, 0); + + struct ggml_tensor * q_rope = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(q_rope, "q_rope", il); + + struct ggml_tensor * k_rope = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(k_rope, "k_rope", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, NULL, + k_rope, Vcur, q_rope, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } else { + float slope_scale = 1.0 - 1.0 * il / (n_layer - 1) + 1e-5; + struct ggml_tensor * slope_rate = ggml_scale(ctx0, slopes, slope_scale); + cb(slope_rate, "slope_rate", il); + + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + struct ggml_tensor * QKVcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); + cb(QKVcur, "QKVcur", il); + + QKVcur = ggml_silu(ctx0, QKVcur); + cb(QKVcur, "QKVcur_silu", il); + + QKVcur = ggml_view_4d(ctx0, QKVcur, n_embd_head * 3, n_head, n_seq_tokens, n_seqs, ggml_element_size(QKVcur)*n_embd_head*3, ggml_element_size(QKVcur)*n_embd_head*3*n_head, ggml_element_size(QKVcur)*n_embd_head*3*n_head*n_seq_tokens, 0); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_4d(ctx0, QKVcur, n_embd_head, n_head, n_seq_tokens, n_seqs, QKVcur->nb[1], QKVcur->nb[2], QKVcur->nb[3], 0*sizeof(float)*n_embd_head)); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_4d(ctx0, QKVcur, n_embd_head, n_head, n_seq_tokens, n_seqs, QKVcur->nb[1], QKVcur->nb[2], QKVcur->nb[3], 1*sizeof(float)*n_embd_head)); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_4d(ctx0, QKVcur, n_embd_head, n_head, n_seq_tokens, n_seqs, QKVcur->nb[1], QKVcur->nb[2], QKVcur->nb[3], 2*sizeof(float)*n_embd_head)); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + struct ggml_tensor * kv_old = ggml_view_2d(ctx0, kv_self.kv_l[il], n_embd_head*n_embd_head*n_head, n_seq_max, ggml_element_size(kv_self.kv_l[il])*n_embd_head*n_embd_head*n_head, 0); + cb(kv_old, "kv_old_2d", il); + + // optimization for a single sequence + if (n_seqs > 1) { + kv_old = ggml_get_rows(ctx0, kv_old, seq_ids); + cb(kv_old, "kv_old_2d_sel", il); + } + + kv_old = ggml_view_4d(ctx0, kv_old, n_embd_head, n_embd_head, n_head, n_seqs, ggml_element_size(kv_self.kv_l[il])*n_embd_head, ggml_element_size(kv_self.kv_l[il])*n_embd_head*n_embd_head, ggml_element_size(kv_self.kv_l[il])*n_embd_head*n_embd_head*n_head, 0); + cb(kv_old, "kv_old", il); + + struct ggml_tensor * qkv = nullptr; + struct ggml_tensor * kv_new = nullptr; + if (n_seq_tokens == 1) { + + struct ggml_tensor * slopes_neg = ggml_scale(ctx0, slope_rate, -1.0); + cb(slopes_neg, "slopes_neg", il); + + struct ggml_tensor * ratio = ggml_exp(ctx0, slopes_neg); + cb(ratio, "ratio", il); + + struct ggml_tensor * ratio_3d = ggml_view_3d(ctx0, ratio, 1, 1, n_head, ggml_element_size(ratio), ggml_element_size(ratio), 0); + cb(ratio_3d, "ratio3d", il); + + struct ggml_tensor * v_trans = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); + cb(v_trans, "v_trans", il); + + struct ggml_tensor * k_trans = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 1, 2, 0, 3)); + cb(k_trans, "k_trans", il); + + struct ggml_tensor * kv_cur = ggml_mul_mat(ctx0, v_trans, k_trans); + cb(kv_cur, "kv_cur", il); + + struct ggml_tensor * kv_old_s = ggml_mul(ctx0, kv_old, ratio_3d); + cb(kv_old_s, "kv_old_s", il); + + kv_new = ggml_add(ctx0, kv_old_s, kv_cur); + cb(kv_new, "kv_new", il); + + struct ggml_tensor * q_trans = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); + cb(q_trans, "q_trans", il); + + struct ggml_tensor * kv_new_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_new)); + cb(kv_new_trans, "kv_new_trans", il); + + qkv = ggml_mul_mat(ctx0, kv_new_trans, q_trans); + cb(qkv, "qkv", il); + } else if(n_seq_tokens > 1) { + struct ggml_tensor * q_decay = ggml_exp(ctx0, ggml_scale(ctx0, q_decay_exp, slope_scale)); + cb(q_decay, "q_decay", il); + struct ggml_tensor * k_decay = ggml_exp(ctx0, ggml_scale(ctx0, k_decay_exp, slope_scale)); + cb(k_decay, "k_decay", il); + struct ggml_tensor * diag_decay = ggml_exp(ctx0, ggml_scale(ctx0, diag_decay_exp, slope_scale)); + cb(diag_decay, "diag_decay", il); + + struct ggml_tensor * q_s = ggml_mul(ctx0, Qcur, q_decay); + cb(q_s, "q_s", il); + + struct ggml_tensor * q_s_trans = ggml_cont(ctx0, ggml_permute(ctx0, q_s, 0, 2, 1, 3)); + cb(q_s_trans, "q_s_trans", il); + + struct ggml_tensor * kv_old_trans = ggml_cont(ctx0, ggml_transpose(ctx0, kv_old)); + cb(kv_old_trans, "kv_old_trans", il); + + struct ggml_tensor * qkv_none_diag = ggml_mul_mat(ctx0, kv_old_trans, q_s_trans); + cb(qkv_none_diag, "qkv_none_diag", il); + + struct ggml_tensor * q_trans = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 0, 2, 1, 3)); + cb(q_trans, "q_trans", il); + + struct ggml_tensor * k_trans = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + cb(k_trans, "k_trans", il); + + struct ggml_tensor * qk = ggml_mul_mat(ctx0, k_trans, q_trans); + cb(qk, "qk", il); + + struct ggml_tensor * diag_decay_trans = ggml_cont(ctx0, ggml_permute(ctx0, diag_decay, 2, 0, 1, 3)); + + qk = ggml_mul(ctx0, qk, diag_decay_trans); + cb(qk, "qk_s", il); + + struct ggml_tensor * v_trans = ggml_cont(ctx0, ggml_permute(ctx0, Vcur, 1, 2, 0, 3)); + cb(v_trans, "v_trans", il); + + struct ggml_tensor * qkv_diag = ggml_mul_mat(ctx0, v_trans, qk); + cb(qkv_diag, "qkv_diag", il); + + qkv = ggml_add(ctx0, qkv_none_diag, qkv_diag); + cb(qkv, "qkv", il); + + ggml_build_forward_expand(gf, qkv); + + struct ggml_tensor * slopes_neg = ggml_scale(ctx0, slope_rate, -1.0*n_seq_tokens); + cb(slopes_neg, "slopes_neg", il); + + struct ggml_tensor * block_decay = ggml_exp(ctx0, slopes_neg); + cb(block_decay, "block_decay", il); + + struct ggml_tensor * block_decay_3d = ggml_view_3d(ctx0, block_decay, 1, 1, n_head, ggml_element_size(block_decay), ggml_element_size(block_decay), 0); + cb(block_decay_3d, "block_decay_3d", il); + + struct ggml_tensor * kv_old_s = ggml_mul(ctx0, kv_old, block_decay_3d); + cb(kv_old_s, "kv_old_s", il); + + struct ggml_tensor * k_after_decay = ggml_mul(ctx0, Kcur, k_decay); + cb(k_after_decay, "k_after_decay", il); + + struct ggml_tensor * k_after_decay_trans = ggml_cont(ctx0, ggml_permute(ctx0, k_after_decay, 1, 2, 0, 3)); + cb(k_after_decay_trans, "k_after_decay_trans", il); + + struct ggml_tensor * kv_cur = ggml_mul_mat(ctx0, v_trans, k_after_decay_trans); + cb(kv_cur, "kv_cur", il); + + kv_new = ggml_add(ctx0, kv_old_s, kv_cur); + cb(kv_new, "kv_new", il); + } + + // store new kv states for each processed sequence + // TODO is there a prettier way to do it? + for (uint64_t s = 0; s < ubatch.n_seqs; s++) { + uint64_t seq_id = ubatch.seq_id ? ubatch.seq_id[s][0] : 0; + struct ggml_tensor * kv_old_seq_view = ggml_view_4d(ctx0, kv_self.kv_l[il], n_embd_head, n_embd_head, n_head, 1, ggml_element_size(kv_self.kv_l[il])*n_embd_head, ggml_element_size(kv_self.kv_l[il])*n_embd_head*n_embd_head, ggml_element_size(kv_self.kv_l[il])*n_embd_head*n_embd_head*n_head, ggml_element_size(kv_self.kv_l[il])*n_embd_head*n_embd_head*n_head*seq_id); + struct ggml_tensor * kv_new_seq_view = ggml_view_4d(ctx0, kv_new, n_embd_head, n_embd_head, n_head, 1, ggml_element_size(kv_self.kv_l[il])*n_embd_head, ggml_element_size(kv_self.kv_l[il])*n_embd_head*n_embd_head, ggml_element_size(kv_self.kv_l[il])*n_embd_head*n_embd_head*n_head, ggml_element_size(kv_self.kv_l[il])*n_embd_head*n_embd_head*n_head*s); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, kv_new_seq_view, kv_old_seq_view)); + } + + qkv = ggml_cont(ctx0, ggml_permute(ctx0, qkv, 0, 2, 1, 3)); + cb(qkv, "qkv_permuted", il); + + qkv = ggml_view_3d(ctx0, qkv, qkv->ne[0]*qkv->ne[1], qkv->ne[2], qkv->ne[3], ggml_element_size(qkv)*qkv->ne[0]*qkv->ne[1], ggml_element_size(qkv)*qkv->ne[0]*qkv->ne[1]*qkv->ne[2], 0); + + // norm + struct ggml_tensor * qkv_norm = llm_build_norm(ctx0, qkv, hparams, + model.layers[il].attn_norm_2, NULL, + LLM_NORM_RMS, cb, il); + cb(qkv_norm, "qkv_norm", il); + + struct ggml_tensor * g = llm_build_lora_mm(lctx, ctx0, model.layers[il].wg, cur); + cb(g, "g", il); + + g = ggml_sigmoid(ctx0, g); + cb(g, "g_sigm", il); + + cur = ggml_mul(ctx0, g, qkv_norm); + + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); + cb(cur, "attn_out", il); + + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens*n_seqs); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + residual = ggml_scale(ctx0, residual, hparams.f_residual_scale); + cb(residual, "residual_scaled_attn", il); + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, residual); + cb(ffn_inp, "ffn_inp", il); + + // MoE + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + residual = cur; + + cur = llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + cb, il); + cb(cur, "ffn_moe_out", il); + + residual = ggml_scale(ctx0, residual, hparams.f_residual_scale); + cb(residual, "residual_scaled_ffn", il); + + cur = ggml_add(ctx0, cur, residual); + cb(cur, "ffn_out", il); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -8391,6 +8755,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_wavtokenizer_dec(); } break; + case LLM_ARCH_MINIMAX01: + { + result = llm.build_minimax01(); + } break; default: GGML_ABORT("fatal error"); } @@ -8515,7 +8883,7 @@ static int llama_decode_impl( } lctx.sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self.recurrent, + /* simple_split */ !(kv_self.recurrent || model.arch == LLM_ARCH_MINIMAX01), /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer @@ -8526,7 +8894,7 @@ static int llama_decode_impl( while (lctx.sbatch.n_tokens > 0) { llama_ubatch ubatch; - if (kv_self.recurrent) { + if (kv_self.recurrent || model.arch == LLM_ARCH_MINIMAX01) { if (embd_pooled) { // Pooled embeddings cannot be split across ubatches (yet) ubatch = lctx.sbatch.split_seq(n_ubatch);