Skip to content

Commit 770184d

Browse files
committed
tmp
1 parent 93674de commit 770184d

File tree

1 file changed

+6
-62
lines changed

1 file changed

+6
-62
lines changed

src/llama-model.cpp

Lines changed: 6 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2913,11 +2913,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
29132913
}
29142914

29152915
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0);
2916+
layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
29162917
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, n_head * kv_lora_rank}, 0);
29172918
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_head * n_embd_head_v}, 0);
29182919
if (!layer.wk_b || !layer.wv_b) {
2919-
auto wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0);
2920-
if (!wkv_b) {
2920+
if (!layer.wkv_b) {
29212921
throw std::runtime_error("wkv_b must be defined without wk_b and wv_b");
29222922
}
29232923

@@ -2946,69 +2946,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
29462946
buft = ggml_backend_dev_buffer_type(cpu_dev);
29472947
}
29482948

2949-
LLAMA_LOG_INFO("wkv_b shape: [%d, %d], type: %d\n", wkv_b->ne[0], wkv_b->ne[1], int(wkv_b->type));
2949+
LLAMA_LOG_INFO("wkv_b shape: [%d, %d], type: %d\n", layer.wkv_b->ne[0], layer.wkv_b->ne[1], int(layer.wkv_b->type));
29502950
LLAMA_LOG_INFO("n_head_kv: %d, kv_lora_rank: %d, n_embd_head_qk_nope: %d\n", n_head_kv, kv_lora_rank, n_embd_head_qk_nope);
29512951
ggml_context * ctx = ctx_for_buft(buft);
2952-
layer.wk_b = ggml_new_tensor_2d(ctx,
2953-
wkv_b->type,
2954-
n_head_kv * kv_lora_rank,
2955-
n_embd_head_qk_nope
2956-
);
2957-
LLAMA_LOG_INFO("wk_b shape: [%d, %d]\n", layer.wk_b->ne[0], layer.wk_b->ne[1]);
2958-
{
2959-
float *src = (float *)wkv_b->data;
2960-
float *dst = (float *)layer.wk_b->data;
2961-
int src_stride = wkv_b->ne[0]; // 原始张量每行的元素数
2962-
2963-
for (int h = 0; h < n_head_kv; ++h) {
2964-
int k_start = h * (n_embd_head_qk_nope + n_embd_head_v);
2965-
for (int row = 0; row < kv_lora_rank; ++row) {
2966-
for (int col = 0; col < n_embd_head_qk_nope; ++col) {
2967-
LLAMA_LOG_INFO("wk_b row: %d, col: %d\n", row, col);
2968-
int src_idx = row * src_stride + k_start + col;
2969-
LLAMA_LOG_INFO("src_idx: %d\n", src_idx);
2970-
GGML_ASSERT(src_idx < ggml_nelements(wkv_b));
2971-
2972-
int dst_row = h * kv_lora_rank + row;
2973-
int dst_col = col;
2974-
LLAMA_LOG_INFO("wk_b dst_row: %d, dst_col: %d\n", dst_row, dst_col);
2975-
dst[dst_row * n_embd_head_qk_nope + dst_col] = src[src_idx];
2976-
}
2977-
}
2978-
}
2979-
}
29802952

2981-
layer.wv_b = ggml_new_tensor_2d(
2982-
ctx,
2983-
wkv_b->type,
2984-
n_head_kv * n_embd_head_v, // 行数:合并头和特征维度
2985-
kv_lora_rank // 列数:LoRA 秩
2986-
);
2987-
LLAMA_LOG_INFO("wv_b shape: [%d, %d]\n", layer.wv_b->ne[0], layer.wv_b->ne[1]);
2988-
{
2989-
float *src = (float *)wkv_b->data;
2990-
float *dst = (float *)layer.wv_b->data;
2991-
int src_stride = wkv_b->ne[0]; // 原始张量每行的元素数
2992-
2993-
for (int h = 0; h < n_head_kv; ++h) {
2994-
int v_start = h * (n_embd_head_qk_nope + n_embd_head_v) + n_embd_head_qk_nope;
2995-
for (int row = 0; row < kv_lora_rank; ++row) {
2996-
for (int col = 0; col < n_embd_head_v; ++col) {
2997-
LLAMA_LOG_INFO("wv_b row: %d, col: %d\n", row, col);
2998-
// 源索引计算
2999-
int src_idx = row * src_stride + v_start + col;
3000-
LLAMA_LOG_INFO("src_idx: %d\n", src_idx);
3001-
GGML_ASSERT(src_idx < ggml_nelements(wkv_b));
3002-
3003-
// 目标索引计算
3004-
int dst_row = h * n_embd_head_v + col; // 合并头和特征维度
3005-
int dst_col = row; // LoRA 秩维度
3006-
LLAMA_LOG_INFO("wv_b dst_row: %d, dst_col: %d\n", dst_row, dst_col);
3007-
dst[dst_row * kv_lora_rank + dst_col] = src[src_idx];
3008-
}
3009-
}
3010-
}
3011-
}
2953+
auto trans_wkv_b = ggml_transpose(ctx, layer.wkv_b);
2954+
layer.wk_b = ggml_view_2d(ctx, trans_wkv_b, trans_wkv_b->ne[0], n_embd_head_qk_nope, n_head, 0);
2955+
layer.wv_b = ggml_view_2d(ctx, trans_wkv_b, trans_wkv_b->ne[0], n_embd_head_v, n_head, n_embd_head_qk_nope * n_head);
30122956
}
30132957
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0);
30142958

0 commit comments

Comments
 (0)