Skip to content

Commit 16b3f9c

Browse files
committed
Valgrind debugging session / multi-chunk support
1 parent 5417f32 commit 16b3f9c

File tree

1 file changed

+80
-60
lines changed

1 file changed

+80
-60
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 80 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -10530,62 +10530,73 @@ static void delta_compute_k_beta_key_t_f32(const float * k_beta, const float * k
1053010530
}
1053110531

1053210532
// Helper function to apply triangular updates to entire chunk (all sequences and heads)
10533-
static void delta_apply_triangular_updates_chunk_f32(float * attn, const int64_t chunk_size,
10534-
const int64_t n_seqs, const int64_t H_v) {
10533+
static void delta_apply_triangular_updates_chunk_f32(float * attn,
10534+
const int64_t chunk_size,
10535+
const int64_t n_seqs,
10536+
const int64_t H_v,
10537+
int num_chunks) {
1053510538
for (int64_t seq = 0; seq < n_seqs; seq++) {
10536-
for (int64_t head = 0; head < H_v; head++) {
10537-
float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
10538-
10539-
// Apply triangular updates following the Python reference exactly:
10540-
// for i in range(1, chunk_size):
10541-
// row = attn[..., i, :i].clone()
10542-
// sub = attn[..., :i, :i].clone()
10543-
// attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
10544-
for (int64_t i = 1; i < chunk_size; i++) {
10545-
// Create temporary storage for row and sub to avoid modifying during computation
10546-
float * row = (float *) malloc(i * sizeof(float));
10547-
float * sub = (float *) malloc(i * i * sizeof(float));
10548-
10549-
// Copy row = attn[..., i, :i]
10550-
for (int64_t j = 0; j < i; j++) {
10551-
row[j] = attn_ptr[i * chunk_size + j];
10552-
}
10553-
10554-
// Copy sub = attn[..., :i, :i]
10555-
for (int64_t k = 0; k < i; k++) {
10539+
for (int i = 0; i < num_chunks; i++) {
10540+
for (int64_t head = 0; head < H_v; head++) {
10541+
float * attn_ptr = attn + seq * (chunk_size * chunk_size * H_v) + (head * num_chunks + i) * (chunk_size * chunk_size);
10542+
10543+
// Apply triangular updates following the Python reference exactly:
10544+
// for i in range(1, chunk_size):
10545+
// row = attn[..., i, :i].clone()
10546+
// sub = attn[..., :i, :i].clone()
10547+
// attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
10548+
for (int64_t i = 1; i < chunk_size; i++) {
10549+
// Create temporary storage for row and sub to avoid modifying during computation
10550+
float * row = (float *) malloc(i * sizeof(float));
10551+
float * sub = (float *) malloc(i * i * sizeof(float));
10552+
10553+
// Copy row = attn[..., i, :i]
1055610554
for (int64_t j = 0; j < i; j++) {
10557-
sub[k * i + j] = attn_ptr[k * chunk_size + j];
10555+
row[j] = attn_ptr[i * chunk_size + j];
1055810556
}
10559-
}
10560-
10561-
// Compute updates for each j in :i
10562-
for (int64_t j = 0; j < i; j++) {
10563-
// Compute (row.unsqueeze(-1) * sub).sum(-2)
10564-
float sum_val = 0.0f;
10557+
10558+
// Copy sub = attn[..., :i, :i]
1056510559
for (int64_t k = 0; k < i; k++) {
10566-
sum_val += row[k] * sub[k * i + j];
10560+
for (int64_t j = 0; j < i; j++) {
10561+
sub[k * i + j] = attn_ptr[k * chunk_size + j];
10562+
}
10563+
}
10564+
10565+
// Compute updates for each j in :i
10566+
for (int64_t j = 0; j < i; j++) {
10567+
// Compute (row.unsqueeze(-1) * sub).sum(-2)
10568+
float sum_val = 0.0f;
10569+
for (int64_t k = 0; k < i; k++) {
10570+
sum_val += row[k] * sub[k * i + j];
10571+
}
10572+
10573+
// Update: attn[..., i, j] = row[j] + sum_val
10574+
attn_ptr[i * chunk_size + j] = row[j] + sum_val;
1056710575
}
10568-
10569-
// Update: attn[..., i, j] = row[j] + sum_val
10570-
attn_ptr[i * chunk_size + j] = row[j] + sum_val;
10576+
10577+
free(row);
10578+
free(sub);
1057110579
}
10572-
10573-
free(row);
10574-
free(sub);
1057510580
}
1057610581
}
1057710582
}
1057810583
}
1057910584

1058010585
// Helper function to add identity matrix to entire chunk (all sequences and heads)
10581-
static void delta_add_identity_matrix_chunk_f32(float * matrix, const int64_t chunk_size,
10582-
const int64_t n_seqs, const int64_t H_v) {
10586+
static void delta_add_identity_matrix_chunk_f32(float * matrix,
10587+
const int64_t chunk_size,
10588+
const int64_t n_seqs,
10589+
const int64_t H_v,
10590+
int num_chunks) {
1058310591
for (int64_t seq = 0; seq < n_seqs; seq++) {
10584-
for (int64_t head = 0; head < H_v; head++) {
10585-
float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) + head * (chunk_size * chunk_size);
10586-
// Add identity matrix directly
10587-
for (int64_t i = 0; i < chunk_size; i++) {
10588-
matrix_ptr[i * chunk_size + i] += 1.0f;
10592+
for (int i = 0; i < num_chunks; i++) {
10593+
for (int64_t head = 0; head < H_v; head++) {
10594+
float * matrix_ptr = matrix + seq * (chunk_size * chunk_size * H_v) +
10595+
(head * num_chunks + i) * (chunk_size * chunk_size);
10596+
// Add identity matrix directly
10597+
for (int64_t i = 0; i < chunk_size; i++) {
10598+
matrix_ptr[i * chunk_size + i] += 1.0f;
10599+
}
1058910600
}
1059010601
}
1059110602
}
@@ -10617,15 +10628,19 @@ static void delta_compute_value_f32(const float * attn,
1061710628
const int64_t chunk_size,
1061810629
const int64_t v_head_dim,
1061910630
const int64_t n_heads,
10620-
const int64_t n_seqs) {
10631+
const int64_t n_seqs,
10632+
int num_chunks) {
1062110633
for (int64_t seq = 0; seq < n_seqs; seq++) {
10622-
for (int64_t head = 0; head < n_heads; head++) {
10623-
delta_matmul_f32(
10624-
attn + (chunk_size * chunk_size * n_heads) * seq + (chunk_size * chunk_size) * head,
10625-
v_beta + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
10626-
value + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
10627-
chunk_size, v_head_dim, chunk_size);
10634+
for (int i = 0; i < num_chunks; i++) {
10635+
for (int64_t head = 0; head < n_heads; head++) {
10636+
delta_matmul_f32(
10637+
attn + (chunk_size * chunk_size * n_heads) * seq + (chunk_size * chunk_size) * (head * num_chunks + i),
10638+
v_beta + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
10639+
value + (chunk_size * v_head_dim * n_heads) * seq + (chunk_size * v_head_dim) * head,
10640+
chunk_size, v_head_dim, chunk_size);
10641+
}
1062810642
}
10643+
1062910644
}
1063010645
}
1063110646

@@ -10913,11 +10928,11 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1091310928
// int64_t total_params = n_seqs * H_v * num_chunks;
1091410929
// int64_t per_thread = (total_params % nth == 0) ? total_params / nth : (total_params / nth) + 1;
1091510930

10916-
float * attn = (float *) malloc(chunk_size * chunk_size * H_v * n_seqs * sizeof(float));
10917-
float * value = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
10918-
float * k_cumdecay = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
10931+
float * attn = (float *) malloc(chunk_size * chunk_size * H_v * num_chunks * n_seqs * sizeof(float));
10932+
float * value = (float *) malloc(chunk_size * S_v * H_v * num_chunks * n_seqs * sizeof(float));
10933+
float * k_cumdecay = (float *) malloc(chunk_size * S_v * H_v * num_chunks * n_seqs * sizeof(float));
1091910934
bool * mask = (bool *) malloc(chunk_size * chunk_size * sizeof(bool));
10920-
float * g = (float *) malloc(chunk_size * H_v * n_seqs * sizeof(float));
10935+
float * g = (float *) malloc(chunk_size * H_v * num_chunks * n_seqs * sizeof(float));
1092110936

1092210937
// Create upper triangular mask for causal attention (exclude diagonal)
1092310938
for (int64_t i = 0; i < chunk_size; i++) {
@@ -10934,18 +10949,20 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1093410949
// This corresponds to the reference implementation:
1093510950
// for i in range(1, chunk_size): attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
1093610951
// attn = attn + torch.eye(chunk_size)
10937-
delta_apply_triangular_updates_chunk_f32(attn, chunk_size, n_seqs, H_v);
10938-
delta_add_identity_matrix_chunk_f32(attn, chunk_size, n_seqs, H_v);
10952+
delta_apply_triangular_updates_chunk_f32(attn, chunk_size, n_seqs, H_v, num_chunks);
10953+
delta_add_identity_matrix_chunk_f32(attn, chunk_size, n_seqs, H_v, num_chunks);
1093910954

1094010955
// Compute value = attn @ v_beta
10941-
delta_compute_value_f32(attn, (const float *) src6->data, value, chunk_size, S_v, H_v, n_seqs);
10956+
delta_compute_value_f32(attn, (const float *) src6->data, value, chunk_size, S_v, H_v, n_seqs, num_chunks);
1094210957
for (int64_t seq = 0; seq < n_seqs; seq++) {
10943-
for (int64_t head = 0; head < H_v; head++) {
10958+
for (int i = 0; i < num_chunks; i++) {
10959+
for (int64_t head = 0; head < H_v; head++) {
1094410960
delta_compute_k_cumdecay_f32(attn + (chunk_size * chunk_size * H_v) * seq + (chunk_size * chunk_size) * head,
1094510961
(float *) src7->data + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
1094610962
g + (chunk_size * H_v) * seq + chunk_size * head,
1094710963
k_cumdecay + (chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head,
1094810964
chunk_size, S_v);
10965+
}
1094910966
}
1095010967
}
1095110968
print_debug_info(k_cumdecay, chunk_size * S_v * H_v * n_seqs, "k_cumdecay", -1);
@@ -10996,7 +11013,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1099611013

1099711014
// Compute q_g_exp = q * g.exp()
1099811015
for (int64_t i = 0; i < chunk_size; i++) {
10999-
for (int64_t d = 0; d < S_v; d++) {
11016+
for (int64_t d = 0; d < S_v; d++) {
1100011017
q_g_exp_ptr[i * S_v + d] = q_ptr[i * S_v + d] * expf(g_ptr[i]);
1100111018
}
1100211019
}
@@ -11196,8 +11213,11 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1119611213
for (int64_t head = 0; head < H_v; head++) {
1119711214
float * core_attn_out_ptr = core_attn_out + seq * (chunk_size * S_v * H_v) + head * (chunk_size * S_v);
1119811215

11216+
// Compute number of tokens for this chunk (chunk_size unless this is the last chunk)
11217+
int64_t n_tokens_chunk = chunk == num_chunks - 1 ? n_tokens % chunk_size : chunk_size;
11218+
1119911219
// Store output for this chunk
11200-
for (int64_t i = 0; i < n_tokens; i++) {
11220+
for (int64_t i = 0; i < n_tokens_chunk; i++) {
1120111221
for (int64_t d = 0; d < S_v; d++) {
1120211222
int64_t output_idx =
1120311223
seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d;

0 commit comments

Comments
 (0)