Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
69943f8
opencl: support ne3 in get_rows (llama/15866)
lhez Sep 30, 2025
3a5a354
ggml webgpu: support for rope,div,sub,glu,scale,cont operators (llama…
reeselevine Sep 30, 2025
a57c9f6
opencl: support pad_ext (llama/15888)
lhez Sep 30, 2025
032abbc
vulkan: make ggml_vk_default_dispatcher support older vulkan headers …
netrunnereve Oct 1, 2025
4d08090
HIP: Disable ROCWMMA fattn on CDNA when compiled against ROCWMMA 2.0.…
IMbackK Oct 1, 2025
3b2df32
musa: update compile flags (llama/16265)
yeahdongcn Oct 2, 2025
5cd3424
model : Apertus model implementation (llama/15852)
pwilkin Oct 2, 2025
cc6dc14
ggml webgpu: add support for soft_max, optimize rms_norm (llama/16357)
reeselevine Oct 2, 2025
3f2ecff
vulkan: in flash attention, bounds check against nem1 (don't rely on …
jeffbolznv Oct 3, 2025
fe538c2
vulkan: Fix FA coopmat1 invalid array indexing (llama/16365)
jeffbolznv Oct 3, 2025
2e23591
vulkan: Replace uses of maxMemoryAllocationSize and VK_WHOLE_SIZE (ll…
jeffbolznv Oct 3, 2025
3d3000f
ggml : fix graph reallocation with multiple chunks (llama/16396)
Acly Oct 3, 2025
5f89599
metal : fix loop bound in ggml_mem_ranges (llama/16412)
ggerganov Oct 3, 2025
75159b5
vulkan : incremental shader builds (llama/16341)
Acly Oct 11, 2025
0c56ec3
rpc : add support for multiple devices (llama/16276)
rgerganov Oct 4, 2025
98b549d
rpc : check src buffer when copying tensor (llama/16421)
rgerganov Oct 4, 2025
6e7e1b8
vulkan: use a more appropriate amount of threads when generating shad…
netrunnereve Oct 4, 2025
72b9fa0
ggml webgpu: actually add softmax, fix rms_norm offset (llama/16400)
reeselevine Oct 5, 2025
73265c0
ggml-cpu : fix leftover handling in ggml_vec_scale_f32 for SVE (llama…
danbev Oct 6, 2025
352a07a
ggml : fix unaligned access in AMX code (llama/16315)
ggerganov Oct 6, 2025
389681e
metal : various optimizations + refactoring (llama/16446)
ggerganov Oct 7, 2025
091a5c1
tests : add -INF blocks to the KQ mask in the FA tests (llama/16380)
ggerganov Oct 7, 2025
d75f9ae
metal : add support for non-padded FA KV (llama/16148)
ggerganov Oct 7, 2025
c8d88fc
ggml webgpu: profiling, CI updates, reworking of command submission (…
reeselevine Oct 7, 2025
1b7b120
metal : mark FA blocks (llama/16372)
ggerganov Oct 8, 2025
57d8e6b
Disable CUDA host buffers on integrated GPUs (llama/16308)
ai-fonsi Oct 8, 2025
73b3339
refactor soft_max, add soft_max_back (llama/16472)
NeoZhangJianyu Oct 9, 2025
ba2e955
kleidiai: kernel interface refactoring (llama/16460)
chaxu01 Oct 9, 2025
910395c
CANN: Improve ACL graph matching (llama/16166)
noemotiovon Oct 9, 2025
779ca59
cpu : optimize the ggml NORM operation (llama/15953)
duduta Oct 9, 2025
667e364
cmake : Dont define XOPENSOURCE on AIX (llama/16481)
mehendarkarprajwal Oct 10, 2025
d477505
cuda : avoid initializing unused devices (llama/16510)
slaren Oct 11, 2025
33f7862
metal : fix mul-mm condition + fix mul-mv permuted kernels (llama/16494)
ggerganov Oct 11, 2025
4f77668
sync : ggml
ggerganov Oct 12, 2025
2ad7a69
talk-llama : sync llama.cpp
ggerganov Oct 12, 2025
55d8f01
bench : update [no ci]
ggerganov Oct 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/talk-llama/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_SMOLLM3, "smollm3" },
{ LLM_ARCH_OPENAI_MOE, "gpt-oss" },
{ LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_LFM2MOE, "lfm2moe" },
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_LLADA, "llada" },
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_GROVEMOE, "grovemoe" },
{ LLM_ARCH_APERTUS, "apertus" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -217,6 +219,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },

{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
// sentence-transformers dense modules feature dims
{ LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" },
{ LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" },
{ LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" },
{ LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" },

{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
Expand Down Expand Up @@ -256,6 +263,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" },
{ LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" },

{ LLM_KV_XIELU_ALPHA_N, "xielu.alpha_n" },
{ LLM_KV_XIELU_ALPHA_P, "xielu.alpha_p" },
{ LLM_KV_XIELU_BETA, "xielu.beta" },
{ LLM_KV_XIELU_EPS, "xielu.eps" },

// deprecated
{ LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" },
{ LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
Expand Down Expand Up @@ -1064,6 +1076,8 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_DENSE_2_OUT, "dense_2" },
{ LLM_TENSOR_DENSE_3_OUT, "dense_3" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
Expand Down Expand Up @@ -2098,6 +2112,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_OUTPUT, "output" },
}
},
{
LLM_ARCH_LFM2MOE,
{
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
}
},
{
LLM_ARCH_SMALLTHINKER,
{
Expand All @@ -2119,6 +2159,25 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }
},
},
{
LLM_ARCH_APERTUS,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_DREAM,
{
Expand Down Expand Up @@ -2229,6 +2288,8 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
{LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
Expand Down Expand Up @@ -2468,6 +2529,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_PLAMO2:
case LLM_ARCH_GRANITE_HYBRID:
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_NEMOTRON_H:
return true;
default:
Expand Down
15 changes: 15 additions & 0 deletions examples/talk-llama/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,14 @@ enum llm_arch {
LLM_ARCH_SMOLLM3,
LLM_ARCH_OPENAI_MOE,
LLM_ARCH_LFM2,
LLM_ARCH_LFM2MOE,
LLM_ARCH_DREAM,
LLM_ARCH_SMALLTHINKER,
LLM_ARCH_LLADA,
LLM_ARCH_LLADA_MOE,
LLM_ARCH_SEED_OSS,
LLM_ARCH_GROVEMOE,
LLM_ARCH_APERTUS,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -260,17 +262,30 @@ enum llm_kv {

LLM_KV_SHORTCONV_L_CACHE,

LLM_KV_XIELU_ALPHA_N,
LLM_KV_XIELU_ALPHA_P,
LLM_KV_XIELU_BETA,
LLM_KV_XIELU_EPS,

// deprecated:
LLM_KV_TOKENIZER_PREFIX_ID,
LLM_KV_TOKENIZER_SUFFIX_ID,
LLM_KV_TOKENIZER_MIDDLE_ID,

// sentence-transformers dense layers in and out features
LLM_KV_DENSE_2_FEAT_IN,
LLM_KV_DENSE_2_FEAT_OUT,
LLM_KV_DENSE_3_FEAT_IN,
LLM_KV_DENSE_3_FEAT_OUT,
};

enum llm_tensor {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_TOKEN_EMBD_NORM,
LLM_TENSOR_TOKEN_TYPES,
LLM_TENSOR_POS_EMBD,
LLM_TENSOR_DENSE_2_OUT,
LLM_TENSOR_DENSE_3_OUT,
LLM_TENSOR_OUTPUT,
LLM_TENSOR_OUTPUT_NORM,
LLM_TENSOR_ROPE_FREQS,
Expand Down
2 changes: 1 addition & 1 deletion examples/talk-llama/llama-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ int32_t llm_chat_apply_template(
ss << message->content << "<|end_of_text|>\n";
}
if (add_ass) {
ss << "<|start_of_role|>assistant<|end_of_role|>\n";
ss << "<|start_of_role|>assistant<|end_of_role|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) {
// GigaChat template
Expand Down
6 changes: 6 additions & 0 deletions examples/talk-llama/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2346,6 +2346,12 @@ llama_context * llama_init_from_model(
return nullptr;
}

if (params.pooling_type != model->hparams.pooling_type) {
//user-specified pooling-type is different from the model default
LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__,
model->hparams.pooling_type, params.pooling_type);
}

try {
auto * ctx = new llama_context(*model, params);
return ctx;
Expand Down
17 changes: 17 additions & 0 deletions examples/talk-llama/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
}

void llm_graph_context::build_dense_out(
ggml_tensor * dense_2,
ggml_tensor * dense_3) const {
if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) {
return;
}
ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd;
GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd");

cur = ggml_mul_mat(ctx0, dense_2, cur);
cur = ggml_mul_mat(ctx0, dense_3, cur);
cb(cur, "result_embd_pooled", -1);
res->t_embd_pooled = cur;
ggml_build_forward_expand(gf, cur);
}


void llm_graph_context::build_pooling(
ggml_tensor * cls,
ggml_tensor * cls_b,
Expand Down
8 changes: 8 additions & 0 deletions examples/talk-llama/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,14 @@ struct llm_graph_context {
ggml_tensor * cls_b,
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const;

//
// dense (out)
//

void build_dense_out(
ggml_tensor * dense_2,
ggml_tensor * dense_3) const;
};

// TODO: better name
Expand Down
6 changes: 5 additions & 1 deletion examples/talk-llama/llama-hparams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,11 @@ uint32_t llama_hparams::n_embd_s() const {
}

bool llama_hparams::is_recurrent(uint32_t il) const {
return recurrent_layer_arr[il];
if (il < n_layer) {
return recurrent_layer_arr[il];
}

GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer);
}

uint32_t llama_hparams::n_pos_per_embd() const {
Expand Down
14 changes: 13 additions & 1 deletion examples/talk-llama/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct llama_hparams {
uint32_t n_embd;
uint32_t n_embd_features = 0;
uint32_t n_layer;
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache
uint32_t n_rot;
uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
Expand Down Expand Up @@ -169,6 +169,18 @@ struct llama_hparams {
uint32_t laurel_rank = 64;
uint32_t n_embd_altup = 256;

// needed for sentence-transformers dense layers
uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense
uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense
uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense
uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense

// xIELU
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_n;
std::array<float, LLAMA_MAX_LAYERS> xielu_alpha_p;
std::array<float, LLAMA_MAX_LAYERS> xielu_beta;
std::array<float, LLAMA_MAX_LAYERS> xielu_eps;

// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
Expand Down
4 changes: 2 additions & 2 deletions examples/talk-llama/llama-kv-cache-iswa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,15 +220,15 @@ bool llama_kv_cache_iswa::get_can_shift() const {
}

void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
kv_base->state_write(io, seq_id, flags);
}

kv_swa->state_write(io, seq_id, flags);
}

void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
kv_base->state_read(io, seq_id, flags);
}

Expand Down
7 changes: 2 additions & 5 deletions examples/talk-llama/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache(
throw std::runtime_error("failed to create ggml context for kv cache");
}

ggml_tensor * k;
ggml_tensor * v;

k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);
ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream);
ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream);

ggml_format_name(k, "cache_k_l%d", il);
ggml_format_name(v, "cache_v_l%d", il);
Expand Down
20 changes: 11 additions & 9 deletions examples/talk-llama/llama-memory-hybrid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch, false);
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
}

if (ubatch.n_tokens == 0) {
Expand Down Expand Up @@ -175,17 +177,17 @@ std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdo
}

void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
GGML_UNUSED(flags);

mem_attn->state_write(io, seq_id);
mem_recr->state_write(io, seq_id);
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
mem_attn->state_write(io, seq_id, flags);
}
mem_recr->state_write(io, seq_id, flags);
}

void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
GGML_UNUSED(flags);

mem_attn->state_read(io, seq_id);
mem_recr->state_read(io, seq_id);
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
mem_attn->state_read(io, seq_id, flags);
}
mem_recr->state_read(io, seq_id, flags);
}

llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {
Expand Down
14 changes: 11 additions & 3 deletions examples/talk-llama/llama-memory-recurrent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) {
}

bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
//printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
uint32_t new_head = size;

if (p0 < 0) {
Expand All @@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
if (tail_id >= 0) {
const auto & cell = cells[tail_id];
// partial intersection is invalid
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
return false;
}
// invalidate tails which will be cleared
Expand All @@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
} else {
// seq_id is negative, then the range should include everything or nothing
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n");
return false;
}
}
Expand Down Expand Up @@ -379,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
// if all tokens are output, split by sequence
ubatch = balloc.split_seq(n_ubatch);
} else {
ubatch = balloc.split_equal(n_ubatch, false);
// TODO: non-sequential equal split can be done if using unified KV cache
// for simplicity, we always use sequential equal split for now
ubatch = balloc.split_equal(n_ubatch, true);
}

if (ubatch.n_tokens == 0) {
Expand Down Expand Up @@ -856,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
if (dest_seq_id != -1) {
// single sequence

seq_rm(dest_seq_id, -1, -1);

if (cell_count == 0) {
return true;
}

llama_batch_allocr balloc(hparams.n_pos_per_embd());

llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
Expand Down
Loading