diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index c3c8e954b12c6..1b6a3f211e018 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2222,19 +2222,6 @@ extern "C" { GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads); GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1); - // Enhanced flash attention with state tensor for S/M values - // s_m_state: [2, n_heads * q_len] tensor containing [M, S] pairs for each head/position - GGML_API struct ggml_tensor * ggml_flash_attn_ext_with_state( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - struct ggml_tensor * mask, - struct ggml_tensor * s_m_state, // State tensor for S and M values - float scale, - float max_bias, - float logit_softcap); - #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index a614b2001bf64..3a8547729ba23 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -656,6 +656,7 @@ static void ggml_compute_forward_dup_bf16( GGML_ABORT("fatal error"); // TODO: implement } } + static void ggml_compute_forward_dup_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -1215,7 +1216,7 @@ static void ggml_compute_forward_add_q_f32( GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); - const int nr = ggml_nrows(src1); + const int nr = ggml_nrows(src0); GGML_TENSOR_BINARY_OP_LOCALS @@ -1425,6 +1426,7 @@ static void ggml_compute_forward_add1_f16_f32( } } } + static void ggml_compute_forward_add1_f16_f16( const ggml_compute_params * params, ggml_tensor * dst) { @@ -2189,7 +2191,9 @@ void ggml_compute_forward_count_equal( } } } + // ggml_compute_forward_repeat + static void ggml_compute_forward_repeat_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -2976,6 +2980,7 @@ static void ggml_compute_forward_silu_f16( #endif } } + static void ggml_compute_forward_silu( const ggml_compute_params * params, ggml_tensor * dst) { @@ -2997,6 +3002,7 @@ static void ggml_compute_forward_silu( } } } +// ggml_compute_forward_leaky_relu static void ggml_compute_forward_leaky_relu_f32( const ggml_compute_params * params, @@ -3080,6 +3086,8 @@ void ggml_compute_forward_leaky_relu( } } +// ggml_compute_forward_silu_back + static void ggml_compute_forward_silu_back_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -3189,6 +3197,8 @@ void ggml_compute_forward_silu_back( } } +// ggml_compute_forward_norm + static void ggml_compute_forward_norm_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -3258,6 +3268,8 @@ void ggml_compute_forward_norm( } } +// ggml_compute_forward_group_rms_norm + static void ggml_compute_forward_rms_norm_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -3429,12 +3441,13 @@ static void ggml_compute_forward_rms_norm_back_f32( // grad[#02] = repeat(scale(grad[#07],#04), #02) // grad[#02] = repeat(scale(mul(grad[#08], div(0.5, #08)),#04), #02) // grad[#02] = repeat(scale(mul(neg(mul(grad[#09], div(#09,#08))), div(0.5, #08)),#04), #02) - // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00))), div(#09,#08)), div(0.5, #08)),#04), #02) + // grad[#02] = repeat(scale(mul(neg(mul(sum(mul(grad[#10],#00)), div(#09,#08))), div(0.5, #08)),#04), #02) // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(#09,#08) * div(0.5, #08) * (1/N)), #02) // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(div(#01,#08),#08) * div(0.5, #08) * (1/N)), #02) // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#08*#08) * div(0.5, #08) * (1/N)), #02) // grad[#02] = repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02) // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, grad[#02]), 2.0) + // grad[#00] = scale(grad(#10), #09) + scale(mul(#00, repeat(-(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N)), #02)), 2.0) // grad[#00] = scale(grad(#10), #09) + scale(scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(0.5, #08) * (1/N))), 2.0) // grad[#00] = scale(grad(#10), #09) + scale(#00, -(sum(mul(grad[#10],#00)) * div(1,#07) * div(1,#08) * (1/N))) // grad[#00] = scale(grad(#10), #09) + scale(#00, sum(mul(grad[#10],#00)) * div(1,#07*#08) * (-1/N)) @@ -3498,6 +3511,8 @@ void ggml_compute_forward_rms_norm_back( } } +// ggml_compute_forward_group_norm + static void ggml_compute_forward_group_norm_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -3591,6 +3606,8 @@ void ggml_compute_forward_group_norm( } } +// ggml_compute_forward_l2_norm + static void ggml_compute_forward_l2_norm_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -3652,6 +3669,8 @@ void ggml_compute_forward_l2_norm( } } +// ggml_compute_forward_out_prod + static void ggml_compute_forward_out_prod_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -4521,6 +4540,7 @@ static void ggml_compute_forward_get_rows_back_f32( (float *) ((char *) src0->data + i*src0->nb[1])); } } + void ggml_compute_forward_get_rows_back( const ggml_compute_params * params, ggml_tensor * dst) { @@ -4888,6 +4908,8 @@ static void ggml_compute_forward_soft_max_ext_back_f32( // dxk = sum_i(-yk*yi * dyi) + yk*dyk // dxk = -yk * sum_i(yi * dyi) + yk*dyk // dxk = -yk * dot(y, dy) + yk*dyk + // dxk = yk * (- dot(y, dy) + dyk) + // dxk = yk * (dyk - dot(y, dy)) // // post-order: // dot_y_dy := dot(y, dy) @@ -5166,6 +5188,7 @@ static void ggml_mrope_cache_init( theta_e *= theta_scale; } } + static void ggml_compute_forward_rope_f32( const ggml_compute_params * params, ggml_tensor * dst, @@ -5953,7 +5976,9 @@ void ggml_compute_forward_im2col( } } } + // ggml_compute_forward_im2col_back_f32 + void ggml_compute_forward_im2col_back_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -6746,7 +6771,9 @@ void ggml_compute_forward_pad( } } } + // ggml_compute_forward_pad_reflect_1d + void ggml_compute_forward_pad_reflect_1d( const ggml_compute_params * params, ggml_tensor * dst) { @@ -7156,251 +7183,6 @@ static void ggml_compute_forward_flash_attn_ext_f16( } } -static void ggml_compute_forward_flash_attn_ext_f16_with_state( - const ggml_compute_params * params, - const ggml_tensor * q, - const ggml_tensor * k, - const ggml_tensor * v, - const ggml_tensor * mask, - const ggml_tensor * state, - ggml_tensor * dst) { - - GGML_TENSOR_LOCALS(int64_t, neq, q, ne) - GGML_TENSOR_LOCALS(size_t, nbq, q, nb) - GGML_TENSOR_LOCALS(int64_t, nek, k, ne) - GGML_TENSOR_LOCALS(size_t, nbk, k, nb) - GGML_TENSOR_LOCALS(int64_t, nev, v, ne) - GGML_TENSOR_LOCALS(size_t, nbv, v, nb) - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - - // Validate state tensor format: [2, n_heads * q_len] - GGML_ASSERT(state != NULL); - GGML_ASSERT(state->ne[0] == 2); // [M, S] pairs - GGML_ASSERT(state->ne[1] == neq2 * neq1); // n_heads * q_len - GGML_ASSERT(state->type == GGML_TYPE_F32); - - const int ith = params->ith; - const int nth = params->nth; - - const int64_t DK = nek0; //> head_dim - const int64_t DV = nev0; //> head_dim - const int64_t N = neq1; //> q_len - - GGML_ASSERT(ne0 == DV); //> dst -> ne[0] == head_dim - GGML_ASSERT(ne2 == N); //> dst -> ne[2] == q_len - - // input tensor rows must be contiguous - //> QKV cannot do transpose. - GGML_ASSERT(nbq0 == ggml_type_size(q->type)); - GGML_ASSERT(nbk0 == ggml_type_size(k->type)); - GGML_ASSERT(nbv0 == ggml_type_size(v->type)); - - //> V donot transpose before. - GGML_ASSERT(neq0 == DK); //> q -> ne[0] == head_dim - GGML_ASSERT(nek0 == DK); //> k -> ne[0] == head_dim - GGML_ASSERT(nev0 == DV); //> v -> ne[0] == head_dim - - GGML_ASSERT(neq1 == N); //> q -> ne[1] == q_len - - // dst cannot be transposed or permuted - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb0 <= nb1); - GGML_ASSERT(nb1 <= nb2); - GGML_ASSERT(nb2 <= nb3); - - // broadcast factors - const int64_t rk2 = neq2/nek2; //> n_q_head / n_kv_head - const int64_t rk3 = neq3/nek3; //> n_q_batch / n_kv_batch - - const int64_t rv2 = neq2/nev2; //> n_q_head / n_v_head - const int64_t rv3 = neq3/nev3; //> n_q_batch / n_v_batch - - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int nr = neq1*neq2*neq3; //> number of rows, one row is one head_dim. - - // NOTE: Parallelize by q rows. - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; - - memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); - - if (logit_softcap != 0) { - scale /= logit_softcap; - } - - const uint32_t n_head = neq2; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); - - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - - ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; - ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float; - ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; - ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; - - GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type"); - GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type"); - - // loop over n_batch and n_head - for (int ir = ir0; ir < ir1; ++ir) { - // q indices - const int iq3 = ir/(neq2*neq1); - const int iq2 = (ir - iq3*neq2*neq1)/neq1; - const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); - - const uint32_t h = iq2; // head index - const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; - - // Calculate state tensor offset for this head/position - const int64_t state_idx = iq2 * neq1 + iq1; // head * q_len + position - float * state_data = (float *)state->data; - - // Read initial S and M values from state tensor - // State format: [M, S] for each head/position - float S = state_data[state_idx * 2 + 1]; // sum (index 1) - float M = state_data[state_idx * 2 + 0]; // maximum KQ value (index 0) - - // If this is the first call (indicated by M == -INFINITY), initialize properly - if (M == -INFINITY) { - S = 0.0f; - } - - float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator - float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer - ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator - ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16 - - if (v->type == GGML_TYPE_F16) { - memset(VKQ16, 0, DV*sizeof(ggml_fp16_t)); - } else { - memset(VKQ32, 0, DV*sizeof(float)); - } - - const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; - - // k indices - const int ik3 = iq3 / rk3; - const int ik2 = iq2 / rk2; - - // v indices - const int iv3 = iq3 / rv3; - const int iv2 = iq2 / rv2; - - const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); - q_to_vec_dot(pq, Q_q, DK); - - // online softmax / attention - // loop over n_kv and n_head_kv - // ref: https://arxiv.org/pdf/2112.05682.pdf - for (int64_t ic = 0; ic < nek1; ++ic) { - const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f; - if (mv == -INFINITY) { - continue; - } - - float s; // KQ value - - //> k_data: [head_dim, kv_len, n_kv_head, n_kv_batch] - const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3); - kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1); - - s = s*scale; // scale KQ value - - if (logit_softcap != 0.0f) { - s = logit_softcap*tanhf(s); - } - - s += mv; // apply mask - - const float Mold = M; - - float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value - float vs = 1.0f; // post-softmax KQ value, expf(s - M) - - const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); - - if (v->type == GGML_TYPE_F16) { - if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f - M = s; - ms = expf(Mold - M); - - // V = V*expf(Mold - M) - ggml_vec_scale_f16(DV, VKQ16, ms); - } else { - // no new maximum, ms == 1.0f, vs != 1.0f - vs = expf(s - M); - } - - // V += v*expf(s - M) - //> VKQ16 = VKQ16 + v_data * expf(s - M) - ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs); - } else { - if (s > M) { - // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f - M = s; - ms = expf(Mold - M); - - // V = V*expf(Mold - M) - ggml_vec_scale_f32(DV, VKQ32, ms); - } else { - // no new maximum, ms == 1.0f, vs != 1.0f - vs = expf(s - M); - } - - // V += v*expf(s - M) - if (v_to_float) { - v_to_float(v_data, V32, DV); - ggml_vec_mad_f32(DV, VKQ32, V32, vs); - } else { - // V is F32 - ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs); - } - } - - S = S*ms + vs; // scale and increment sum with partial sum - } - - // Write updated S and M values back to state tensor - state_data[state_idx * 2 + 0] = M; // maximum KQ value (index 0) - state_data[state_idx * 2 + 1] = S; // sum (index 1) - - if (v->type == GGML_TYPE_F16) { - for (int64_t d = 0; d < DV; ++d) { - VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]); - } - } - - // V /= S - const float S_inv = 1.0f / S; - ggml_vec_scale_f32(DV, VKQ32, S_inv); - - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; - - // original - // memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); - - // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); - } -} void ggml_compute_forward_flash_attn_ext_mixed( const ggml_compute_params * params, const ggml_tensor * q, @@ -7754,14 +7536,7 @@ void ggml_compute_forward_flash_attn_ext( case GGML_PREC_F32: { // uses F32 accumulators - // Check if we have additional sources beyond the required ones for state tensor - if (dst->src[6] != nullptr) { - // State tensor is provided as src[6] - use enhanced function with S/M state - ggml_compute_forward_flash_attn_ext_f16_with_state(params, q, k, v, mask, dst->src[6], dst); - } else { - // Standard function without state tensor - ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); - } + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); } break; case GGML_PREC_MIXED: { @@ -7773,6 +7548,7 @@ void ggml_compute_forward_flash_attn_ext( } } } + // ggml_compute_forward_flash_attn_back static void ggml_compute_forward_flash_attn_back_f32( @@ -8178,7 +7954,9 @@ void ggml_compute_forward_ssm_conv( } } } + // ggml_compute_forward_ssm_scan + static void ggml_compute_forward_ssm_scan_f32( const ggml_compute_params * params, ggml_tensor * dst) { @@ -8515,6 +8293,7 @@ void ggml_compute_forward_get_rel_pos( } } } + // ggml_compute_forward_add_rel_pos static void ggml_compute_forward_add_rel_pos_f32( @@ -8969,6 +8748,8 @@ static void ggml_compute_forward_gla_f32( } #endif } + + void ggml_compute_forward_gla( const ggml_compute_params * params, ggml_tensor * dst) { @@ -9311,6 +9092,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( dp[0] *= -1.0f / (float) nr; } } + void ggml_compute_forward_cross_entropy_loss( const ggml_compute_params * params, ggml_tensor * dst) { @@ -9494,4 +9276,4 @@ void ggml_compute_forward_opt_step_adamw( GGML_ABORT("fatal error"); } } -} \ No newline at end of file +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 39e78d7052ac2..2c87371c6dc75 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1261,6 +1261,7 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) { nbytes += (tensor->ne[i] - 1)*tensor->nb[i]; } } + #ifdef GGML_USE_TMAC if (tensor->type == GGML_TYPE_TMAC_BN_0) { // One scale will not exceed one alignment boundary, so we can just add one alignment to the size. @@ -1902,6 +1903,7 @@ struct ggml_tensor * ggml_get_next_tensor(const struct ggml_context * ctx, struc return NULL; } + struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) { struct ggml_object * obj = ctx->objects_begin; @@ -2547,6 +2549,7 @@ struct ggml_tensor * ggml_elu_inplace( struct ggml_tensor * a) { return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_ELU); } + // ggml_relu struct ggml_tensor * ggml_relu( @@ -3186,6 +3189,7 @@ struct ggml_tensor * ggml_reshape( return result; } + struct ggml_tensor * ggml_reshape_1d( struct ggml_context * ctx, struct ggml_tensor * a, @@ -3828,6 +3832,7 @@ struct ggml_tensor * ggml_rope_custom( ext_factor, attn_factor, beta_fast, beta_slow, false ); } + struct ggml_tensor * ggml_rope_custom_inplace( struct ggml_context * ctx, struct ggml_tensor * a, @@ -4467,6 +4472,7 @@ struct ggml_tensor * ggml_timestep_embedding( return result; } + // ggml_argsort struct ggml_tensor * ggml_argsort( @@ -4590,57 +4596,6 @@ struct ggml_tensor * ggml_flash_attn_mixed( return result; } -struct ggml_tensor * ggml_flash_attn_ext_with_state( - struct ggml_context * ctx, - struct ggml_tensor * q, - struct ggml_tensor * k, - struct ggml_tensor * v, - struct ggml_tensor * mask, - struct ggml_tensor * s_m_state, - float scale, - float max_bias, - float logit_softcap) { - GGML_ASSERT(ggml_can_mul_mat(k, q)); - // TODO: check if vT can be multiplied by (k*qT) - - if (mask) { - GGML_ASSERT(ggml_is_contiguous(mask)); - GGML_ASSERT(mask->ne[2] == 1); - GGML_ASSERT(mask->ne[3] == 1); - GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && - "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); - //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); - } - - if (max_bias > 0.0f) { - GGML_ASSERT(mask); - } - - // Validate state tensor format: [2, n_heads * q_len] - GGML_ASSERT(s_m_state != NULL); - GGML_ASSERT(s_m_state->ne[0] == 2); // [M, S] pairs - GGML_ASSERT(s_m_state->ne[1] == q->ne[2] * q->ne[1]); // n_heads * q_len - GGML_ASSERT(s_m_state->type == GGML_TYPE_F32); - - // permute(0, 2, 1, 3) - int64_t ne[4] = { v->ne[0], q->ne[2], q->ne[1], q->ne[3] }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - float params[] = { scale, max_bias, logit_softcap }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_FLASH_ATTN_EXT; - result->src[0] = q; - result->src[1] = k; - result->src[2] = v; - result->src[3] = mask; - result->src[4] = NULL; // k_quant not used in this variant - result->src[5] = NULL; // v_quant not used in this variant - result->src[6] = s_m_state; // State tensor for S and M values - - return result; -} - void ggml_flash_attn_ext_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { @@ -5164,6 +5119,7 @@ static struct ggml_tensor * ggml_map_custom2_impl( return result; } + struct ggml_tensor * ggml_map_custom2( struct ggml_context * ctx, struct ggml_tensor * a, @@ -5500,6 +5456,7 @@ static void ggml_sub_or_set( ggml_format_name(cgraph->grads[isrc], "grad for %s", src->name); ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } + static void ggml_compute_backward( struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, const bool * grads_needed) { struct ggml_tensor * tensor = cgraph->nodes[i]; @@ -6151,6 +6108,7 @@ size_t ggml_graph_overhead_custom(size_t size, bool grads) { size_t ggml_graph_overhead(void) { return ggml_graph_overhead_custom(GGML_DEFAULT_GRAPH_SIZE, false); } + struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads) { const size_t obj_size = ggml_graph_nbytes(size, grads); struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_TYPE_GRAPH, obj_size); @@ -6747,4 +6705,4 @@ bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, cons if (p0->poll != p1->poll ) return false; if (p0->strict_cpu != p1->strict_cpu ) return false; return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; -} \ No newline at end of file +} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index bff92ccbdba1e..18a1c8f05dd49 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -229,7 +229,6 @@ if (NOT GGML_BACKEND_DL) llama_build_and_test(test-mul-mat.cpp) llama_build_and_test(test-flash-attn.cpp) llama_build_and_test(test-flash-decoding-custom-op.cpp) - llama_build_and_test(test-flash-attn-state.cpp) llama_build_and_test(test_ggml_mul_mat.cpp) endif() diff --git a/tests/test-flash-attn-state.cpp b/tests/test-flash-attn-state.cpp deleted file mode 100644 index 7d1be7f02551f..0000000000000 --- a/tests/test-flash-attn-state.cpp +++ /dev/null @@ -1,411 +0,0 @@ -#include "ggml.h" -#include "ggml-cpu.h" -#include "../ggml/src/ggml-impl.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// Use fixed seed for reproducible results -static std::mt19937 g_rng(42); - -static void fill_tensor_f32(ggml_tensor * dst, float min_val = -1.0f, float max_val = 1.0f) { - float* data = (float*)dst->data; - size_t n_elements = ggml_nelements(dst); - std::uniform_real_distribution dis(min_val, max_val); - - for (size_t i = 0; i < n_elements; i++) { - data[i] = dis(g_rng); - } -} - -static void fill_tensor_f16(ggml_tensor * dst, float min_val = -1.0f, float max_val = 1.0f) { - ggml_fp16_t* data = (ggml_fp16_t*)dst->data; - size_t n_elements = ggml_nelements(dst); - std::uniform_real_distribution dis(min_val, max_val); - - for (size_t i = 0; i < n_elements; i++) { - data[i] = ggml_fp32_to_fp16(dis(g_rng)); - } -} - -static void print_tensor_info(const char* name, ggml_tensor* tensor) { - printf("%s: [%ld, %ld, %ld, %ld] type=%s, elements=%ld\n", - name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], - ggml_type_name(tensor->type), ggml_nelements(tensor)); -} - -static void print_f32_sample(const char* name, ggml_tensor* tensor, int max_elements = 10) { - if (tensor->type != GGML_TYPE_F32) { - printf("%s: Not F32 tensor (type=%s)\n", name, ggml_type_name(tensor->type)); - return; - } - - float* data = (float*)tensor->data; - size_t n_elements = ggml_nelements(tensor); - size_t elements_to_print = std::min((size_t)max_elements, n_elements); - - printf("%s sample values: ", name); - for (size_t i = 0; i < elements_to_print; i++) { - printf("%.6f ", data[i]); - } - if (elements_to_print < n_elements) { - printf("... (total %ld elements)", n_elements); - } - printf("\n"); -} - -static float tensor_max_diff(ggml_tensor* a, ggml_tensor* b) { - if (ggml_nelements(a) != ggml_nelements(b) || a->type != b->type) { - printf("ERROR: Tensors have different sizes or types\n"); - return -1.0f; - } - - if (a->type != GGML_TYPE_F32) { - printf("ERROR: Only F32 tensors supported for comparison\n"); - return -1.0f; - } - - float* data_a = (float*)a->data; - float* data_b = (float*)b->data; - size_t n_elements = ggml_nelements(a); - - float max_diff = 0.0f; - for (size_t i = 0; i < n_elements; i++) { - float diff = std::abs(data_a[i] - data_b[i]); - max_diff = std::max(max_diff, diff); - } - - return max_diff; -} - -static void reset_state_tensor(ggml_tensor* state) { - float* state_data = (float*)state->data; - size_t n_pairs = ggml_nelements(state) / 2; - - for (size_t i = 0; i < n_pairs; i++) { - state_data[i * 2 + 0] = -INFINITY; // M (max KQ value) - state_data[i * 2 + 1] = 0.0f; // S (sum) - } -} - -int main() { - printf("=== Flash Attention State Tensor - Comprehensive Test ===\n"); - - // Test parameters - const int head_dim = 32; - const int n_heads = 8; - const int n_kv_heads = 4; - const int seq_len = 2; - const int kv_len = 4; // Will be split into segments - const int n_threads = 4; - const int kv_segments = 2; // Split KV into 2 segments - const int kv_segment_len = kv_len / kv_segments; - - printf("Test Parameters:\n"); - printf(" head_dim=%d, n_heads=%d, n_kv_heads=%d\n", head_dim, n_heads, n_kv_heads); - printf(" seq_len=%d, kv_len=%d\n", seq_len, kv_len); - printf(" kv_segments=%d, kv_segment_len=%d\n", kv_segments, kv_segment_len); - - // Initialize ggml context - const size_t ctx_size = 1024*1024*1024; // 1GB - struct ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ false, - }; - - struct ggml_context * ctx = ggml_init(params); - if (!ctx) { - fprintf(stderr, "Failed to initialize ggml context\n"); - return 1; - } - - // ============================================================================ - // Create and initialize tensors with FIXED data - // ============================================================================ - printf("\n--- Creating Fixed Test Data ---\n"); - - // Create tensors for flash attention - // Format: [head_dim, seq_len, n_heads, 1] for Q - // Format: [head_dim, kv_len, n_kv_heads, 1] for K, V - ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, seq_len, n_heads, 1); - ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, head_dim, kv_len, n_kv_heads, 1); - ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, head_dim, kv_len, n_kv_heads, 1); - - // Create mask tensor with proper padding - const int padded_kv_len = GGML_PAD(kv_len, 64); - const int padded_seq_len = GGML_PAD(seq_len, GGML_KQ_MASK_PAD); - ggml_tensor * mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, padded_kv_len, padded_seq_len); - - // Create state tensor: [2, n_heads * seq_len] for [M, S] pairs - ggml_tensor * state = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2, n_heads * seq_len); - - print_tensor_info("Q", q); - print_tensor_info("K", k); - print_tensor_info("V", v); - print_tensor_info("Mask", mask); - print_tensor_info("State", state); - - // Fill with FIXED reproducible data - printf("\nGenerating fixed test data (seed=42)...\n"); - fill_tensor_f32(q, -0.8f, 0.8f); - fill_tensor_f16(k, -0.6f, 0.6f); - fill_tensor_f16(v, -0.7f, 0.7f); - - // Initialize mask (no causal mask - all positions can see all KV) - ggml_fp16_t* mask_data = (ggml_fp16_t*)mask->data; - memset(mask_data, 0, ggml_nbytes(mask)); - for (int i = 0; i < seq_len; i++) { - for (int j = 0; j < kv_len; j++) { - // No masking - all positions can see all KV tokens - mask_data[i * padded_kv_len + j] = ggml_fp32_to_fp16(0.0f); - } - } - - printf("Fixed test data generated successfully\n"); - - // ============================================================================ - // Test 1: Standard Flash Attention (Reference Result) - // ============================================================================ - printf("\n--- Test 1: Standard Flash Attention (Reference) ---\n"); - - ggml_tensor * result_standard = ggml_flash_attn_ext( - ctx, q, k, v, mask, - 1.0f / std::sqrt(head_dim), // scale - 0.0f, // max_bias - 0.0f // logit_softcap - ); - ggml_flash_attn_ext_set_prec(result_standard, GGML_PREC_F32); - - if (!result_standard) { - printf("ERROR: Failed to create standard flash attention operation\n"); - ggml_free(ctx); - return 1; - } - - struct ggml_cgraph * graph_standard = ggml_new_graph(ctx); - ggml_build_forward_expand(graph_standard, result_standard); - - printf("Computing standard flash attention...\n"); - enum ggml_status status_standard = ggml_graph_compute_with_ctx(ctx, graph_standard, n_threads); - - if (status_standard != GGML_STATUS_SUCCESS) { - printf("ERROR: Standard flash attention failed with status: %d\n", status_standard); - ggml_free(ctx); - return 1; - } - - printf("Standard flash attention computation successful\n"); - print_f32_sample("Standard result", result_standard, 8); - - // ============================================================================ - // Test 2: Segmented Flash Attention with State Accumulation - // ============================================================================ - printf("\n--- Test 2: Segmented Flash Attention with State ---\n"); - - // Reset state tensor - reset_state_tensor(state); - - // Create result tensor for accumulation (same shape as standard result) - ggml_tensor * result_segmented = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, - head_dim, seq_len, n_heads, 1); - - // Initialize segmented result to zero - memset(result_segmented->data, 0, ggml_nbytes(result_segmented)); - - printf("Processing %d segments of KV cache (segment_len=%d)...\n", kv_segments, kv_segment_len); - - for (int seg = 0; seg < kv_segments; seg++) { - printf("\n Segment %d/%d (kv_pos %d-%d):\n", - seg + 1, kv_segments, seg * kv_segment_len, (seg + 1) * kv_segment_len - 1); - - // Print state before this segment - printf(" State before segment %d: ", seg + 1); - float* state_data = (float*)state->data; - for (int i = 0; i < std::min(4, n_heads * seq_len); i++) { - printf("[M=%.3f,S=%.3f] ", state_data[i * 2 + 0], state_data[i * 2 + 1]); - } - printf("...\n"); - - // Create views of K and V for this segment using ggml_view_4d - ggml_tensor * k_segment = ggml_view_4d(ctx, k, - head_dim, kv_segment_len, n_kv_heads, 1, // ne - k->nb[1], k->nb[2], k->nb[3], // nb (strides) - seg * kv_segment_len * k->nb[1]); // offset - - ggml_tensor * v_segment = ggml_view_4d(ctx, v, - head_dim, kv_segment_len, n_kv_heads, 1, // ne - v->nb[1], v->nb[2], v->nb[3], // nb (strides) - seg * kv_segment_len * v->nb[1]); // offset - - // Create mask for this segment - const int padded_segment_len = GGML_PAD(kv_segment_len, 64); - ggml_tensor * mask_segment = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, - padded_segment_len, padded_seq_len); - - // Fill segment mask - ggml_fp16_t* mask_seg_data = (ggml_fp16_t*)mask_segment->data; - memset(mask_seg_data, 0, ggml_nbytes(mask_segment)); - - for (int i = 0; i < seq_len; i++) { - for (int j = 0; j < kv_segment_len; j++) { - int global_j = seg * kv_segment_len + j; - // No masking for segment - all positions can see all KV tokens in this segment - mask_seg_data[i * padded_segment_len + j] = ggml_fp32_to_fp16(0.0f); - } - } - - // Debug: Print mask information for first segment - if (seg == 0) { - printf(" Debug - Global mask (first 4 seq positions, first 20 kv positions):\n"); - for (int i = 0; i < std::min(4, seq_len); i++) { - printf(" seq[%d]: ", i); - for (int j = 0; j < std::min(20, kv_len); j++) { - float mask_val = GGML_FP16_TO_FP32(mask_data[i * padded_kv_len + j]); - printf("%.0f ", mask_val == -INFINITY ? -1.0f : mask_val); - } - printf("...\n"); - } - - printf(" Debug - Segment mask (first 4 seq positions, all segment positions):\n"); - for (int i = 0; i < std::min(4, seq_len); i++) { - printf(" seq[%d]: ", i); - for (int j = 0; j < kv_segment_len; j++) { - float mask_val = GGML_FP16_TO_FP32(mask_seg_data[i * padded_segment_len + j]); - printf("%.0f ", mask_val == -INFINITY ? -1.0f : mask_val); - } - printf("\n"); - } - } - - print_tensor_info(" K segment", k_segment); - print_tensor_info(" V segment", v_segment); - - // Compute flash attention with state for this segment - ggml_tensor * result_seg = ggml_flash_attn_ext_with_state( - ctx, q, k_segment, v_segment, mask_segment, state, - 1.0f / std::sqrt(head_dim), // scale - 0.0f, // max_bias - 0.0f // logit_softcap - ); - ggml_flash_attn_ext_set_prec(result_seg, GGML_PREC_F32); - - if (!result_seg) { - printf("ERROR: Failed to create segmented flash attention for segment %d\n", seg); - ggml_free(ctx); - return 1; - } - - struct ggml_cgraph * graph_seg = ggml_new_graph(ctx); - ggml_build_forward_expand(graph_seg, result_seg); - - enum ggml_status status_seg = ggml_graph_compute_with_ctx(ctx, graph_seg, n_threads); - - if (status_seg != GGML_STATUS_SUCCESS) { - printf("ERROR: Segmented flash attention failed for segment %d with status: %d\n", seg, status_seg); - ggml_free(ctx); - return 1; - } - - printf(" Segment %d computed successfully\n", seg + 1); - print_f32_sample(" Segment result", result_seg, 6); - - // Print state after this segment - printf(" State after segment %d: ", seg + 1); - for (int i = 0; i < std::min(4, n_heads * seq_len); i++) { - printf("[M=%.3f,S=%.3f] ", state_data[i * 2 + 0], state_data[i * 2 + 1]); - } - printf("...\n"); - - // For the final segment, copy the result (this contains the accumulated result of all segments) - if (seg == kv_segments - 1) { - memcpy(result_segmented->data, result_seg->data, ggml_nbytes(result_seg)); - printf(" Final accumulated result copied from segment %d\n", seg + 1); - } - } - - printf("\nSegmented computation completed\n"); - print_f32_sample("Final segmented result", result_segmented, 8); - - // ============================================================================ - // Test 3: Compare Results - // ============================================================================ - printf("\n--- Test 3: Comparing Results ---\n"); - - float max_diff = tensor_max_diff(result_standard, result_segmented); - - printf("Comparison between standard and segmented results:\n"); - printf(" Maximum absolute difference: %.2e\n", max_diff); - - const float tolerance = 1e-4; // Reasonable tolerance for F16/F32 precision - - if (max_diff < tolerance) { - printf(" ✅ PASS: Results match within tolerance (%.2e)\n", tolerance); - } else { - printf(" ❌ FAIL: Results differ beyond tolerance (%.2e)\n", tolerance); - - // Print detailed comparison for debugging - printf("\nDetailed comparison:\n"); - print_f32_sample("Standard", result_standard, 20); - print_f32_sample("Segmented", result_segmented, 20); - } - - // ============================================================================ - // Test 4: State Tensor Analysis - // ============================================================================ - printf("\n--- Test 4: State Tensor Analysis ---\n"); - - printf("Final state tensor values:\n"); - print_f32_sample("Final state", state, 16); - - float* state_data = (float*)state->data; - float min_m = INFINITY, max_m = -INFINITY; - float min_s = INFINITY, max_s = -INFINITY; - - for (int i = 0; i < n_heads * seq_len; i++) { - float m_val = state_data[i * 2 + 0]; - float s_val = state_data[i * 2 + 1]; - - if (m_val != -INFINITY) { - min_m = std::min(min_m, m_val); - max_m = std::max(max_m, m_val); - } - - min_s = std::min(min_s, s_val); - max_s = std::max(max_s, s_val); - } - - printf("State tensor statistics:\n"); - printf(" M values: min=%.6f, max=%.6f\n", min_m, max_m); - printf(" S values: min=%.6f, max=%.6f\n", min_s, max_s); - - // ============================================================================ - // Final Results - // ============================================================================ - printf("\n=== Final Test Results ===\n"); - - if (max_diff < tolerance) { - printf("🎉 ALL TESTS PASSED!\n"); - printf("✅ Segmented flash attention with state produces identical results\n"); - printf("✅ State tensor correctly accumulates across segments\n"); - printf("✅ Implementation is working correctly\n"); - } else { - printf("❌ TESTS FAILED!\n"); - printf("❌ Results differ beyond acceptable tolerance\n"); - printf("❌ Implementation needs debugging\n"); - } - - printf("\nMax difference: %.2e (tolerance: %.2e)\n", max_diff, tolerance); - - // Cleanup - ggml_free(ctx); - return (max_diff < tolerance) ? 0 : 1; -} \ No newline at end of file