@@ -10865,7 +10865,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1086510865 float * dst_data = (float *) dst->data ;
1086610866 // Following GLA pattern: output is first part, state is second part
1086710867 float * output = dst_data; // [S_v * H_v, n_tokens, 1, n_seqs] - only real sequence length, not padded
10868- float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v * H_v , S_v * n_seqs , 1, 1 ]
10868+ float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v, S_v * H_v , 1, n_seqs ]
1086910869
1087010870 const int ith = params->ith ;
1087110871 // const int nth = params->nth; // nth is unused
@@ -10884,6 +10884,20 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1088410884
1088510885 float * state_data = (float *) src4->data ;
1088610886
10887+ // Init new state with initial state (will probably be zeroes)
10888+ for (int64_t seq = 0 ; seq < n_seqs; seq++) {
10889+ for (int64_t head = 0 ; head < H_v; head++) {
10890+ for (int64_t i = 0 ; i < S_v; i++) {
10891+ for (int64_t j = 0 ; j < S_v; j++) {
10892+ new_state[seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j] =
10893+ state_data[seq * src4->nb [3 ] / sizeof (float ) + (head * S_v + i) * src4->nb [1 ] / sizeof (float ) + j * src4->nb [0 ] / sizeof (float )];
10894+ }
10895+ }
10896+ }
10897+ }
10898+ print_debug_info (new_state, S_v * S_v * H_v * n_seqs, " init_state" , -1 );
10899+
10900+
1088710901 GGML_ASSERT (ggml_is_contiguous (src0));
1088810902 GGML_ASSERT (ggml_is_contiguous (src1));
1088910903 GGML_ASSERT (ggml_is_contiguous (src2));
@@ -10896,12 +10910,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1089610910
1089710911 // int64_t total_params = n_seqs * H_v * num_chunks;
1089810912 // int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
10899-
10900- // Create helper lambda for state tensor access
10901- const auto state_ptr = [state_data, src4] (int64_t seq, int64_t head, int64_t i, int64_t j) {
10902- return state_data + (j * src4->nb [0 ] / sizeof (float )) + (i * src4->nb [1 ] / sizeof (float )) +
10903- (head * src4->nb [2 ] / sizeof (float )) + (seq * src4->nb [3 ] / sizeof (float ));
10904- };
1090510913
1090610914 float * attn = (float *) malloc (chunk_size * chunk_size * H_v * n_seqs * sizeof (float ));
1090710915 float * value = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
@@ -11048,15 +11056,15 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1104811056
1104911057 // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
1105011058 // k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
11051- delta_matmul_state_chunk_f32 (k_cumdecay, state_data , v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
11059+ delta_matmul_state_chunk_f32 (k_cumdecay, new_state , v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
1105211060 print_debug_info (v_prime, chunk_size * S_v * H_v * n_seqs, " v_prime_chunk" , chunk);
1105311061
1105411062 // v_new = v_i - v_prime
1105511063 delta_tensor_subtract_chunk_f32 (value, v_prime, v_new, chunk_size * S_v, n_seqs, H_v);
1105611064 print_debug_info (v_new, chunk_size * S_v * H_v * n_seqs, " v_new_chunk" , chunk);
1105711065
1105811066 // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
11059- delta_matmul_state_chunk_f32 (q_g_exp, state_data , attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
11067+ delta_matmul_state_chunk_f32 (q_g_exp, new_state , attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
1106011068 print_debug_info (attn_inter, chunk_size * S_v * H_v * n_seqs, " attn_inter_chunk" , chunk);
1106111069
1106211070 // core_attn_out[:, :, i] = attn_inter + attn @ v_new
@@ -11203,19 +11211,16 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1120311211 }
1120411212 }
1120511213 print_debug_info (output, S_v * H_v * n_tokens * n_seqs, " output" , chunk);
11206-
11207- // Update state tensor (all sequences and heads)
11214+ GGML_LOG_INFO (" \n Full output tensor: \n\n " );
1120811215 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
1120911216 for (int64_t head = 0 ; head < H_v; head++) {
11210- float * temp_state_ptr = temp_state + seq * (S_v * S_v * H_v) + head * (S_v * S_v);
11211-
11212- for (int64_t i = 0 ; i < S_v; i++) {
11213- for (int64_t j = 0 ; j < S_v; j++) {
11214- int64_t state_idx = seq * S_v * S_v * H_v + head * S_v * S_v + i * S_v + j;
11215- new_state[state_idx] = temp_state_ptr[i * S_v + j];
11216- *(state_ptr (seq, head, i, j)) = temp_state_ptr[i * S_v + j];
11217+ GGML_LOG_INFO (" \n [ " );
11218+ for (int64_t i = 0 ; i < n_tokens; i++) {
11219+ for (int64_t d = 0 ; d < S_v; d++) {
11220+ GGML_LOG_INFO (" %.4f " , output[seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d]);
1121711221 }
1121811222 }
11223+ GGML_LOG_INFO (" ]" );
1121911224 }
1122011225 }
1122111226 print_debug_info (new_state, S_v * S_v * H_v * n_seqs, " new_state" , chunk);
0 commit comments