Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ggml/src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4119,6 +4119,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
return true;
}
if (src0_type == GGML_TYPE_Q8_0 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
return true;
}
Expand Down
35 changes: 27 additions & 8 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8367,17 +8367,36 @@ struct llm_build_context {
const int64_t n_head_kv = hparams.n_head_kv(il);
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
struct ggml_tensor * rope_factors = build_rope_factors(il);
struct ggml_tensor * tmp =
struct ggml_tensor * k =
ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_head_kv, n_ctx,
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
0);

struct ggml_tensor * tmp;
if (ggml_is_quantized(k->type)) {
// dequantize to f32 -> RoPE -> quantize back
tmp = ggml_cast(ctx0, k, GGML_TYPE_F32);
cb(tmp, "K_f32", il);
for (auto * backend : lctx.backends) {
// Figure out which backend KV cache belongs to
if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft)) {
ggml_backend_sched_set_tensor_backend(lctx.sched, tmp, backend);
break;
}
}
tmp = ggml_rope_ext_inplace(ctx0, tmp,
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted_f32", il);
tmp = ggml_cpy(ctx0, tmp, k);
} else {
// we rotate only the first n_rot dimensions
ggml_rope_ext_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_head_kv, n_ctx,
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
0),
tmp = ggml_rope_ext_inplace(ctx0, k,
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);

}
cb(tmp, "K_shifted", il);
ggml_build_forward_expand(gf, tmp);
}
Expand Down