Skip to content

Commit 7eef0bd

Browse files
committed
Rewrite recurrent delta + softmax to separate ops
1 parent 554593d commit 7eef0bd

File tree

13 files changed

+535
-94
lines changed

13 files changed

+535
-94
lines changed

ggml/include/ggml.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ extern "C" {
544544
GGML_OP_GATED_LINEAR_ATTN,
545545
GGML_OP_RWKV_WKV7,
546546
GGML_OP_DELTA_NET,
547+
GGML_OP_DELTA_NET_RECURRENT,
547548

548549
GGML_OP_UNARY,
549550

@@ -578,6 +579,8 @@ extern "C" {
578579
GGML_UNARY_OP_HARDSWISH,
579580
GGML_UNARY_OP_HARDSIGMOID,
580581
GGML_UNARY_OP_EXP,
582+
GGML_UNARY_OP_EXPM1,
583+
GGML_UNARY_OP_SOFTPLUS,
581584
GGML_UNARY_OP_GELU_ERF,
582585

583586
GGML_UNARY_OP_COUNT,
@@ -961,6 +964,22 @@ extern "C" {
961964
struct ggml_context * ctx,
962965
struct ggml_tensor * a);
963966

967+
GGML_API struct ggml_tensor * ggml_expm1(
968+
struct ggml_context * ctx,
969+
struct ggml_tensor * a);
970+
971+
GGML_API struct ggml_tensor * ggml_expm1_inplace(
972+
struct ggml_context * ctx,
973+
struct ggml_tensor * a);
974+
975+
GGML_API struct ggml_tensor * ggml_softplus(
976+
struct ggml_context * ctx,
977+
struct ggml_tensor * a);
978+
979+
GGML_API struct ggml_tensor * ggml_softplus_inplace(
980+
struct ggml_context * ctx,
981+
struct ggml_tensor * a);
982+
964983
GGML_API struct ggml_tensor * ggml_sin(
965984
struct ggml_context * ctx,
966985
struct ggml_tensor * a);
@@ -1164,6 +1183,22 @@ extern "C" {
11641183
struct ggml_context * ctx,
11651184
struct ggml_tensor * a);
11661185

1186+
GGML_API struct ggml_tensor * ggml_expm1(
1187+
struct ggml_context * ctx,
1188+
struct ggml_tensor * a);
1189+
1190+
GGML_API struct ggml_tensor * ggml_expm1_inplace(
1191+
struct ggml_context * ctx,
1192+
struct ggml_tensor * a);
1193+
1194+
GGML_API struct ggml_tensor * ggml_softplus(
1195+
struct ggml_context * ctx,
1196+
struct ggml_tensor * a);
1197+
1198+
GGML_API struct ggml_tensor * ggml_softplus_inplace(
1199+
struct ggml_context * ctx,
1200+
struct ggml_tensor * a);
1201+
11671202
// gated linear unit ops
11681203
// A: n columns, r rows,
11691204
// result is n / 2 columns, r rows,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,6 +2010,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20102010
{
20112011
ggml_compute_forward_delta_net_f32(params, tensor);
20122012
} break;
2013+
case GGML_OP_DELTA_NET_RECURRENT:
2014+
{
2015+
ggml_compute_forward_delta_net_recurrent_f32(params, tensor);
2016+
} break;
20132017
case GGML_OP_MAP_CUSTOM1:
20142018
{
20152019
ggml_compute_forward_map_custom1(params, tensor);
@@ -2193,6 +2197,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
21932197
case GGML_UNARY_OP_HARDSWISH:
21942198
case GGML_UNARY_OP_HARDSIGMOID:
21952199
case GGML_UNARY_OP_EXP:
2200+
case GGML_UNARY_OP_SOFTPLUS:
2201+
case GGML_UNARY_OP_EXPM1:
21962202
{
21972203
n_tasks = 1;
21982204
} break;
@@ -2288,6 +2294,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
22882294
case GGML_OP_POOL_1D:
22892295
case GGML_OP_POOL_2D:
22902296
case GGML_OP_POOL_2D_BACK:
2297+
case GGML_OP_DELTA_NET_RECURRENT:
22912298
{
22922299
n_tasks = 1;
22932300
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9861,6 +9861,14 @@ void ggml_compute_forward_unary(
98619861
{
98629862
ggml_compute_forward_exp(params, dst);
98639863
} break;
9864+
case GGML_UNARY_OP_EXPM1:
9865+
{
9866+
ggml_compute_forward_expm1(params, dst);
9867+
} break;
9868+
case GGML_UNARY_OP_SOFTPLUS:
9869+
{
9870+
ggml_compute_forward_softplus(params, dst);
9871+
} break;
98649872
default:
98659873
{
98669874
GGML_ABORT("fatal error");
@@ -10874,6 +10882,200 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1087410882
}
1087510883
}
1087610884

10885+
static void print_debug_info(float * data, size_t size, const char * name, int64_t token) {
10886+
GGML_LOG_INFO("\nggml-debug: %s (%ld) first 5 values: [%.6f, %.6f, %.6f, %.6f, %.6f, ...]\n",
10887+
name, token, data[0], data[1], data[2], data[3], data[4]);
10888+
double sum = 0.0;
10889+
for (unsigned int i = 0; i < size; i++) {
10890+
sum += data[i];
10891+
}
10892+
GGML_LOG_INFO("sum = %.10f\n", sum);
10893+
}
10894+
10895+
void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * params, ggml_tensor * dst) {
10896+
const struct ggml_tensor * src0 = dst->src[0]; // q_tokens
10897+
const struct ggml_tensor * src1 = dst->src[1]; // k_tokens
10898+
const struct ggml_tensor * src2 = dst->src[2]; // v_tokens
10899+
const struct ggml_tensor * src3 = dst->src[3]; // g_tokens_exp
10900+
const struct ggml_tensor * src4 = dst->src[4]; // beta_tokens
10901+
const struct ggml_tensor * src5 = dst->src[5]; // state
10902+
// src6, src7, src8 are nullptr in recurrent version
10903+
10904+
const int64_t H_v = (int64_t) dst->op_params[0];
10905+
const int64_t S_k = (int64_t) dst->op_params[1];
10906+
const int64_t S_v = (int64_t) dst->op_params[2];
10907+
const int64_t original_n_tokens = (int64_t) dst->op_params[3]; // Get original sequence length
10908+
const int64_t n_tokens = original_n_tokens; // Use the original sequence length
10909+
const int64_t n_seqs = src0->ne[3]; // q tensor has n_seqs in dim 3
10910+
10911+
// Add assertions to verify tensor dimensions
10912+
GGML_ASSERT(src0->ne[3] == n_seqs); // q tensor
10913+
GGML_ASSERT(src1->ne[3] == n_seqs); // k tensor
10914+
GGML_ASSERT(src2->ne[3] == n_seqs); // v tensor
10915+
GGML_ASSERT(src3->ne[3] == n_seqs); // g tensor
10916+
GGML_ASSERT(src4->ne[3] == n_seqs); // beta tensor
10917+
GGML_ASSERT(src5->ne[3] == n_seqs); // state tensor
10918+
10919+
float * dst_data = (float *) dst->data;
10920+
// Output is first part, state is second part
10921+
float * output = dst_data; // [S_v * H_v * n_tokens * n_seqs]
10922+
float * final_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v * S_v * H_v * n_seqs]
10923+
10924+
const int ith = params->ith;
10925+
// const int nth = params->nth;
10926+
10927+
// Clear output and new state section
10928+
if (ith == 0) {
10929+
memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
10930+
} else {
10931+
return; // only calculate on one thread
10932+
}
10933+
10934+
float * state_data = (float *) src5->data; // state is now src5
10935+
10936+
GGML_ASSERT(ggml_is_contiguous(src0));
10937+
GGML_ASSERT(ggml_is_contiguous(src1));
10938+
GGML_ASSERT(ggml_is_contiguous(src2));
10939+
GGML_ASSERT(ggml_is_contiguous(src3));
10940+
GGML_ASSERT(ggml_is_contiguous(src4));
10941+
GGML_ASSERT(ggml_is_contiguous(src5));
10942+
10943+
const auto state_ptr = [state_data, src5] (int64_t seq, int64_t head, int64_t i, int64_t j) {
10944+
return state_data + (j * src5->nb[0] / sizeof(float)) + (i * src5->nb[1] / sizeof(float)) +
10945+
(head * src5->nb[2] / sizeof(float)) + (seq * src5->nb[3] / sizeof(float));
10946+
};
10947+
10948+
// Process each token sequentially across all sequences and heads (recurrent processing)
10949+
// Following the PyTorch reference: for each token i, process all sequences and heads
10950+
for (int64_t token = 0; token < n_tokens; token++) {
10951+
const auto q_t = [token, src0] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd(src0, token, i, head, seq); };
10952+
const auto k_t = [token, src1] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd(src1, token, i, head, seq); };
10953+
const auto v_t = [token, src2] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd(src2, token, i, head, seq); };
10954+
const auto g_exp_t = [token, src3] (int64_t seq, int64_t head) { return ggml_get_f32_nd(src3, token, 0, head, seq); };
10955+
const auto beta_t = [token, src4] (int64_t seq, int64_t head) { return ggml_get_f32_nd(src4, token, 0, head, seq); };
10956+
10957+
float * delta = (float *)malloc(S_v * H_v * n_seqs * sizeof(float));
10958+
float * kv_mem = (float *)malloc(S_v * H_v * n_seqs * sizeof(float));
10959+
float * attn_out_t = (float *)malloc(S_v * H_v * n_seqs * sizeof(float));
10960+
10961+
// Create temporary arrays for processing all sequences and heads at once
10962+
float * temp_state = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
10963+
10964+
// Initialize temp_state with current state values for all sequences and heads
10965+
for (int64_t seq = 0; seq < n_seqs; seq++) {
10966+
for (int64_t head = 0; head < H_v; head++) {
10967+
for (int64_t i = 0; i < S_v; i++) {
10968+
for (int64_t j = 0; j < S_v; j++) {
10969+
int64_t idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
10970+
temp_state[idx] = *(state_ptr(seq, head, i, j));
10971+
}
10972+
}
10973+
}
10974+
}
10975+
print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
10976+
10977+
// 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
10978+
for (int64_t seq = 0; seq < n_seqs; seq++) {
10979+
for (int64_t head = 0; head < H_v; head++) {
10980+
float g_exp = g_exp_t(seq, head);
10981+
for (int64_t i = 0; i < S_v; i++) {
10982+
for (int64_t j = 0; j < S_v; j++) {
10983+
int64_t idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
10984+
temp_state[idx] *= g_exp;
10985+
}
10986+
}
10987+
}
10988+
}
10989+
print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
10990+
10991+
// 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
10992+
for (int64_t seq = 0; seq < n_seqs; seq++) {
10993+
for (int64_t head = 0; head < H_v; head++) {
10994+
for (int64_t j = 0; j < S_v; j++) {
10995+
kv_mem[seq * H_v * S_v + head * S_v + j] = 0.0f;
10996+
for (int64_t i = 0; i < S_v; i++) {
10997+
int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
10998+
// This implements: (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
10999+
kv_mem[seq * H_v * S_v + head * S_v + j] += temp_state[state_idx] * k_t(seq, head, i);
11000+
}
11001+
}
11002+
}
11003+
}
11004+
print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
11005+
11006+
// 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
11007+
for (int64_t seq = 0; seq < n_seqs; seq++) {
11008+
for (int64_t head = 0; head < H_v; head++) {
11009+
float beta_val = beta_t(seq, head);
11010+
for (int64_t j = 0; j < S_v; j++) {
11011+
delta[seq * H_v * S_v + head * S_v + j] =
11012+
(v_t(seq, head, j) - kv_mem[seq * H_v * S_v + head * S_v + j]) * beta_val;
11013+
}
11014+
}
11015+
}
11016+
print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
11017+
11018+
// 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
11019+
for (int64_t seq = 0; seq < n_seqs; seq++) {
11020+
for (int64_t head = 0; head < H_v; head++) {
11021+
for (int64_t i = 0; i < S_v; i++) {
11022+
for (int64_t j = 0; j < S_v; j++) {
11023+
int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
11024+
// k_t[i] * delta[j] (where delta is treated as column vector)
11025+
temp_state[state_idx] += k_t(seq, head, i) * delta[seq * H_v * S_v + head * S_v + j];
11026+
}
11027+
}
11028+
}
11029+
}
11030+
print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
11031+
11032+
// 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
11033+
for (int64_t seq = 0; seq < n_seqs; seq++) {
11034+
for (int64_t head = 0; head < H_v; head++) {
11035+
for (int64_t j = 0; j < S_v; j++) {
11036+
attn_out_t[seq * H_v * S_v + head * S_v + j] = 0.0f;
11037+
for (int64_t i = 0; i < S_v; i++) {
11038+
int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
11039+
attn_out_t[seq * H_v * S_v + head * S_v + j] += temp_state[state_idx] * q_t(seq, head, i);
11040+
}
11041+
}
11042+
}
11043+
}
11044+
print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
11045+
11046+
// Store the output for this token (for all seqs and heads)
11047+
for (int64_t seq = 0; seq < n_seqs; seq++) {
11048+
for (int64_t head = 0; head < H_v; head++) {
11049+
for (int64_t d = 0; d < S_v; d++) {
11050+
int64_t output_idx = d + head * S_v + token * (S_v * H_v) + seq * (S_v * H_v * n_tokens);
11051+
output[output_idx] = attn_out_t[seq * H_v * S_v + head * S_v + d];
11052+
}
11053+
}
11054+
}
11055+
11056+
// Update the working state for next token iteration (in the state tensor for all seqs and heads)
11057+
for (int64_t seq = 0; seq < n_seqs; seq++) {
11058+
for (int64_t head = 0; head < H_v; head++) {
11059+
for (int64_t i = 0; i < S_v; i++) {
11060+
for (int64_t j = 0; j < S_v; j++) {
11061+
int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
11062+
*(state_ptr(seq, head, i, j)) = temp_state[state_idx];
11063+
11064+
// Store the final state for this head and sequence (for output)
11065+
int64_t final_state_idx = i + j * S_v + head * (S_v * S_v) + seq * (S_v * S_v * H_v);
11066+
final_state[final_state_idx] = temp_state[state_idx];
11067+
}
11068+
}
11069+
}
11070+
}
11071+
11072+
free(temp_state);
11073+
free(delta);
11074+
free(kv_mem);
11075+
free(attn_out_t);
11076+
}
11077+
}
11078+
1087711079
// ggml_compute_forward_rwkv_wkv7
1087811080
static void ggml_compute_forward_rwkv_wkv7_f32(
1087911081
const ggml_compute_params * params,

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params,
103103
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
104104
void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst);
105105
void ggml_compute_forward_delta_net_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
106+
void ggml_compute_forward_delta_net_recurrent_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
106107
void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107108
void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108109
void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst);

ggml/src/ggml-cpu/unary-ops.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ static inline float op_log(float x) {
6464
return logf(x);
6565
}
6666

67+
static inline float op_expm1(float x) {
68+
return expf(x) - 1.0f;
69+
}
70+
71+
static inline float op_softplus(float x) {
72+
return (x > 20.0f) ? x : logf(1.0f + expf(x));
73+
}
74+
6775
template <float (*op)(float), typename src0_t, typename dst_t>
6876
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
6977
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
@@ -184,3 +192,11 @@ void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor *
184192
void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
185193
unary_op<op_log>(params, dst);
186194
}
195+
196+
void ggml_compute_forward_expm1(const ggml_compute_params * params, ggml_tensor * dst) {
197+
unary_op<op_expm1>(params, dst);
198+
}
199+
200+
void ggml_compute_forward_softplus(const ggml_compute_params * params, ggml_tensor * dst) {
201+
unary_op<op_softplus>(params, dst);
202+
}

ggml/src/ggml-cpu/unary-ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
2222
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
2323
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
2424
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
25+
void ggml_compute_forward_expm1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
26+
void ggml_compute_forward_softplus(const struct ggml_compute_params * params, struct ggml_tensor * dst);
2527

2628
#ifdef __cplusplus
2729
}

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2333,6 +2333,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23332333
case GGML_UNARY_OP_ELU:
23342334
ggml_cuda_op_elu(ctx, dst);
23352335
break;
2336+
case GGML_UNARY_OP_EXPM1:
2337+
ggml_cuda_op_expm1(ctx, dst);
2338+
break;
2339+
case GGML_UNARY_OP_SOFTPLUS:
2340+
ggml_cuda_op_softplus(ctx, dst);
2341+
break;
23362342
default:
23372343
return false;
23382344
}
@@ -3314,6 +3320,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
33143320
case GGML_UNARY_OP_GELU_QUICK:
33153321
case GGML_UNARY_OP_TANH:
33163322
case GGML_UNARY_OP_EXP:
3323+
case GGML_UNARY_OP_EXPM1:
3324+
case GGML_UNARY_OP_SOFTPLUS:
33173325
case GGML_UNARY_OP_ELU:
33183326
return ggml_is_contiguous(op->src[0]);
33193327
default:

0 commit comments

Comments
 (0)