Skip to content

Commit d13d6ff

Browse files
committed
support dynamic wkv
1 parent 42c0aa2 commit d13d6ff

File tree

1 file changed

+87
-2
lines changed

1 file changed

+87
-2
lines changed

src/llama-model.cpp

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,7 +1422,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
14221422
try {
14231423
info = llm_tensor_info_for(tn_tensor);
14241424
} catch (const std::out_of_range & e) {
1425-
throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str()));
1425+
LLAMA_LOG_WARN("missing tensor info mapping for %s -- ignoring\n", tn.str().c_str());
1426+
return nullptr;
14261427
}
14271428

14281429
// skip unused tensors
@@ -2911,9 +2912,93 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
29112912
}
29122913

29132914
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
2914-
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
29152915
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0);
29162916
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0);
2917+
if (!layer.wk_b || !layer.wv_b) {
2918+
auto wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
2919+
if (!wkv_b) {
2920+
throw std::runtime_error("wkv_b must be defined without wk_b and wv_b");
2921+
}
2922+
2923+
// select the buffer type for this tensor
2924+
buft_list_t * buft_list = pimpl->dev_input.buft_list;
2925+
2926+
ggml_backend_buffer_type_t buft = nullptr;
2927+
2928+
// check overrides
2929+
if (ml.tensor_buft_overrides) {
2930+
std::string tensor_name = "blk."+ std::to_string(i) +".attn_kv_b.weight";
2931+
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
2932+
std::regex pattern(overrides->pattern);
2933+
if (std::regex_search(tensor_name, pattern)) {
2934+
LLAMA_LOG_DEBUG("tensor %s buffer type overriden to %s\n", tensor_name.c_str(), ggml_backend_buft_name(overrides->buft));
2935+
buft = overrides->buft;
2936+
break;
2937+
}
2938+
}
2939+
}
2940+
2941+
// avoid using a host buffer when using mmap
2942+
auto * buft_dev = ggml_backend_buft_get_device(buft);
2943+
if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) {
2944+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
2945+
buft = ggml_backend_dev_buffer_type(cpu_dev);
2946+
}
2947+
2948+
ggml_context * ctx = ctx_for_buft(buft);
2949+
layer.wk_b = ggml_new_tensor_2d(ctx,
2950+
layer.wkv_b->type,
2951+
n_head_kv * kv_lora_rank,
2952+
n_embd_head_qk_nope
2953+
);
2954+
{
2955+
float *src = (float *)layer.wkv_b->data;
2956+
float *dst = (float *)layer.wk_b->data;
2957+
int src_stride = wkv_b->ne[0]; // 原始张量每行的元素数
2958+
2959+
for (int h = 0; h < n_head_kv; ++h) {
2960+
int k_start = h * (n_embd_head_qk_nope + n_embd_head_v);
2961+
for (int row = 0; row < kv_lora_rank; ++row) {
2962+
for (int col = 0; col < n_embd_head_qk_nope; ++col) {
2963+
int src_idx = row * src_stride + k_start + col;
2964+
GGML_ASSERT(src_idx < ggml_nelements(layer.wkv_b));
2965+
2966+
int dst_row = h * kv_lora_rank + row;
2967+
int dst_col = col;
2968+
dst[dst_row * n_embd_head_qk_nope + dst_col] = src[src_idx];
2969+
}
2970+
}
2971+
}
2972+
}
2973+
2974+
layer.wv_b = ggml_new_tensor_2d(
2975+
ctx,
2976+
layer.wkv_b->type,
2977+
n_head_kv * n_embd_head_v, // 行数:合并头和特征维度
2978+
kv_lora_rank // 列数:LoRA 秩
2979+
);
2980+
{
2981+
float *src = (float *)layer.wkv_b->data;
2982+
float *dst = (float *)layer.wv_b->data;
2983+
int src_stride = wkv_b->ne[0]; // 原始张量每行的元素数
2984+
2985+
for (int h = 0; h < n_head_kv; ++h) {
2986+
int v_start = h * (n_embd_head_qk_nope + n_embd_head_v) + n_embd_head_qk_nope;
2987+
for (int row = 0; row < kv_lora_rank; ++row) {
2988+
for (int col = 0; col < n_embd_head_v; ++col) {
2989+
// 源索引计算
2990+
int src_idx = row * src_stride + v_start + col;
2991+
GGML_ASSERT(src_idx < ggml_nelements(layer.wkv_b));
2992+
2993+
// 目标索引计算
2994+
int dst_row = h * n_embd_head_v + col; // 合并头和特征维度
2995+
int dst_col = row; // LoRA 秩维度
2996+
dst[dst_row * kv_lora_rank + dst_col] = src[src_idx];
2997+
}
2998+
}
2999+
}
3000+
}
3001+
}
29173002
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
29183003

29193004
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);

0 commit comments

Comments
 (0)