@@ -9861,6 +9861,14 @@ void ggml_compute_forward_unary(
98619861 {
98629862 ggml_compute_forward_exp (params, dst);
98639863 } break ;
9864+ case GGML_UNARY_OP_EXPM1:
9865+ {
9866+ ggml_compute_forward_expm1 (params, dst);
9867+ } break ;
9868+ case GGML_UNARY_OP_SOFTPLUS:
9869+ {
9870+ ggml_compute_forward_softplus (params, dst);
9871+ } break ;
98649872 default :
98659873 {
98669874 GGML_ABORT (" fatal error" );
@@ -10874,6 +10882,200 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1087410882 }
1087510883}
1087610884
10885+ static void print_debug_info (float * data, size_t size, const char * name, int64_t token) {
10886+ GGML_LOG_INFO (" \n ggml-debug: %s (%ld) first 5 values: [%.6f, %.6f, %.6f, %.6f, %.6f, ...]\n " ,
10887+ name, token, data[0 ], data[1 ], data[2 ], data[3 ], data[4 ]);
10888+ double sum = 0.0 ;
10889+ for (unsigned int i = 0 ; i < size; i++) {
10890+ sum += data[i];
10891+ }
10892+ GGML_LOG_INFO (" sum = %.10f\n " , sum);
10893+ }
10894+
10895+ void ggml_compute_forward_delta_net_recurrent_f32 (const ggml_compute_params * params, ggml_tensor * dst) {
10896+ const struct ggml_tensor * src0 = dst->src [0 ]; // q_tokens
10897+ const struct ggml_tensor * src1 = dst->src [1 ]; // k_tokens
10898+ const struct ggml_tensor * src2 = dst->src [2 ]; // v_tokens
10899+ const struct ggml_tensor * src3 = dst->src [3 ]; // g_tokens_exp
10900+ const struct ggml_tensor * src4 = dst->src [4 ]; // beta_tokens
10901+ const struct ggml_tensor * src5 = dst->src [5 ]; // state
10902+ // src6, src7, src8 are nullptr in recurrent version
10903+
10904+ const int64_t H_v = (int64_t ) dst->op_params [0 ];
10905+ const int64_t S_k = (int64_t ) dst->op_params [1 ];
10906+ const int64_t S_v = (int64_t ) dst->op_params [2 ];
10907+ const int64_t original_n_tokens = (int64_t ) dst->op_params [3 ]; // Get original sequence length
10908+ const int64_t n_tokens = original_n_tokens; // Use the original sequence length
10909+ const int64_t n_seqs = src0->ne [3 ]; // q tensor has n_seqs in dim 3
10910+
10911+ // Add assertions to verify tensor dimensions
10912+ GGML_ASSERT (src0->ne [3 ] == n_seqs); // q tensor
10913+ GGML_ASSERT (src1->ne [3 ] == n_seqs); // k tensor
10914+ GGML_ASSERT (src2->ne [3 ] == n_seqs); // v tensor
10915+ GGML_ASSERT (src3->ne [3 ] == n_seqs); // g tensor
10916+ GGML_ASSERT (src4->ne [3 ] == n_seqs); // beta tensor
10917+ GGML_ASSERT (src5->ne [3 ] == n_seqs); // state tensor
10918+
10919+ float * dst_data = (float *) dst->data ;
10920+ // Output is first part, state is second part
10921+ float * output = dst_data; // [S_v * H_v * n_tokens * n_seqs]
10922+ float * final_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v * S_v * H_v * n_seqs]
10923+
10924+ const int ith = params->ith ;
10925+ // const int nth = params->nth;
10926+
10927+ // Clear output and new state section
10928+ if (ith == 0 ) {
10929+ memset (output, 0 , ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof (float ));
10930+ } else {
10931+ return ; // only calculate on one thread
10932+ }
10933+
10934+ float * state_data = (float *) src5->data ; // state is now src5
10935+
10936+ GGML_ASSERT (ggml_is_contiguous (src0));
10937+ GGML_ASSERT (ggml_is_contiguous (src1));
10938+ GGML_ASSERT (ggml_is_contiguous (src2));
10939+ GGML_ASSERT (ggml_is_contiguous (src3));
10940+ GGML_ASSERT (ggml_is_contiguous (src4));
10941+ GGML_ASSERT (ggml_is_contiguous (src5));
10942+
10943+ const auto state_ptr = [state_data, src5] (int64_t seq, int64_t head, int64_t i, int64_t j) {
10944+ return state_data + (j * src5->nb [0 ] / sizeof (float )) + (i * src5->nb [1 ] / sizeof (float )) +
10945+ (head * src5->nb [2 ] / sizeof (float )) + (seq * src5->nb [3 ] / sizeof (float ));
10946+ };
10947+
10948+ // Process each token sequentially across all sequences and heads (recurrent processing)
10949+ // Following the PyTorch reference: for each token i, process all sequences and heads
10950+ for (int64_t token = 0 ; token < n_tokens; token++) {
10951+ const auto q_t = [token, src0] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd (src0, token, i, head, seq); };
10952+ const auto k_t = [token, src1] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd (src1, token, i, head, seq); };
10953+ const auto v_t = [token, src2] (int64_t seq, int64_t head, int64_t i) { return ggml_get_f32_nd (src2, token, i, head, seq); };
10954+ const auto g_exp_t = [token, src3] (int64_t seq, int64_t head) { return ggml_get_f32_nd (src3, token, 0 , head, seq); };
10955+ const auto beta_t = [token, src4] (int64_t seq, int64_t head) { return ggml_get_f32_nd (src4, token, 0 , head, seq); };
10956+
10957+ float * delta = (float *)malloc (S_v * H_v * n_seqs * sizeof (float ));
10958+ float * kv_mem = (float *)malloc (S_v * H_v * n_seqs * sizeof (float ));
10959+ float * attn_out_t = (float *)malloc (S_v * H_v * n_seqs * sizeof (float ));
10960+
10961+ // Create temporary arrays for processing all sequences and heads at once
10962+ float * temp_state = (float *) malloc (S_v * S_v * H_v * n_seqs * sizeof (float ));
10963+
10964+ // Initialize temp_state with current state values for all sequences and heads
10965+ for (int64_t seq = 0 ; seq < n_seqs; seq++) {
10966+ for (int64_t head = 0 ; head < H_v; head++) {
10967+ for (int64_t i = 0 ; i < S_v; i++) {
10968+ for (int64_t j = 0 ; j < S_v; j++) {
10969+ int64_t idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
10970+ temp_state[idx] = *(state_ptr (seq, head, i, j));
10971+ }
10972+ }
10973+ }
10974+ }
10975+ print_debug_info (temp_state, n_seqs * H_v * S_v * S_v, " temp_state_copy" , token);
10976+
10977+ // 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
10978+ for (int64_t seq = 0 ; seq < n_seqs; seq++) {
10979+ for (int64_t head = 0 ; head < H_v; head++) {
10980+ float g_exp = g_exp_t (seq, head);
10981+ for (int64_t i = 0 ; i < S_v; i++) {
10982+ for (int64_t j = 0 ; j < S_v; j++) {
10983+ int64_t idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
10984+ temp_state[idx] *= g_exp;
10985+ }
10986+ }
10987+ }
10988+ }
10989+ print_debug_info (temp_state, n_seqs * H_v * S_v * S_v, " temp_state_times_g_t" , token);
10990+
10991+ // 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
10992+ for (int64_t seq = 0 ; seq < n_seqs; seq++) {
10993+ for (int64_t head = 0 ; head < H_v; head++) {
10994+ for (int64_t j = 0 ; j < S_v; j++) {
10995+ kv_mem[seq * H_v * S_v + head * S_v + j] = 0 .0f ;
10996+ for (int64_t i = 0 ; i < S_v; i++) {
10997+ int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
10998+ // This implements: (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
10999+ kv_mem[seq * H_v * S_v + head * S_v + j] += temp_state[state_idx] * k_t (seq, head, i);
11000+ }
11001+ }
11002+ }
11003+ }
11004+ print_debug_info (kv_mem, n_seqs * H_v * S_v, " kv_mem" , token);
11005+
11006+ // 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
11007+ for (int64_t seq = 0 ; seq < n_seqs; seq++) {
11008+ for (int64_t head = 0 ; head < H_v; head++) {
11009+ float beta_val = beta_t (seq, head);
11010+ for (int64_t j = 0 ; j < S_v; j++) {
11011+ delta[seq * H_v * S_v + head * S_v + j] =
11012+ (v_t (seq, head, j) - kv_mem[seq * H_v * S_v + head * S_v + j]) * beta_val;
11013+ }
11014+ }
11015+ }
11016+ print_debug_info (delta, n_seqs * H_v * S_v, " delta" , token);
11017+
11018+ // 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
11019+ for (int64_t seq = 0 ; seq < n_seqs; seq++) {
11020+ for (int64_t head = 0 ; head < H_v; head++) {
11021+ for (int64_t i = 0 ; i < S_v; i++) {
11022+ for (int64_t j = 0 ; j < S_v; j++) {
11023+ int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
11024+ // k_t[i] * delta[j] (where delta is treated as column vector)
11025+ temp_state[state_idx] += k_t (seq, head, i) * delta[seq * H_v * S_v + head * S_v + j];
11026+ }
11027+ }
11028+ }
11029+ }
11030+ print_debug_info (temp_state, n_seqs * H_v * S_v * S_v, " temp_state" , token);
11031+
11032+ // 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
11033+ for (int64_t seq = 0 ; seq < n_seqs; seq++) {
11034+ for (int64_t head = 0 ; head < H_v; head++) {
11035+ for (int64_t j = 0 ; j < S_v; j++) {
11036+ attn_out_t [seq * H_v * S_v + head * S_v + j] = 0 .0f ;
11037+ for (int64_t i = 0 ; i < S_v; i++) {
11038+ int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
11039+ attn_out_t [seq * H_v * S_v + head * S_v + j] += temp_state[state_idx] * q_t (seq, head, i);
11040+ }
11041+ }
11042+ }
11043+ }
11044+ print_debug_info (attn_out_t , n_seqs * S_v * H_v, " attn_out_t" , token);
11045+
11046+ // Store the output for this token (for all seqs and heads)
11047+ for (int64_t seq = 0 ; seq < n_seqs; seq++) {
11048+ for (int64_t head = 0 ; head < H_v; head++) {
11049+ for (int64_t d = 0 ; d < S_v; d++) {
11050+ int64_t output_idx = d + head * S_v + token * (S_v * H_v) + seq * (S_v * H_v * n_tokens);
11051+ output[output_idx] = attn_out_t [seq * H_v * S_v + head * S_v + d];
11052+ }
11053+ }
11054+ }
11055+
11056+ // Update the working state for next token iteration (in the state tensor for all seqs and heads)
11057+ for (int64_t seq = 0 ; seq < n_seqs; seq++) {
11058+ for (int64_t head = 0 ; head < H_v; head++) {
11059+ for (int64_t i = 0 ; i < S_v; i++) {
11060+ for (int64_t j = 0 ; j < S_v; j++) {
11061+ int64_t state_idx = seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j;
11062+ *(state_ptr (seq, head, i, j)) = temp_state[state_idx];
11063+
11064+ // Store the final state for this head and sequence (for output)
11065+ int64_t final_state_idx = i + j * S_v + head * (S_v * S_v) + seq * (S_v * S_v * H_v);
11066+ final_state[final_state_idx] = temp_state[state_idx];
11067+ }
11068+ }
11069+ }
11070+ }
11071+
11072+ free (temp_state);
11073+ free (delta);
11074+ free (kv_mem);
11075+ free (attn_out_t );
11076+ }
11077+ }
11078+
1087711079// ggml_compute_forward_rwkv_wkv7
1087811080static void ggml_compute_forward_rwkv_wkv7_f32 (
1087911081 const ggml_compute_params * params,
0 commit comments