@@ -745,10 +745,14 @@ void launch_fattn(
745745 size_t nb23 = V ? V->nb [3 ] : nb13;
746746
747747 if (need_f16_K && K->type != GGML_TYPE_F16) {
748- GGML_ASSERT (ggml_is_contiguously_allocated (K));
749- K_f16.alloc (ggml_nelements (K));
748+ const int64_t n_seq = K->ne [3 ];
749+ const int64_t n_eps = (K->nb [3 ]/ggml_type_size (K->type ))*ggml_blck_size (K->type ); // elements per sequence
750+
751+ K_f16.alloc (n_seq*n_eps);
750752 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (K->type );
751- to_fp16 (K_data, K_f16.ptr , ggml_nelements (K), main_stream);
753+ for (int s = 0 ; s < n_seq; ++s) {
754+ to_fp16 (K_data + s*K->nb [3 ], K_f16.ptr + s*n_eps, n_eps, main_stream);
755+ }
752756 K_data = (char *) K_f16.ptr ;
753757
754758 const size_t bs = ggml_blck_size (K->type );
@@ -760,10 +764,14 @@ void launch_fattn(
760764 }
761765
762766 if (V && need_f16_V && V->type != GGML_TYPE_F16) {
763- GGML_ASSERT (ggml_is_contiguously_allocated (V));
764- V_f16.alloc (ggml_nelements (V));
767+ const int64_t n_seq = V->ne [3 ];
768+ const int64_t n_eps = (V->nb [3 ]/ggml_type_size (V->type ))*ggml_blck_size (V->type );
769+
770+ V_f16.alloc (n_seq*n_eps);
765771 to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda (V->type );
766- to_fp16 (V_data, V_f16.ptr , ggml_nelements (V), main_stream);
772+ for (int s = 0 ; s < n_seq; ++s) {
773+ to_fp16 (V_data + s*V->nb [3 ], V_f16.ptr + s*n_eps, n_eps, main_stream);
774+ }
767775 V_data = (char *) V_f16.ptr ;
768776
769777 const size_t bs = ggml_blck_size (V->type );
0 commit comments