Skip to content

Commit a2a8109

Browse files
committed
rwkv: fix llama-parallel
Signed-off-by: Molly Sophia <[email protected]>
1 parent 1a9c263 commit a2a8109

File tree

3 files changed

+15
-12
lines changed

3 files changed

+15
-12
lines changed

convert_hf_to_gguf.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3555,20 +3555,23 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35553555
# ignore them all since they are not used
35563556
return
35573557

3558+
wkv_has_gate = self.hparams.get("wkv_has_gate", True)
3559+
lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"]
3560+
35583561
if bid is not None and "attention.x_" in name:
35593562
if "attention.x_x" in name:
35603563
# already concatenated
35613564
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
3562-
data = data_torch.reshape(6, 1, -1)
3565+
data = data_torch.reshape(len(lerp_list), 1, 1, -1)
35633566
yield (new_name, data)
35643567
else:
35653568
try:
35663569
self.lerp_weights[bid][name] = data_torch
35673570
except KeyError:
35683571
self.lerp_weights[bid] = {name: data_torch}
3569-
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in ["r", "w", "k", "v", "a", "g"]):
3572+
if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list):
35703573
new_name = f"blk.{bid}.time_mix_lerp_fused.weight"
3571-
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"].squeeze(0) for i in ["r", "w", "k", "v", "a", "g"]], dim=0)
3574+
data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0)
35723575
yield (new_name, data)
35733576
return
35743577
else:

src/llama-model.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3356,7 +3356,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
33563356
layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0);
33573357
layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0);
33583358

3359-
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 6}, 0);
3359+
layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0);
33603360

33613361
layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0);
33623362
layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0);

src/llama.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1060,7 +1060,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
10601060
bool has_gating = layer->time_mix_g1 && layer->time_mix_g2;
10611061

10621062
struct ggml_tensor * sx = ggml_sub(ctx, x_prev, cur);
1063-
struct ggml_tensor * dummy = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd, n_tokens, layer->time_mix_lerp_fused->ne[2]);
1063+
struct ggml_tensor * dummy = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5);
10641064
sx = ggml_repeat(ctx, sx, dummy);
10651065

10661066
struct ggml_tensor * xxx = ggml_add(ctx, ggml_mul(ctx, sx, layer->time_mix_lerp_fused), cur);
@@ -1149,7 +1149,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
11491149
}
11501150
cur = llm_build_lora_mm(lctx, ctx, layer->time_mix_output, cur);
11511151

1152-
return cur;
1152+
return ggml_reshape_3d(ctx, cur, n_embd, n_seq_tokens, n_seqs);
11531153
}
11541154

11551155
static struct ggml_tensor * llm_build_rwkv7_channel_mix(
@@ -7768,9 +7768,9 @@ struct llm_build_context {
77687768
if (il == n_layer - 1) {
77697769
// skip computing output for unused tokens
77707770
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7771-
inp_ffn = ggml_get_rows(ctx0, x_norm_ffn, inp_out_ids);
7772-
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
7773-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7771+
inp_ffn = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_norm_ffn, n_embd, n_tokens), inp_out_ids);
7772+
x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
7773+
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
77747774
}
77757775

77767776
cur = ggml_add(ctx0, cur, llm_build_rwkv6_channel_mix(lctx, ctx0, layer, inp_ffn, x_prev));
@@ -8002,9 +8002,9 @@ struct llm_build_context {
80028002
if (il == n_layer - 1) {
80038003
// skip computing output for unused tokens
80048004
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
8005-
inp_ffn = ggml_get_rows(ctx0, x_norm_ffn, inp_out_ids);
8006-
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
8007-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
8005+
inp_ffn = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_norm_ffn, n_embd, n_tokens), inp_out_ids);
8006+
x_prev = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
8007+
cur = ggml_get_rows(ctx0, ggml_reshape_2d(ctx0, cur, n_embd, n_tokens), inp_out_ids);
80088008
}
80098009

80108010
cur = ggml_add(ctx0, cur, llm_build_rwkv7_channel_mix(lctx, ctx0, layer, inp_ffn, x_prev));

0 commit comments

Comments
 (0)