@@ -102,4 +102,50 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens
102102 on_no_fattn_vec_case (Q->ne [0 ], V->ne [0 ]);
103103}
104104
105-
105+ bool ggml_cuda_fattn_vec_f16_is_supported ([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
106+ auto K = dst->src [1 ];
107+ auto V = dst->src [2 ];
108+ if (K->ne [0 ] != V->ne [0 ]) {
109+ if (K->ne [0 ] != 192 || V->ne [2 ] != 128 ) return false ;
110+ if (K->type != V->type ) return false ;
111+ return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0;
112+ }
113+ #ifdef GGML_CUDA_FA_ALL_QUANTS
114+ if (K->ne [0 ] == 64 ) {
115+ return K->type == GGML_TYPE_F16 &&
116+ (V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 ||
117+ V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 || V->type == GGML_TYPE_Q8_0);
118+ }
119+ if (K->ne [0 ] == 256 ) {
120+ return K->type == V->type && (K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0);
121+ }
122+ if (K->ne [0 ] != 128 || V->ne [0 ] != 128 ) return false ;
123+ if ((K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q4_1 || K->type == GGML_TYPE_Q5_0 || K->type == GGML_TYPE_Q5_1 ||
124+ K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16) &&
125+ (V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 || V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 ||
126+ V->type == GGML_TYPE_Q8_0 || V->type == GGML_TYPE_F16)) return true ;
127+ return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) ||
128+ (K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) ||
129+ (K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q6_0) ||
130+ (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) ||
131+ (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL);
132+ #else
133+ if (K->ne [0 ] == 128 ) {
134+ if (K->type == V->type ) {
135+ return K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_IQ4_NL;
136+ }
137+ return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) ||
138+ (K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) ||
139+ (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) ||
140+ (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL);
141+ }
142+ if (K->type != V->type ) return false ;
143+ if (K->ne [0 ] == 64 ) {
144+ return K->type == GGML_TYPE_F16;
145+ }
146+ if (K->ne [0 ] == 256 ) {
147+ return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0;
148+ }
149+ return false ;
150+ #endif
151+ }
0 commit comments