File tree Expand file tree Collapse file tree 2 files changed +14
-5
lines changed Expand file tree Collapse file tree 2 files changed +14
-5
lines changed Original file line number Diff line number Diff line change @@ -6769,8 +6769,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
67696769 ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu (k->type )->vec_dot ;
67706770 ggml_to_float_t const v_to_float = ggml_get_type_traits (v->type )->to_float ;
67716771
6772- GGML_ASSERT (q_to_vec_dot && " fattn: unsupported K-type" );
6773- GGML_ASSERT (v_to_float && " fattn: unsupported V-type" );
6772+ GGML_ASSERT (( q_to_vec_dot) && " fattn: unsupported K-type" );
6773+ GGML_ASSERT ((v-> type == GGML_TYPE_F32 || v_to_float ) && " fattn: unsupported V-type" );
67746774
67756775 // loop over n_batch and n_head
67766776 for (int ir = ir0; ir < ir1; ++ir) {
@@ -6866,10 +6866,14 @@ static void ggml_compute_forward_flash_attn_ext_f16(
68666866 vs = expf (s - M);
68676867 }
68686868
6869- v_to_float (v_data, V32, DV);
6870-
68716869 // V += v*expf(s - M)
6872- ggml_vec_mad_f32 (DV, VKQ32, V32, vs);
6870+ if (v_to_float) {
6871+ v_to_float (v_data, V32, DV);
6872+ ggml_vec_mad_f32 (DV, VKQ32, V32, vs);
6873+ } else {
6874+ // V is F32
6875+ ggml_vec_mad_f32 (DV, VKQ32, (const float *) v_data, vs);
6876+ }
68736877 }
68746878
68756879 S = S*ms + vs; // scale and increment sum with partial sum
Original file line number Diff line number Diff line change @@ -1346,6 +1346,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
13461346 case GGML_OP_ARANGE:
13471347 return true ;
13481348 case GGML_OP_FLASH_ATTN_EXT:
1349+ if (op->src [0 ]->ne [0 ] == 32 ) {
1350+ // head size == 32 (e.g. bert-bge-small)
1351+ // TODO: not sure if it is worth adding kernels for this size
1352+ return false ;
1353+ }
13491354 if (op->src [1 ]->type != op->src [2 ]->type ) {
13501355 return false ;
13511356 }
You can’t perform that action at this time.
0 commit comments