@@ -10530,62 +10530,73 @@ static void delta_compute_k_beta_key_t_f32(const float * k_beta, const float * k
1053010530}
1053110531
1053210532// Helper function to apply triangular updates to entire chunk (all sequences and heads)
10533- static void delta_apply_triangular_updates_chunk_f32 (float * attn, const int64_t chunk_size,
10534- const int64_t n_seqs, const int64_t H_v) {
10533+ static void delta_apply_triangular_updates_chunk_f32 (float * attn,
10534+ const int64_t chunk_size,
10535+ const int64_t n_seqs,
10536+ const int64_t H_v,
10537+ int num_chunks) {
1053510538 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
10536- for (int64_t head = 0 ; head < H_v; head++) {
10537- float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
10538-
10539- // Apply triangular updates following the Python reference exactly:
10540- // for i in range(1, chunk_size):
10541- // row = attn[..., i, :i].clone()
10542- // sub = attn[..., :i, :i].clone()
10543- // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
10544- for (int64_t i = 1 ; i < chunk_size; i++) {
10545- // Create temporary storage for row and sub to avoid modifying during computation
10546- float * row = (float *) malloc (i * sizeof (float ));
10547- float * sub = (float *) malloc (i * i * sizeof (float ));
10548-
10549- // Copy row = attn[..., i, :i]
10550- for (int64_t j = 0 ; j < i; j++) {
10551- row[j] = attn_ptr[i * chunk_size + j];
10552- }
10553-
10554- // Copy sub = attn[..., :i, :i]
10555- for (int64_t k = 0 ; k < i; k++) {
10539+ for (int i = 0 ; i < num_chunks; i++) {
10540+ for (int64_t head = 0 ; head < H_v; head++) {
10541+ float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + (head * num_chunks + i) * (chunk_size * chunk_size);
10542+
10543+ // Apply triangular updates following the Python reference exactly:
10544+ // for i in range(1, chunk_size):
10545+ // row = attn[..., i, :i].clone()
10546+ // sub = attn[..., :i, :i].clone()
10547+ // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
10548+ for (int64_t i = 1 ; i < chunk_size; i++) {
10549+ // Create temporary storage for row and sub to avoid modifying during computation
10550+ float * row = (float *) malloc (i * sizeof (float ));
10551+ float * sub = (float *) malloc (i * i * sizeof (float ));
10552+
10553+ // Copy row = attn[..., i, :i]
1055610554 for (int64_t j = 0 ; j < i; j++) {
10557- sub[k * i + j] = attn_ptr[k * chunk_size + j];
10555+ row[ j] = attn_ptr[i * chunk_size + j];
1055810556 }
10559- }
10560-
10561- // Compute updates for each j in :i
10562- for (int64_t j = 0 ; j < i; j++) {
10563- // Compute (row.unsqueeze(-1) * sub).sum(-2)
10564- float sum_val = 0 .0f ;
10557+
10558+ // Copy sub = attn[..., :i, :i]
1056510559 for (int64_t k = 0 ; k < i; k++) {
10566- sum_val += row[k] * sub[k * i + j];
10560+ for (int64_t j = 0 ; j < i; j++) {
10561+ sub[k * i + j] = attn_ptr[k * chunk_size + j];
10562+ }
10563+ }
10564+
10565+ // Compute updates for each j in :i
10566+ for (int64_t j = 0 ; j < i; j++) {
10567+ // Compute (row.unsqueeze(-1) * sub).sum(-2)
10568+ float sum_val = 0 .0f ;
10569+ for (int64_t k = 0 ; k < i; k++) {
10570+ sum_val += row[k] * sub[k * i + j];
10571+ }
10572+
10573+ // Update: attn[..., i, j] = row[j] + sum_val
10574+ attn_ptr[i * chunk_size + j] = row[j] + sum_val;
1056710575 }
10568-
10569- // Update: attn[..., i, j] = row[j] + sum_val
10570- attn_ptr[i * chunk_size + j] = row[j] + sum_val ;
10576+
10577+ free ( row);
10578+ free (sub) ;
1057110579 }
10572-
10573- free (row);
10574- free (sub);
1057510580 }
1057610581 }
1057710582 }
1057810583}
1057910584
1058010585// Helper function to add identity matrix to entire chunk (all sequences and heads)
10581- static void delta_add_identity_matrix_chunk_f32 (float * matrix, const int64_t chunk_size,
10582- const int64_t n_seqs, const int64_t H_v) {
10586+ static void delta_add_identity_matrix_chunk_f32 (float * matrix,
10587+ const int64_t chunk_size,
10588+ const int64_t n_seqs,
10589+ const int64_t H_v,
10590+ int num_chunks) {
1058310591 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
10584- for (int64_t head = 0 ; head < H_v; head++) {
10585- float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
10586- // Add identity matrix directly
10587- for (int64_t i = 0 ; i < chunk_size; i++) {
10588- matrix_ptr[i * chunk_size + i] += 1 .0f ;
10592+ for (int i = 0 ; i < num_chunks; i++) {
10593+ for (int64_t head = 0 ; head < H_v; head++) {
10594+ float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) +
10595+ (head * num_chunks + i) * (chunk_size * chunk_size);
10596+ // Add identity matrix directly
10597+ for (int64_t i = 0 ; i < chunk_size; i++) {
10598+ matrix_ptr[i * chunk_size + i] += 1 .0f ;
10599+ }
1058910600 }
1059010601 }
1059110602 }
@@ -10617,15 +10628,19 @@ static void delta_compute_value_f32(const float * attn,
1061710628 const int64_t chunk_size,
1061810629 const int64_t v_head_dim,
1061910630 const int64_t n_heads,
10620- const int64_t n_seqs) {
10631+ const int64_t n_seqs,
10632+ int num_chunks) {
1062110633 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
10622- for (int64_t head = 0 ; head < n_heads; head++) {
10623- delta_matmul_f32 (
10624- attn + (chunk_size * chunk_size * n_heads) * seq + (chunk_size * chunk_size) * head,
10625- v_beta + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
10626- value + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
10627- chunk_size, v_head_dim, chunk_size);
10634+ for (int i = 0 ; i < num_chunks; i++) {
10635+ for (int64_t head = 0 ; head < n_heads; head++) {
10636+ delta_matmul_f32 (
10637+ attn + (chunk_size * chunk_size * n_heads) * seq + (chunk_size * chunk_size) * (head * num_chunks + i),
10638+ v_beta + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
10639+ value + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
10640+ chunk_size, v_head_dim, chunk_size);
10641+ }
1062810642 }
10643+
1062910644 }
1063010645}
1063110646
@@ -10913,11 +10928,11 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1091310928 // int64_t total_params = n_seqs * H_v * num_chunks;
1091410929 // int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
1091510930
10916- float * attn = (float *) malloc (chunk_size * chunk_size * H_v * n_seqs * sizeof (float ));
10917- float * value = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
10918- float * k_cumdecay = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
10931+ float * attn = (float *) malloc (chunk_size * chunk_size * H_v * num_chunks * n_seqs * sizeof (float ));
10932+ float * value = (float *) malloc (chunk_size * S_v * H_v * num_chunks * n_seqs * sizeof (float ));
10933+ float * k_cumdecay = (float *) malloc (chunk_size * S_v * H_v * num_chunks * n_seqs * sizeof (float ));
1091910934 bool * mask = (bool *) malloc (chunk_size * chunk_size * sizeof (bool ));
10920- float * g = (float *) malloc (chunk_size * H_v * n_seqs * sizeof (float ));
10935+ float * g = (float *) malloc (chunk_size * H_v * num_chunks * n_seqs * sizeof (float ));
1092110936
1092210937 // Create upper triangular mask for causal attention (exclude diagonal)
1092310938 for (int64_t i = 0 ; i < chunk_size; i++) {
@@ -10934,18 +10949,20 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1093410949 // This corresponds to the reference implementation:
1093510950 // for i in range(1, chunk_size): attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
1093610951 // attn = attn + torch.eye(chunk_size)
10937- delta_apply_triangular_updates_chunk_f32 (attn, chunk_size, n_seqs, H_v);
10938- delta_add_identity_matrix_chunk_f32 (attn, chunk_size, n_seqs, H_v);
10952+ delta_apply_triangular_updates_chunk_f32 (attn, chunk_size, n_seqs, H_v, num_chunks );
10953+ delta_add_identity_matrix_chunk_f32 (attn, chunk_size, n_seqs, H_v, num_chunks );
1093910954
1094010955 // Compute value = attn @ v_beta
10941- delta_compute_value_f32 (attn, (const float *) src6->data , value, chunk_size, S_v, H_v, n_seqs);
10956+ delta_compute_value_f32 (attn, (const float *) src6->data , value, chunk_size, S_v, H_v, n_seqs, num_chunks );
1094210957 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
10943- for (int64_t head = 0 ; head < H_v; head++) {
10958+ for (int i = 0 ; i < num_chunks; i++) {
10959+ for (int64_t head = 0 ; head < H_v; head++) {
1094410960 delta_compute_k_cumdecay_f32 (attn + (chunk_size * chunk_size * H_v) * seq + (chunk_size * chunk_size) * head,
1094510961 (float *) src7->data + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
1094610962 g + (chunk_size * H_v) * seq + chunk_size * head,
1094710963 k_cumdecay + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
1094810964 chunk_size, S_v);
10965+ }
1094910966 }
1095010967 }
1095110968 print_debug_info (k_cumdecay, chunk_size * S_v * H_v * n_seqs, " k_cumdecay" , -1 );
@@ -10996,7 +11013,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1099611013
1099711014 // Compute q_g_exp = q * g.exp()
1099811015 for (int64_t i = 0 ; i < chunk_size; i++) {
10999- for (int64_t d = 0 ; d < S_v; d++) {
11016+ for (int64_t d = 0 ; d < S_v; d++) {
1100011017 q_g_exp_ptr[i * S_v + d] = q_ptr[i * S_v + d] * expf (g_ptr[i]);
1100111018 }
1100211019 }
@@ -11196,8 +11213,11 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1119611213 for (int64_t head = 0 ; head < H_v; head++) {
1119711214 float * core_attn_out_ptr = core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
1119811215
11216+ // Compute number of tokens for this chunk (chunk_size unless this is the last chunk)
11217+ int64_t n_tokens_chunk = chunk == num_chunks - 1 ? n_tokens % chunk_size : chunk_size;
11218+
1119911219 // Store output for this chunk
11200- for (int64_t i = 0 ; i < n_tokens ; i++) {
11220+ for (int64_t i = 0 ; i < n_tokens_chunk ; i++) {
1120111221 for (int64_t d = 0 ; d < S_v; d++) {
1120211222 int64_t output_idx =
1120311223 seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d;
0 commit comments