Skip to content

Commit 875de2b

Browse files
committed
e steps forward, pi steps back
1 parent a60458e commit 875de2b

File tree

1 file changed

+2
-7
lines changed

1 file changed

+2
-7
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10964,17 +10964,13 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1096410964
};
1096510965

1096610966
// Allocate per-chunk arrays containing all sequences and heads
10967-
float * temp_state = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
1096810967
float * core_attn_out = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
1096910968
float * attn_inter = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
1097010969
float * v_new = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
1097110970
float * v_prime = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
1097210971
float * g_diff_exp = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
1097310972
float * g_last = (float *) malloc(H_v * n_seqs * sizeof(float));
1097410973

10975-
// Initialize temp_state with zeros for all sequences and heads (state should be empty initially)
10976-
memset(temp_state, 0, S_v * S_v * H_v * n_seqs * sizeof(float));
10977-
1097810974
// Create temporary arrays for entire chunk
1097910975
float * q_chunk_data = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
1098010976
float * k_chunk_data = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
@@ -11177,14 +11173,14 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1117711173
for (int64_t head = 0; head < H_v; head++) {
1117811174
for (int i = 0; i < S_v; i++) {
1117911175
for (int j = 0; j < S_v; j++) {
11180-
temp_state[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] =
11176+
new_state[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] =
1118111177
state_data[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j] * g_last[seq * H_v + head] +
1118211178
kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + S_v * i + j];
1118311179
}
1118411180
}
1118511181
}
1118611182
}
11187-
print_debug_info(temp_state, S_v * S_v * H_v * n_seqs, "temp_state", chunk);
11183+
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "state_end_chunk", chunk);
1118811184

1118911185
// Free temporary memory
1119011186
free(q_chunk_data);
@@ -11225,7 +11221,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1122511221
// }
1122611222
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);
1122711223

11228-
free(temp_state);
1122911224
free(core_attn_out);
1123011225
free(attn_inter);
1123111226
free(v_new);

0 commit comments

Comments
 (0)