@@ -10964,17 +10964,13 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1096410964 };
1096510965
1096610966 // Allocate per-chunk arrays containing all sequences and heads
10967- float * temp_state = (float *) malloc (S_v * S_v * H_v * n_seqs * sizeof (float ));
1096810967 float * core_attn_out = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
1096910968 float * attn_inter = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
1097010969 float * v_new = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
1097110970 float * v_prime = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
1097210971 float * g_diff_exp = (float *) malloc (chunk_size * H_v * n_seqs * sizeof (float ));
1097310972 float * g_last = (float *) malloc (H_v * n_seqs * sizeof (float ));
1097410973
10975- // Initialize temp_state with zeros for all sequences and heads (state should be empty initially)
10976- memset (temp_state, 0 , S_v * S_v * H_v * n_seqs * sizeof (float ));
10977-
1097810974 // Create temporary arrays for entire chunk
1097910975 float * q_chunk_data = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
1098010976 float * k_chunk_data = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
@@ -11177,14 +11173,14 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1117711173 for (int64_t head = 0 ; head < H_v; head++) {
1117811174 for (int i = 0 ; i < S_v; i++) {
1117911175 for (int j = 0 ; j < S_v; j++) {
11180- temp_state [(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] =
11176+ new_state [(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] =
1118111177 state_data[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] * g_last[seq * H_v + head] +
1118211178 kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j];
1118311179 }
1118411180 }
1118511181 }
1118611182 }
11187- print_debug_info (temp_state , S_v * S_v * H_v * n_seqs, " temp_state " , chunk);
11183+ print_debug_info (new_state , S_v * S_v * H_v * n_seqs, " state_end_chunk " , chunk);
1118811184
1118911185 // Free temporary memory
1119011186 free (q_chunk_data);
@@ -11225,7 +11221,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1122511221 // }
1122611222 print_debug_info (new_state, S_v * S_v * H_v * n_seqs, " new_state" , chunk);
1122711223
11228- free (temp_state);
1122911224 free (core_attn_out);
1123011225 free (attn_inter);
1123111226 free (v_new);
0 commit comments