diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2862405f806a9..82e61f6384566 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7253,7 +7253,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state( // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices - const int iq3 = ir/(neq2*neq1); + const int iq3 = ir / (neq2*neq1); const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); @@ -7425,6 +7425,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_with_state( memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); } } + void ggml_compute_forward_flash_attn_ext_mixed( const ggml_compute_params * params, const ggml_tensor * q, diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index d02ed4e7c0638..69ae443351518 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1630,6 +1630,22 @@ llm_graph_input_attn_kv_mixed * llm_graph_context::build_attn_inp_kv_mixed() con inp->self_kq_mask_quant_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_quant, GGML_TYPE_F16) : inp->self_kq_mask_quant; } + // Create state and result tensors for stateful flash attention + { + // State tensor: [2, n_heads * seq_len] for [M, S] pairs + const auto seq_len = n_tokens; // sequence length for current batch + inp->attn_state = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, n_head * seq_len); + ggml_set_input(inp->attn_state); + ggml_format_name(inp->attn_state, "attn_state"); + + // Result tensor: [head_dim, n_heads, seq_len, n_batch] - Fixed dimension order to match ggml_flash_attn_ext_with_state + const auto head_dim = n_embd_head_v; + const auto n_batch = 1; + inp->attn_result = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_dim, n_head, seq_len, n_batch); + ggml_set_input(inp->attn_result); + ggml_format_name(inp->attn_result, "attn_result"); + } + return (llm_graph_input_attn_kv_mixed *) res->add_input(std::move(inp)); } @@ -1754,6 +1770,8 @@ ggml_tensor * llm_graph_context::build_attn_mha_with_state( ggml_tensor * kq_mask_fp16, ggml_tensor * kq_mask_quant, ggml_tensor * v_mla, + ggml_tensor * state, + ggml_tensor * result, float kq_scale) const { GGML_UNUSED(gf); @@ -1845,16 +1863,11 @@ ggml_tensor * llm_graph_context::build_attn_mha_with_state( const auto n_kv_fp16 = k_fp16->ne[1]; // number of keys/values const auto n_kv_quant = k_quant->ne[1]; // number of keys/values - // TODO : Modify these tensors to be non-dynamic alloc. - // Create state tensor: [2, n_heads * seq_len] for [M, S] pairs - ggml_tensor * state = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2, n_head * seq_len); - ggml_set_input(state); - cb(state, "state", -1); + // Use pre-allocated state and result tensors to avoid dynamic allocation issues + GGML_ASSERT(state != nullptr && "State tensor must be pre-allocated"); + GGML_ASSERT(result != nullptr && "Result tensor must be pre-allocated"); - // Create output tensor - ggml_tensor * result = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, - q->ne[0], q->ne[1], q->ne[2], q->ne[3]); - ggml_set_input(result); + cb(state, "state", -1); cb(result, "result", -1); // Cast to F16 if needed for flash attention @@ -1905,10 +1918,11 @@ ggml_tensor * llm_graph_context::build_attn_mha_with_state( cb(result_quant, "attn_result_quant", -1); cur = result; - // Reshape to final output format - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, seq_len); + // Reshape to final output format: [head_dim * n_heads, seq_len] + // cur has dimensions [head_dim, n_heads, seq_len, n_batch], so we flatten the first two dimensions + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2]); - ggml_build_forward_expand(gf, cur); + ggml_build_forward_expand(gf, result); return cur; } @@ -1957,7 +1971,8 @@ ggml_tensor * llm_graph_context::build_attn_mixed_with_state( // Use the segmented flash attention with state ggml_tensor * cur = build_attn_mha_with_state( gf, q_cur, k_fp16, v_fp16, k_quant, v_quant, - kq_b, kq_mask_fp16, kq_mask_quant, v_mla, kq_scale + kq_b, kq_mask_fp16, kq_mask_quant, v_mla, + inp->get_state(), inp->get_result(), kq_scale ); cb(cur, "kqv_out", il); diff --git a/src/llama-graph.h b/src/llama-graph.h index 6a9245fadda53..17d20913c3936 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -313,11 +313,17 @@ class llm_graph_input_attn_kv_mixed : public llm_graph_input_i { ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * get_kq_mask_quant() const { return self_kq_mask_quant_cnv; } + ggml_tensor * get_state() const { return attn_state; } + ggml_tensor * get_result() const { return attn_result; } ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch] ggml_tensor * self_kq_mask_quant = nullptr; // F32 [n_kv, n_batch] ggml_tensor * self_kq_mask_quant_cnv = nullptr; // [n_kv, n_batch] + + // State and result tensors for stateful flash attention + ggml_tensor * attn_state = nullptr; // F32 [2, n_heads * seq_len] for [M, S] pairs + ggml_tensor * attn_result = nullptr; // F32 [head_dim, seq_len, n_heads, n_batch] output tensor const llama_hparams & hparams; const llama_cparams & cparams; @@ -654,6 +660,8 @@ struct llm_graph_context { ggml_tensor * kq_mask_fp16, ggml_tensor * kq_mask_quant, ggml_tensor * v_mla, + ggml_tensor * state, + ggml_tensor * result, float kq_scale) const; llm_graph_input_attn_cross * build_attn_inp_cross() const; diff --git a/src/llama-kv-cache-mixed.cpp b/src/llama-kv-cache-mixed.cpp index 6c3a1d3486719..7edb0773430dd 100644 --- a/src/llama-kv-cache-mixed.cpp +++ b/src/llama-kv-cache-mixed.cpp @@ -1003,7 +1003,7 @@ bool llama_kv_cache_mixed::find_slot(const llama_ubatch & ubatch) { n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad))); //> Virtual head of kv cache. n_quantized = std::min(size, std::max(n_pad, GGML_PAD(cell_max_quantized(), n_pad))); //> Virtual head of quantized kv cache. - LLAMA_LOG_INFO("\n[mixed-kv] successfully allocated slot: head=%u, used=%u, n=%u, n_quantized=%u, cell_max=%u, cell_max_quantized=%u\n", head, used, n, n_quantized, cell_max(), cell_max_quantized()); + // LLAMA_LOG_INFO("\n[mixed-kv] successfully allocated slot: head=%u, used=%u, n=%u, n_quantized=%u, cell_max=%u, cell_max_quantized=%u\n", head, used, n, n_quantized, cell_max(), cell_max_quantized()); return true; } @@ -1029,16 +1029,45 @@ void llama_kv_cache_mixed::set_input_kq_mask(ggml_tensor * dst, const llama_ubat const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seqs = ubatch->n_seqs; - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + // TODO: Currently I can use following code to force pass the llama-cli, but this is simple NOT DO COMPUTE. + + // Basic tensor validation + if (!dst) { + LLAMA_LOG_INFO("[mixed-kv] dst tensor is null in set_input_kq_mask, skipping\n"); + return; + } + + // Check if buffer and data are allocated + // This can happen during graph building phase before allocation + if (!dst->buffer || !dst->data) { + LLAMA_LOG_INFO("[mixed-kv] Buffer or data not allocated yet for %s mask, skipping mask setting\n", + is_quantized ? "quantized" : "FP16"); + return; + } + + // Verify buffer is host-accessible before proceeding + if (!ggml_backend_buffer_is_host(dst->buffer)) { + LLAMA_LOG_INFO("[mixed-kv] Buffer for %s mask is not host-accessible\n", + is_quantized ? "quantized" : "FP16"); + return; + } + float * data = (float *) dst->data; + // Final safety check for data pointer + if (!data) { + LLAMA_LOG_INFO("[mixed-kv] Data pointer is null for %s mask tensor\n", + is_quantized ? "quantized" : "FP16"); + return; + } + // Choose the correct KV length based on whether we're setting mask for quantized or FP16 part // - For FP16 mask (is_quantized=false): use n (covers recent tokens) // - For quantized mask (is_quantized=true): use n_quantized (covers older tokens) const int64_t n_kv = is_quantized ? n_quantized : n; - LLAMA_LOG_DEBUG("[mixed-kv] Setting %s mask: n_kv=%ld (n=%u, n_quantized=%u)\n", - is_quantized ? "quantized" : "FP16", n_kv, n, n_quantized); + LLAMA_LOG_DEBUG("[mixed-kv] Setting %s mask: n_kv=%lld (n=%u, n_quantized=%u)\n", + is_quantized ? "quantized" : "FP16", (long long)n_kv, n, n_quantized); // Use only the previous KV cells of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. @@ -1057,53 +1086,53 @@ void llama_kv_cache_mixed::set_input_kq_mask(ggml_tensor * dst, const llama_ubat const llama_seq_id seq_id = ubatch->seq_id[s][0]; for (int j = 0; j < n_seq_tokens; ++j) { - // Current query token's position const llama_pos p1 = ubatch->pos[s*n_seq_tokens + j]; - // Loop through all tokens in KV cache for (int i = 0; i < n_kv; ++i) { - // Current key token's position - const llama_pos p0 = cells[i].pos; //> kv_cache idx. - - bool masked = false; - - // Rule 0: For mixed cache, check if cell belongs to the part we're masking - // - For FP16 mask: only include non-quantized cells - // - For quantized mask: only include quantized cells + // Get position from appropriate cell range + llama_pos p0; if (is_quantized) { - masked = masked || (!cells[i].is_quantized()); // Skip non-quantized cells for quantized mask + // For quantized mask, use cells [0, n_quantized) + p0 = (i < cells.size()) ? cells[i].pos : -1; } else { - masked = masked || (cells[i].is_quantized()); // Skip quantized cells for FP16 mask + // For FP16 mask, use cells [n_quantized, n_quantized + n) + uint32_t cell_idx = n_quantized + i; + p0 = (cell_idx < cells.size()) ? cells[cell_idx].pos : -1; } - // Rule 1: If key token not in current query token's sequence, mask. - masked = masked || (!cells[i].has_seq_id(seq_id)); //> This cell is not in the current query token's sequence. + bool masked = false; - // Rule 2: If causal attention and key token after query token (future), mask. - masked = masked || (causal_attn && p0 > p1); //> p0 in SEQ_LEN > p1 in KV_LEN. + // mask if invalid cell + masked = masked || (p0 < 0); + + // mask the token if not the same sequence + if (!masked && i < (int)cells.size()) { + uint32_t cell_idx = is_quantized ? i : (n_quantized + i); + if (cell_idx < cells.size()) { + masked = masked || (!cells[cell_idx].has_seq_id(seq_id)); + } + } + + // mask future tokens + masked = masked || (causal_attn && p0 > p1); float f = 0.0f; if (masked) { - // For masked tokens, set attention score to negative infinity f = -INFINITY; } else if (hparams.use_alibi) { - // Rule 3: If using ALiBi, compute penalty based on query-key distance f = -std::abs(p0 - p1); } - // Write computed mask value to destination tensor data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f; } } } - // Rule 4: Mask padding tokens in batch (adapted for mixed KV cache) - if (data) { - for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (int j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; - } + // mask padded tokens + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; } } }