Skip to content

Commit 477c161

Browse files
committed
Parallelize delta_net
1 parent 4ef6f33 commit 477c161

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10694,15 +10694,12 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1069410694
float * new_state = dst_data + (S_v * H_v * n_tokens); // [S_v * H_v, S_v * n_seqs, 1, 1]
1069510695

1069610696
const int ith = params->ith;
10697-
// const int nth = params->nth; // nth is unused
10698-
10699-
// TODO: parallelize across heads and sequences
10700-
if (ith != 0) {
10701-
return;
10702-
}
10697+
const int nth = params->nth; // nth is unused
1070310698

1070410699
// Clear output and new state section
10705-
memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
10700+
if (ith == 0) {
10701+
memset(output, 0, ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof(float));
10702+
}
1070610703

1070710704
// Calculate chunk size
1070810705
const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
@@ -10730,9 +10727,16 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1073010727
GGML_ASSERT(ggml_is_contiguous(src7));
1073110728
GGML_ASSERT(ggml_is_contiguous(src8));
1073210729

10730+
int64_t total_params = n_seqs * H_v * num_chunks;
10731+
int64_t per_thread = total_params / nth;
10732+
1073310733
for (int64_t seq = 0; seq < n_seqs; seq++) {
1073410734
for (int64_t head = 0; head < H_v; head++) {
1073510735
for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
10736+
int64_t tidx = seq * (H_v * num_chunks) + head * num_chunks + chunk;
10737+
if (tidx < ith * per_thread || tidx >= (ith + 1) * per_thread) {
10738+
continue; // not our thread;
10739+
}
1073610740
float * attn_data_for_chs = attn_data + (src8->nb[3] / sizeof(float)) * seq + (src8->nb[2] / sizeof(float)) * (chunk + head * num_chunks);
1073710741
float * value_chunk = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));
1073810742
float * k_cumdecay = (float *) malloc(S_v * chunk_size * H_v * n_seqs * sizeof(float));

0 commit comments

Comments
 (0)