Skip to content

Commit 55cf48d

Browse files
committed
cuda : fix multi-seq, quantized FA
ggml-ci
1 parent a856a56 commit 55cf48d

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -745,10 +745,14 @@ void launch_fattn(
745745
size_t nb23 = V ? V->nb[3] : nb13;
746746

747747
if (need_f16_K && K->type != GGML_TYPE_F16) {
748-
GGML_ASSERT(ggml_is_contiguously_allocated(K));
749-
K_f16.alloc(ggml_nelements(K));
748+
const int64_t n_seq = K->ne[3];
749+
const int64_t n_eps = (K->nb[3]/ggml_type_size(K->type))*ggml_blck_size(K->type); // elements per sequence
750+
751+
K_f16.alloc(n_seq*n_eps);
750752
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
751-
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
753+
for (int s = 0; s < n_seq; ++s) {
754+
to_fp16(K_data + s*K->nb[3], K_f16.ptr + s*n_eps, n_eps, main_stream);
755+
}
752756
K_data = (char *) K_f16.ptr;
753757

754758
const size_t bs = ggml_blck_size(K->type);
@@ -760,10 +764,14 @@ void launch_fattn(
760764
}
761765

762766
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
763-
GGML_ASSERT(ggml_is_contiguously_allocated(V));
764-
V_f16.alloc(ggml_nelements(V));
767+
const int64_t n_seq = V->ne[3];
768+
const int64_t n_eps = (V->nb[3]/ggml_type_size(V->type))*ggml_blck_size(V->type);
769+
770+
V_f16.alloc(n_seq*n_eps);
765771
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
766-
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
772+
for (int s = 0; s < n_seq; ++s) {
773+
to_fp16(V_data + s*V->nb[3], V_f16.ptr + s*n_eps, n_eps, main_stream);
774+
}
767775
V_data = (char *) V_f16.ptr;
768776

769777
const size_t bs = ggml_blck_size(V->type);

tests/test-backend-ops.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5525,6 +5525,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
55255525
test_cases.emplace_back(new test_timestep_embedding());
55265526
test_cases.emplace_back(new test_leaky_relu());
55275527

5528+
test_cases.emplace_back(new test_flash_attn_ext(128, 128, 4, {1, 3}, 512, 128, true, 0.0f, 0.0f, GGML_PREC_DEFAULT, GGML_TYPE_Q8_0));
5529+
55285530
for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
55295531
for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
55305532
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;

0 commit comments

Comments
 (0)