Skip to content

Commit 9de7244

Browse files
committed
Fix memory corruption
1 parent 75586ea commit 9de7244

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10865,7 +10865,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1086510865
float * dst_data = (float *) dst->data;
1086610866
// Following GLA pattern: output is first part, state is second part
1086710867
float * output = dst_data; // [S_v * H_v, n_tokens, 1, n_seqs] - only real sequence length, not padded
10868-
float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v * H_v, S_v * n_seqs, 1, 1]
10868+
float * new_state = dst_data + (S_v * H_v * n_tokens * n_seqs); // [S_v, S_v * H_v, 1, n_seqs]
1086910869

1087010870
const int ith = params->ith;
1087110871
// const int nth = params->nth; // nth is unused
@@ -10884,6 +10884,20 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1088410884

1088510885
float * state_data = (float *) src4->data;
1088610886

10887+
// Init new state with initial state (will probably be zeroes)
10888+
for (int64_t seq = 0; seq < n_seqs; seq++) {
10889+
for (int64_t head = 0; head < H_v; head++) {
10890+
for (int64_t i = 0; i < S_v; i++) {
10891+
for (int64_t j = 0; j < S_v; j++) {
10892+
new_state[seq * (S_v * S_v * H_v) + head * (S_v * S_v) + i * S_v + j] =
10893+
state_data[seq * src4->nb[3] / sizeof(float) + (head * S_v + i) * src4->nb[1] / sizeof(float) + j * src4->nb[0] / sizeof(float)];
10894+
}
10895+
}
10896+
}
10897+
}
10898+
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "init_state", -1);
10899+
10900+
1088710901
GGML_ASSERT(ggml_is_contiguous(src0));
1088810902
GGML_ASSERT(ggml_is_contiguous(src1));
1088910903
GGML_ASSERT(ggml_is_contiguous(src2));
@@ -10896,12 +10910,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1089610910

1089710911
// int64_t total_params = n_seqs * H_v * num_chunks;
1089810912
// int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
10899-
10900-
// Create helper lambda for state tensor access
10901-
const auto state_ptr = [state_data, src4] (int64_t seq, int64_t head, int64_t i, int64_t j) {
10902-
return state_data + (j * src4->nb[0] / sizeof(float)) + (i * src4->nb[1] / sizeof(float)) +
10903-
(head * src4->nb[2] / sizeof(float)) + (seq * src4->nb[3] / sizeof(float));
10904-
};
1090510913

1090610914
float * attn = (float *) malloc(chunk_size * chunk_size * H_v * n_seqs * sizeof(float));
1090710915
float * value = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
@@ -11048,15 +11056,15 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1104811056

1104911057
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
1105011058
// k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
11051-
delta_matmul_state_chunk_f32(k_cumdecay, state_data, v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
11059+
delta_matmul_state_chunk_f32(k_cumdecay, new_state, v_prime, chunk_size, S_v, S_v, n_seqs, H_v);
1105211060
print_debug_info(v_prime, chunk_size * S_v * H_v * n_seqs, "v_prime_chunk", chunk);
1105311061

1105411062
// v_new = v_i - v_prime
1105511063
delta_tensor_subtract_chunk_f32(value, v_prime, v_new, chunk_size * S_v, n_seqs, H_v);
1105611064
print_debug_info(v_new, chunk_size * S_v * H_v * n_seqs, "v_new_chunk", chunk);
1105711065

1105811066
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
11059-
delta_matmul_state_chunk_f32(q_g_exp, state_data, attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
11067+
delta_matmul_state_chunk_f32(q_g_exp, new_state, attn_inter, chunk_size, S_v, S_v, n_seqs, H_v);
1106011068
print_debug_info(attn_inter, chunk_size * S_v * H_v * n_seqs, "attn_inter_chunk", chunk);
1106111069

1106211070
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
@@ -11203,19 +11211,16 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1120311211
}
1120411212
}
1120511213
print_debug_info(output, S_v * H_v * n_tokens * n_seqs, "output", chunk);
11206-
11207-
// Update state tensor (all sequences and heads)
11214+
GGML_LOG_INFO("\nFull output tensor: \n\n");
1120811215
for (int64_t seq = 0; seq < n_seqs; seq++) {
1120911216
for (int64_t head = 0; head < H_v; head++) {
11210-
float * temp_state_ptr = temp_state + seq * (S_v * S_v * H_v) + head * (S_v * S_v);
11211-
11212-
for (int64_t i = 0; i < S_v; i++) {
11213-
for (int64_t j = 0; j < S_v; j++) {
11214-
int64_t state_idx = seq * S_v * S_v * H_v + head * S_v * S_v + i * S_v + j;
11215-
new_state[state_idx] = temp_state_ptr[i * S_v + j];
11216-
*(state_ptr(seq, head, i, j)) = temp_state_ptr[i * S_v + j];
11217+
GGML_LOG_INFO("\n[ ");
11218+
for (int64_t i = 0; i < n_tokens; i++) {
11219+
for (int64_t d = 0; d < S_v; d++) {
11220+
GGML_LOG_INFO("%.4f ", output[seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d]);
1121711221
}
1121811222
}
11223+
GGML_LOG_INFO(" ]");
1121911224
}
1122011225
}
1122111226
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);

0 commit comments

Comments
 (0)