@@ -10694,15 +10694,12 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1069410694 float * new_state = dst_data + (S_v * H_v * n_tokens); // [S_v * H_v, S_v * n_seqs, 1, 1]
1069510695
1069610696 const int ith = params->ith ;
10697- // const int nth = params->nth; // nth is unused
10698-
10699- // TODO: parallelize across heads and sequences
10700- if (ith != 0 ) {
10701- return ;
10702- }
10697+ const int nth = params->nth ; // nth is unused
1070310698
1070410699 // Clear output and new state section
10705- memset (output, 0 , ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof (float ));
10700+ if (ith == 0 ) {
10701+ memset (output, 0 , ((S_v * H_v * n_tokens * n_seqs) + (S_v * S_v * H_v * n_seqs)) * sizeof (float ));
10702+ }
1070610703
1070710704 // Calculate chunk size
1070810705 const int64_t chunk_size = GGML_DELTA_NET_CHUNK;
@@ -10730,9 +10727,16 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1073010727 GGML_ASSERT (ggml_is_contiguous (src7));
1073110728 GGML_ASSERT (ggml_is_contiguous (src8));
1073210729
10730+ int64_t total_params = n_seqs * H_v * num_chunks;
10731+ int64_t per_thread = total_params / nth;
10732+
1073310733 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
1073410734 for (int64_t head = 0 ; head < H_v; head++) {
1073510735 for (int64_t chunk = 0 ; chunk < num_chunks; chunk++) {
10736+ int64_t tidx = seq * (H_v * num_chunks) + head * num_chunks + chunk;
10737+ if (tidx < ith * per_thread || tidx >= (ith + 1 ) * per_thread) {
10738+ continue ; // not our thread;
10739+ }
1073610740 float * attn_data_for_chs = attn_data + (src8->nb [3 ] / sizeof (float )) * seq + (src8->nb [2 ] / sizeof (float )) * (chunk + head * num_chunks);
1073710741 float * value_chunk = (float *) malloc (S_v * chunk_size * H_v * n_seqs * sizeof (float ));
1073810742 float * k_cumdecay = (float *) malloc (S_v * chunk_size * H_v * n_seqs * sizeof (float ));
0 commit comments