@@ -1060,7 +1060,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
10601060 bool has_gating = layer->time_mix_g1 && layer->time_mix_g2 ;
10611061
10621062 struct ggml_tensor * sx = ggml_sub (ctx, x_prev, cur);
1063- struct ggml_tensor * dummy = ggml_new_tensor_3d (ctx, GGML_TYPE_F32, n_embd, n_tokens, layer-> time_mix_lerp_fused -> ne [ 2 ] );
1063+ struct ggml_tensor * dummy = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, n_embd, n_seq_tokens, n_seqs, has_gating ? 6 : 5 );
10641064 sx = ggml_repeat (ctx, sx, dummy);
10651065
10661066 struct ggml_tensor * xxx = ggml_add (ctx, ggml_mul (ctx, sx, layer->time_mix_lerp_fused ), cur);
@@ -1149,7 +1149,7 @@ static struct ggml_tensor * llm_build_rwkv7_time_mix(
11491149 }
11501150 cur = llm_build_lora_mm (lctx, ctx, layer->time_mix_output , cur);
11511151
1152- return cur;
1152+ return ggml_reshape_3d (ctx, cur, n_embd, n_seq_tokens, n_seqs) ;
11531153}
11541154
11551155static struct ggml_tensor * llm_build_rwkv7_channel_mix (
@@ -7768,9 +7768,9 @@ struct llm_build_context {
77687768 if (il == n_layer - 1 ) {
77697769 // skip computing output for unused tokens
77707770 struct ggml_tensor * inp_out_ids = build_inp_out_ids ();
7771- inp_ffn = ggml_get_rows (ctx0, x_norm_ffn, inp_out_ids);
7772- x_prev = ggml_get_rows (ctx0, x_prev, inp_out_ids);
7773- cur = ggml_get_rows (ctx0, cur, inp_out_ids);
7771+ inp_ffn = ggml_get_rows (ctx0, ggml_reshape_2d (ctx0, x_norm_ffn, n_embd, n_tokens) , inp_out_ids);
7772+ x_prev = ggml_get_rows (ctx0, ggml_reshape_2d (ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
7773+ cur = ggml_get_rows (ctx0, ggml_reshape_2d (ctx0, cur, n_embd, n_tokens), inp_out_ids);
77747774 }
77757775
77767776 cur = ggml_add (ctx0, cur, llm_build_rwkv6_channel_mix (lctx, ctx0, layer, inp_ffn, x_prev));
@@ -8002,9 +8002,9 @@ struct llm_build_context {
80028002 if (il == n_layer - 1 ) {
80038003 // skip computing output for unused tokens
80048004 struct ggml_tensor * inp_out_ids = build_inp_out_ids ();
8005- inp_ffn = ggml_get_rows (ctx0, x_norm_ffn, inp_out_ids);
8006- x_prev = ggml_get_rows (ctx0, x_prev, inp_out_ids);
8007- cur = ggml_get_rows (ctx0, cur, inp_out_ids);
8005+ inp_ffn = ggml_get_rows (ctx0, ggml_reshape_2d (ctx0, x_norm_ffn, n_embd, n_tokens) , inp_out_ids);
8006+ x_prev = ggml_get_rows (ctx0, ggml_reshape_2d (ctx0, x_prev, n_embd, n_tokens), inp_out_ids);
8007+ cur = ggml_get_rows (ctx0, ggml_reshape_2d (ctx0, cur, n_embd, n_tokens), inp_out_ids);
80088008 }
80098009
80108010 cur = ggml_add (ctx0, cur, llm_build_rwkv7_channel_mix (lctx, ctx0, layer, inp_ffn, x_prev));
0 commit comments