Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 222 additions & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,228 @@ ggml_tensor * llama_context::build_mamba_layer(
}


ggml_tensor * llama_context::build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
const auto & hparams = model.hparams;

const auto token_shift_count = hparams.token_shift_count;

const auto & n_tokens = ubatch.n_tokens;
const int64_t n_seqs = ubatch.n_seqs;

struct ggml_tensor * token_shift_all = kv_self.k_l[il];

struct ggml_tensor * token_shift = build_copy_mask_state(
ctx0, graph, token_shift_all, state_copy, state_mask,
n_tokens, hparams.n_embd_k_s(), n_seqs, worst_case);

token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);

return token_shift;
}


ggml_tensor * llama_context::build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
const auto & hparams = model.hparams;

const auto token_shift_count = hparams.token_shift_count;
const auto n_embd = hparams.n_embd;

const auto & n_tokens = ubatch.n_tokens;
const int64_t n_seqs = ubatch.n_seqs;

const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;

return ggml_cpy(
ctx0,
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il]))
);
}


ggml_tensor * llama_context::build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case) {
const auto & hparams = model.hparams;

const auto n_tokens = ubatch.n_tokens;
const auto n_seqs = ubatch.n_seqs;
const auto n_embd = hparams.n_embd;
const auto head_size = hparams.wkv_head_size;
const auto n_head = n_embd / head_size;
const auto n_head_kv = hparams.n_head_kv(il);

const auto kv_head = worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head;

const auto layer = &model.layers[il];

bool is_qrwkv = layer->time_mix_first == nullptr;

struct ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
struct ggml_tensor * xxx = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->time_mix_lerp_x), cur);

xxx = ggml_reshape_4d(
ctx0,
ggml_tanh(
ctx0,
ggml_mul_mat(ctx0, layer->time_mix_w1, xxx)
),
layer->time_mix_w1->ne[1] / 5, 1, 5, n_tokens
);

xxx = ggml_cont(ctx0, ggml_permute(ctx0, xxx, 0, 1, 3, 2));

xxx = ggml_mul_mat(
ctx0,
ggml_reshape_4d(
ctx0,
layer->time_mix_w2,
layer->time_mix_w2->ne[0], layer->time_mix_w2->ne[1], 1, 5
),
xxx
);

struct ggml_tensor *xw, *xk, *xv, *xr, *xg;
if (layer->time_mix_lerp_fused) {
// fusing these weights makes some performance improvement
sx = ggml_reshape_3d(ctx0, sx, n_embd, 1, n_tokens);
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
xxx = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xxx, layer->time_mix_lerp_fused), sx), cur);
xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));
} else {
// for backward compatibility
xw = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], 0);
xk = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * sizeof(float));
xv = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 2 * sizeof(float));
xr = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 3 * sizeof(float));
xg = ggml_view_2d(ctx0, xxx, n_embd, n_tokens, xxx->nb[1], n_embd * n_tokens * 4 * sizeof(float));

xw = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xw, layer->time_mix_lerp_w), sx), cur);
xk = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xk, layer->time_mix_lerp_k), sx), cur);
xv = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xv, layer->time_mix_lerp_v), sx), cur);
xr = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xr, layer->time_mix_lerp_r), sx), cur);
xg = ggml_add(ctx0, ggml_mul(ctx0, ggml_add(ctx0, xg, layer->time_mix_lerp_g), sx), cur);
}

struct ggml_tensor * r = build_lora_mm(ctx0, layer->time_mix_receptance, xr);
struct ggml_tensor * k = build_lora_mm(ctx0, layer->time_mix_key, xk);
struct ggml_tensor * v = build_lora_mm(ctx0, layer->time_mix_value, xv);
if (layer->time_mix_receptance_b) {
r = ggml_add(ctx0, r, layer->time_mix_receptance_b);
}
if (layer->time_mix_key_b) {
k = ggml_add(ctx0, k, layer->time_mix_key_b);
}
if (layer->time_mix_value_b) {
v = ggml_add(ctx0, v, layer->time_mix_value_b);
}

struct ggml_tensor * g = build_lora_mm(ctx0, layer->time_mix_gate, xg);
if (is_qrwkv) {
g = ggml_sigmoid(ctx0, g);
} else {
g = ggml_silu(ctx0, g);
}

if (n_head_kv != 0 && n_head_kv != n_head) {
GGML_ASSERT(n_head % n_head_kv == 0);
k = ggml_reshape_4d(ctx0, k, head_size, 1, n_head_kv, n_tokens);
v = ggml_reshape_4d(ctx0, v, head_size, 1, n_head_kv, n_tokens);
struct ggml_tensor * tmp = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, head_size, n_head / n_head_kv, n_head_kv, n_tokens);
k = ggml_repeat(ctx0, k, tmp);
v = ggml_repeat(ctx0, v, tmp);
}

k = ggml_reshape_3d(ctx0, k, head_size, n_head, n_tokens);
v = ggml_reshape_3d(ctx0, v, head_size, n_head, n_tokens);
r = ggml_reshape_3d(ctx0, r, head_size, n_head, n_tokens);

struct ggml_tensor * w = ggml_mul_mat(
ctx0,
layer->time_mix_decay_w2,
ggml_tanh(
ctx0,
ggml_mul_mat(ctx0, layer->time_mix_decay_w1, xw)
)
);

w = ggml_add(ctx0, w, layer->time_mix_decay);
w = ggml_exp(ctx0, ggml_neg(ctx0, ggml_exp(ctx0, w)));
w = ggml_reshape_3d(ctx0, w, head_size, n_head, n_tokens);

if (is_qrwkv) {
// k = k * (1 - w)
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
}

struct ggml_tensor * wkv_state = build_copy_mask_state(
ctx0, graph, kv_self.v_l[il], state_copy, state_mask,
n_tokens, hparams.n_embd_v_s(), n_seqs, worst_case);

struct ggml_tensor * wkv_output;
if (is_qrwkv) {
wkv_output = ggml_gated_linear_attn(ctx0, k, v, r, w, wkv_state, pow(head_size, -0.5f));
} else {
wkv_output = ggml_rwkv_wkv6(ctx0, k, v, r, layer->time_mix_first, w, wkv_state);
}
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
wkv_state = ggml_view_1d(ctx0, wkv_output, n_embd * head_size * n_seqs, n_embd * n_tokens * sizeof(float));

ggml_build_forward_expand(
graph,
ggml_cpy(
ctx0,
wkv_state,
ggml_view_1d(
ctx0,
kv_self.v_l[il],
hparams.n_embd_v_s() * n_seqs,
hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il])
)
)
);

if (!is_qrwkv) {
// group norm with head_count groups
cur = ggml_reshape_3d(ctx0, cur, n_embd / n_head, n_head, n_tokens);
cur = ggml_norm(ctx0, cur, 64e-5f);

// Convert back to regular vectors.
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer->time_mix_ln), layer->time_mix_ln_b);
} else {
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
}

cur = ggml_mul(ctx0, cur, g);
cur = build_lora_mm(ctx0, layer->time_mix_output, cur);

return cur;
}

// llama output

size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
Expand Down
27 changes: 27 additions & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,33 @@ struct llama_context {
int il,
bool worst_case);

ggml_tensor * build_rwkv_token_shift_load(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case);

ggml_tensor * build_rwkv_token_shift_store(
ggml_context * ctx0,
ggml_tensor * token_shift,
const llama_ubatch & ubatch,
int il,
bool worst_case);

ggml_tensor * build_rwkv6_time_mix(
ggml_context * ctx0,
ggml_cgraph * graph,
ggml_tensor * cur,
ggml_tensor * x_prev,
ggml_tensor * state_copy,
ggml_tensor * state_mask,
const llama_ubatch & ubatch,
int il,
bool worst_case);

struct ggml_tensor * inp_s_copy; // I32 [kv_size]
struct ggml_tensor * inp_s_mask; // F32 [1, n_kv]

Expand Down
Loading
Loading