File tree Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Expand file tree Collapse file tree 1 file changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -6721,8 +6721,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
67216721 ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu (k->type )->vec_dot ;
67226722 ggml_to_float_t const v_to_float = ggml_get_type_traits (v->type )->to_float ;
67236723
6724- GGML_ASSERT (q_to_vec_dot && " fattn: unsupported K-type" );
6725- GGML_ASSERT (v_to_float && " fattn: unsupported V-type" );
6724+ GGML_ASSERT ( q_to_vec_dot && " fattn: unsupported K-type" );
6725+ GGML_ASSERT (v-> type == GGML_TYPE_F32 || v_to_float && " fattn: unsupported V-type" );
67266726
67276727 // loop over n_batch and n_head
67286728 for (int ir = ir0; ir < ir1; ++ir) {
@@ -6818,10 +6818,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
68186818 vs = expf (s - M);
68196819 }
68206820
6821- v_to_float (v_data, V32, DV);
6822-
68236821 // V += v*expf(s - M)
6824- ggml_vec_mad_f32 (DV, VKQ32, V32, vs);
6822+ if (v_to_float) {
6823+ v_to_float (v_data, V32, DV);
6824+ ggml_vec_mad_f32 (DV, VKQ32, V32, vs);
6825+ } else {
6826+ // V is F32
6827+ ggml_vec_mad_f32 (DV, VKQ32, (const float *) v_data, vs);
6828+ }
68256829 }
68266830
68276831 S = S*ms + vs; // scale and increment sum with partial sum
You can’t perform that action at this time.
0 commit comments