Skip to content

Commit ef49b83

Browse files
authored
Merge branch 'ikawrakow:main' into main
2 parents 3a55c53 + c029d97 commit ef49b83

16 files changed

+223
-33
lines changed

examples/server/utils.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,17 @@ static json oaicompat_chat_params_parse(
591591

592592
// Apply chat template to the list of messages
593593
auto chat_params = common_chat_templates_apply(opt.tmpls, inputs);
594+
595+
/* Append assistant prefilled message */
596+
if (prefill_assistant_message) {
597+
if (!last_message.content_parts.empty()) {
598+
for (auto & p : last_message.content_parts) {
599+
chat_params.prompt += p.text;
600+
}
601+
} else {
602+
chat_params.prompt += last_message.content;
603+
}
604+
}
594605

595606
llama_params["chat_format"] = static_cast<int>(chat_params.format);
596607
llama_params["prompt"] = chat_params.prompt;

ggml/src/ggml-cuda.cu

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4386,31 +4386,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
43864386
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
43874387
return (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) || op->src[0]->ne[0] == 128;
43884388
#else
4389-
if (op->src[0]->ne[0] == 128) {
4390-
return true;
4391-
}
4392-
if (op->src[1]->ne[0] == 256 && op->src[2]->ne[0] == 256 &&
4393-
(op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_Q8_0) &&
4394-
(op->src[2]->type == GGML_TYPE_F16 || op->src[2]->type == GGML_TYPE_Q8_0)) {
4395-
return true;
4396-
}
4397-
if (op->src[1]->ne[0] == 192 && op->src[2]->ne[0] == 128) {
4398-
return (op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) ||
4399-
(op->src[1]->type == GGML_TYPE_Q8_0 && op->src[2]->type == GGML_TYPE_Q8_0);
4400-
}
4401-
if (op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512) {
4402-
const int cc = ggml_cuda_info().devices[cuda_ctx->device].cc;
4403-
int gqa = op->src[0]->ne[2]/op->src[1]->ne[2];
4404-
return (new_mma_available(cc) && cc >= CC_AMPERE && op->src[3] && gqa%16 == 0);
4405-
}
4406-
if (op->src[1]->ne[0] > 256) {
4407-
return false;
4408-
}
4409-
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
4410-
return true;
4411-
}
4412-
return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA &&
4413-
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
4389+
return ggml_cuda_fattn_is_supported(*cuda_ctx, op);
44144390
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
44154391
default:
44164392
return false;

ggml/src/ggml-cuda/fattn-mma-f16-interface.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
#include "common.cuh"
44

55
void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6+
7+
bool ggml_cuda_fattn_mma_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);

ggml/src/ggml-cuda/fattn-mma-f16.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,10 @@ void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tens
8383

8484
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
8585
}
86+
87+
bool ggml_cuda_fattn_mma_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
88+
auto K = dst->src[1];
89+
auto V = dst->src[1];
90+
if (K->ne[0] != V->ne[0]) return false;
91+
return K->ne[0] == 64 || K->ne[0] == 80 || K->ne[0] == 96 || K->ne[0] == 112 || K->ne[0] == 128 || K->ne[0] == 256;
92+
}

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,3 +353,10 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
353353
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, false>(ctx, dst);
354354
}
355355
}
356+
357+
bool ggml_cuda_fattn_tile_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
358+
auto K = dst->src[1];
359+
auto V = dst->src[2];
360+
if (K->ne[0] != V->ne[0]) return false;
361+
return K->ne[0] == 64 || K->ne[0] == 128;
362+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
#include "common.cuh"
22

33
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
4+
5+
bool ggml_cuda_fattn_tile_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);

ggml/src/ggml-cuda/fattn-tile-f32.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,10 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
352352
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, true>(ctx, dst);
353353
}
354354
}
355+
356+
bool ggml_cuda_fattn_tile_f32_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
357+
auto K = dst->src[1];
358+
auto V = dst->src[2];
359+
if (K->ne[0] != V->ne[0]) return false;
360+
return K->ne[0] == 64 || K->ne[0] == 128;
361+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
#include "common.cuh"
22

33
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
4+
5+
bool ggml_cuda_fattn_tile_f32_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);

ggml/src/ggml-cuda/fattn-vec-f16-interface.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@
33
#include "common.cuh"
44

55
void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6+
7+
bool ggml_cuda_fattn_vec_f16_is_supported(ggml_backend_cuda_context & ctx, const ggml_tensor * dst);

ggml/src/ggml-cuda/fattn-vec-f16.cu

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,50 @@ void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tens
102102
on_no_fattn_vec_case(Q->ne[0], V->ne[0]);
103103
}
104104

105-
105+
bool ggml_cuda_fattn_vec_f16_is_supported([[maybe_unused]] ggml_backend_cuda_context & ctx, const ggml_tensor * dst) {
106+
auto K = dst->src[1];
107+
auto V = dst->src[2];
108+
if (K->ne[0] != V->ne[0]) {
109+
if (K->ne[0] != 192 || V->ne[2] != 128) return false;
110+
if (K->type != V->type) return false;
111+
return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0;
112+
}
113+
#ifdef GGML_CUDA_FA_ALL_QUANTS
114+
if (K->ne[0] == 64) {
115+
return K->type == GGML_TYPE_F16 &&
116+
(V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 ||
117+
V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 || V->type == GGML_TYPE_Q8_0);
118+
}
119+
if (K->ne[0] == 256) {
120+
return K->type == V->type && (K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0);
121+
}
122+
if (K->ne[0] != 128 || V->ne[0] != 128) return false;
123+
if ((K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q4_1 || K->type == GGML_TYPE_Q5_0 || K->type == GGML_TYPE_Q5_1 ||
124+
K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16) &&
125+
(V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q4_1 || V->type == GGML_TYPE_Q5_0 || V->type == GGML_TYPE_Q5_1 ||
126+
V->type == GGML_TYPE_Q8_0 || V->type == GGML_TYPE_F16)) return true;
127+
return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) ||
128+
(K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) ||
129+
(K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q6_0) ||
130+
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) ||
131+
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL);
132+
#else
133+
if (K->ne[0] == 128) {
134+
if (K->type == V->type) {
135+
return K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_IQ4_NL;
136+
}
137+
return (K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL) ||
138+
(K->type == GGML_TYPE_Q6_0 && V->type == GGML_TYPE_Q5_0) ||
139+
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_Q6_0) ||
140+
(K->type == GGML_TYPE_Q8_0 && V->type == GGML_TYPE_IQ4_NL);
141+
}
142+
if (K->type != V->type) return false;
143+
if (K->ne[0] == 64) {
144+
return K->type == GGML_TYPE_F16;
145+
}
146+
if (K->ne[0] == 256) {
147+
return K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q8_0;
148+
}
149+
return false;
150+
#endif
151+
}

0 commit comments

Comments
 (0)