Skip to content

Commit 29de86b

Browse files
committed
Cleanup & remove debugging stuff
1 parent 3049040 commit 29de86b

File tree

181 files changed

+745
-14520
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

181 files changed

+745
-14520
lines changed

ggml/include/ggml.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2417,6 +2417,28 @@ extern "C" {
24172417
struct ggml_tensor * b,
24182418
struct ggml_tensor * state);
24192419

2420+
GGML_API struct ggml_tensor * ggml_delta_net(
2421+
struct ggml_context * ctx,
2422+
struct ggml_tensor * q,
2423+
struct ggml_tensor * k,
2424+
struct ggml_tensor * v,
2425+
struct ggml_tensor * g,
2426+
struct ggml_tensor * beta,
2427+
struct ggml_tensor * state,
2428+
bool use_qk_l2norm,
2429+
float eps_norm);
2430+
2431+
GGML_API struct ggml_tensor * ggml_delta_net_recurrent(
2432+
struct ggml_context * ctx,
2433+
struct ggml_tensor * q,
2434+
struct ggml_tensor * k,
2435+
struct ggml_tensor * v,
2436+
struct ggml_tensor * g,
2437+
struct ggml_tensor * beta,
2438+
struct ggml_tensor * state,
2439+
bool use_qk_l2norm,
2440+
float eps_norm);
2441+
24202442
// custom operators
24212443

24222444
typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);

ggml/src/ggml-cpu/ops.cpp

Lines changed: 2 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -8728,7 +8728,7 @@ static void ggml_compute_forward_ssm_scan_f32(
87288728
// n_head
87298729
for (int h = ih0; h < ih1; ++h) {
87308730
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8731-
const float dt_soft_plus = ggml_softplus(dt[h]);
8731+
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
87328732
const float dA = expf(dt_soft_plus * A[h]);
87338733
const int g = h / (nh / ng); // repeat_interleave
87348734

@@ -8825,7 +8825,7 @@ static void ggml_compute_forward_ssm_scan_f32(
88258825
// n_head
88268826
for (int h = ih0; h < ih1; ++h) {
88278827
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8828-
const float dt_soft_plus = ggml_softplus(dt[h]);
8828+
const float dt_soft_plus = ggml_compute_softplus_f32(dt[h]);
88298829
const int g = h / (nh / ng); // repeat_interleave
88308830

88318831
// dim
@@ -9712,22 +9712,6 @@ void ggml_compute_forward_gla(
97129712
}
97139713
}
97149714

9715-
static double debug_sum(float * data, size_t size) {
9716-
double sum = 0.0;
9717-
for (unsigned int i = 0; i < size; i++) {
9718-
sum += data[i];
9719-
}
9720-
return sum;
9721-
}
9722-
9723-
static void print_debug_info(float * data, size_t size, const char * name, int64_t token) {
9724-
#ifdef MR_CHUNKY_TALKS
9725-
GGML_LOG_INFO("\nggml-debug: %s (%ld) first 5 values: [%.6f, %.6f, %.6f, %.6f, %.6f, ...]\n",
9726-
name, token, data[0], data[1], data[2], data[3], data[4]);
9727-
GGML_LOG_INFO("total elements: %ld, sum = %.10f\n", size, debug_sum(data, size));
9728-
#endif MR_CHUNKY_TALKS
9729-
}
9730-
97319715
// Helper function to compute cumulative sum
97329716
static void delta_cumsum_f32(const float * x, float * dst, const int64_t n) {
97339717
float cumsum = 0.0f;
@@ -9837,34 +9821,9 @@ static void delta_apply_triangular_updates_chunk_f32(float * attn,
98379821
attn_ptr[i * chunk_size + j] = row[j] + sum_val;
98389822
}
98399823

9840-
if (i % 10 == 0) {
9841-
if (seq == 1 && head == 0 && chunk == 0) {
9842-
print_debug_info(row, i, "row[1, 0, 0]", i);
9843-
print_debug_info(sub, i * i, "sub[1, 0, 0]", i);
9844-
}
9845-
if (seq == 0 && head == 1 && chunk == 0) {
9846-
print_debug_info(row, i, "row[0, 1, 0]", i);
9847-
print_debug_info(sub, i * i, "sub[0, 1, 0]", i);
9848-
}
9849-
if (seq == 0 && head == 0 && chunk == 1) {
9850-
print_debug_info(row, i, "row[0, 0, 1]", i);
9851-
print_debug_info(sub, i * i, "sub[0, 0, 1]", i);
9852-
}
9853-
}
9854-
98559824
free(row);
98569825
free(sub);
98579826
}
9858-
9859-
if (seq == 1 && head == 0 && chunk == 0) {
9860-
print_debug_info(attn_ptr, chunk_size * chunk_size, "attn[1, 0, 0]", 0);
9861-
}
9862-
if (seq == 0 && head == 1 && chunk == 0) {
9863-
print_debug_info(attn_ptr, chunk_size * chunk_size, "attn[0, 1, 0]", 0);
9864-
}
9865-
if (seq == 0 && head == 0 && chunk == 1) {
9866-
print_debug_info(attn_ptr, chunk_size * chunk_size, "attn[0, 0, 1]", 0);
9867-
}
98689827
}
98699828
}
98709829
}
@@ -10191,8 +10150,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1019110150
}
1019210151
}
1019310152
}
10194-
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "init_state", -1);
10195-
1019610153

1019710154
GGML_ASSERT(ggml_is_contiguous(src0));
1019810155
GGML_ASSERT(ggml_is_contiguous(src1));
@@ -10229,13 +10186,10 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1022910186
// for i in range(1, chunk_size): attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
1023010187
// attn = attn + torch.eye(chunk_size)
1023110188
delta_apply_triangular_updates_chunk_f32(attn, chunk_size, n_seqs, H_v, num_chunks);
10232-
print_debug_info(attn, chunk_size * chunk_size * H_v * num_chunks * n_seqs, "attn_chunk", -1);
1023310189
delta_add_identity_matrix_chunk_f32(attn, chunk_size, n_seqs, H_v, num_chunks);
10234-
print_debug_info(attn, chunk_size * chunk_size * H_v * num_chunks * n_seqs, "attn_eye", -1);
1023510190

1023610191
// Compute value = attn @ v_beta
1023710192
delta_compute_value_f32(attn, (const float *) src6->data, value, chunk_size, S_v, H_v, n_seqs, num_chunks);
10238-
print_debug_info(value, chunk_size * S_v * H_v * num_chunks * n_seqs, "value", -1);
1023910193

1024010194
for (int64_t seq = 0; seq < n_seqs; seq++) {
1024110195
for (int i = 0; i < num_chunks; i++) {
@@ -10248,7 +10202,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1024810202
}
1024910203
}
1025010204
}
10251-
print_debug_info(k_cumdecay, chunk_size * S_v * H_v * num_chunks * n_seqs, "k_cumdecay", -1);
1025210205

1025310206
// Process each chunk with all sequences and heads together
1025410207
for (int64_t chunk = 0; chunk < num_chunks; chunk++) {
@@ -10304,9 +10257,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1030410257
}
1030510258
}
1030610259

10307-
print_debug_info(pc_q_chunk_data, chunk_size * S_v * H_v * n_seqs, "q_i_chunk", chunk);
10308-
print_debug_info(pc_k_chunk_data, chunk_size * S_v * H_v * n_seqs, "k_i_chunk", chunk);
10309-
1031010260
// Step 4: Compute NEW attention matrix for this chunk: attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
1031110261
// Note: decay_mask[:, :, i] means we need to use the decay_mask for this specific chunk
1031210262
// The mask applied is the simple causal attention mask: torch.triu(torch.ones(chunk_size, chunk_size), diagonal=1)
@@ -10328,7 +10278,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1032810278
delta_matmul_f32(q_ptr, k_trans, attn_ptr, chunk_size, chunk_size, S_v);
1032910279
}
1033010280
}
10331-
print_debug_info(attn, chunk_size * chunk_size * num_chunks * H_v * n_seqs, "attn_q_k_trans", chunk);
1033210281

1033310282
for (int64_t seq = 0; seq < n_seqs; seq++) {
1033410283
for (int64_t head = 0; head < H_v; head++) {
@@ -10348,20 +10297,15 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1034810297
}
1034910298
}
1035010299

10351-
print_debug_info(attn, chunk_size * chunk_size * num_chunks * H_v * n_seqs, "attn_step4_new_chunk", chunk);
10352-
1035310300
// v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
1035410301
// k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
1035510302
delta_matmul_state_chunk_f32(k_cumdecay, new_state, pc_v_prime, chunk_size, S_v, S_v, n_seqs, H_v, chunk, num_chunks);
10356-
print_debug_info(pc_v_prime, chunk_size * S_v * H_v * n_seqs, "v_prime_chunk", chunk);
1035710303

1035810304
// v_new = v_i - v_prime
1035910305
delta_tensor_subtract_chunk_f32(value, pc_v_prime, pc_v_new, chunk_size * S_v, n_seqs, H_v, num_chunks, chunk);
10360-
print_debug_info(pc_v_new, chunk_size * S_v * H_v * n_seqs, "v_new_chunk", chunk);
1036110306

1036210307
// attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
1036310308
delta_matmul_state_chunk_f32(pc_q_g_exp, new_state, pc_attn_inter, chunk_size, S_v, S_v, n_seqs, H_v, -1, -1);
10364-
print_debug_info(pc_attn_inter, chunk_size * S_v * H_v * n_seqs, "attn_inter_chunk", chunk);
1036510309

1036610310
// core_attn_out[:, :, i] = attn_inter + attn @ v_new
1036710311
// Use regular matrix multiplication for attn @ v_new
@@ -10375,9 +10319,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1037510319
delta_matmul_f32(attn_ptr, v_new_ptr, attn_v_new_ptr, chunk_size, S_v, chunk_size);
1037610320
}
1037710321
}
10378-
print_debug_info(pc_attn_v_new, chunk_size * S_v * H_v * n_seqs, "attn_v_new_chunk", chunk);
1037910322
delta_tensor_add_chunk_f32(pc_attn_inter, pc_attn_v_new, pc_core_attn_out, chunk_size * S_v, n_seqs, H_v);
10380-
print_debug_info(pc_core_attn_out, chunk_size * S_v * H_v * n_seqs, "core_attn_out_chunk", chunk);
1038110323

1038210324
// Prepare g_last and g_diff_exp for state update
1038310325
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -10394,9 +10336,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1039410336
}
1039510337
}
1039610338

10397-
print_debug_info(pc_g_last, H_v * n_seqs, "g_last_chunk", chunk);
10398-
print_debug_info(pc_g_diff_exp, chunk_size * H_v * n_seqs, "g_diff_exp", chunk);
10399-
1040010339
float * k_g_diffexp = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
1040110340
for (int64_t seq = 0; seq < n_seqs; seq++) {
1040210341
for (int64_t head = 0; head < H_v; head++) {
@@ -10408,7 +10347,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1040810347
}
1040910348
}
1041010349
}
10411-
print_debug_info(k_g_diffexp, chunk_size * S_v * H_v * n_seqs, "k_g_diffexp", chunk);
1041210350
float * k_g_diffexp_T = (float *) malloc(chunk_size * S_v * H_v * n_seqs * sizeof(float));
1041310351
for (int64_t seq = 0; seq < n_seqs; seq++) {
1041410352
for (int64_t head = 0; head < H_v; head++) {
@@ -10421,25 +10359,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1042110359
}
1042210360
}
1042310361

10424-
// for (int64_t seq = 0; seq < n_seqs; seq++) {
10425-
// for (int64_t head = 0; head < H_v; head++) {
10426-
// GGML_LOG_INFO("Sequence %ld, head %ld: \n[ ", seq, head);
10427-
// for (int i = 0; i < chunk_size; i++) {
10428-
// GGML_LOG_INFO("[ ");
10429-
// for (int j = 0; j < S_v; j++) {
10430-
// GGML_LOG_INFO("%.6f", k_g_diffexp[(chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head + i * S_v + j]);
10431-
// if (j < chunk_size - 1) {
10432-
// GGML_LOG_INFO(", ");
10433-
// }
10434-
// }
10435-
// GGML_LOG_INFO("], \n");
10436-
// }
10437-
// GGML_LOG_INFO("]\n");
10438-
// }
10439-
// }
10440-
10441-
print_debug_info(k_g_diffexp_T, chunk_size * S_v * H_v * n_seqs, "k_g_diffexp_T", chunk);
10442-
1044310362
float * kgd_mul_vnew = (float *) malloc(S_v * S_v * H_v * n_seqs * sizeof(float));
1044410363

1044510364
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -10450,24 +10369,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1045010369
S_v, S_v, chunk_size);
1045110370
}
1045210371
}
10453-
print_debug_info(kgd_mul_vnew, S_v * S_v * H_v * n_seqs, "kgd_mul_vnew", chunk);
10454-
10455-
// for (int64_t seq = 0; seq < n_seqs; seq++) {
10456-
// for (int64_t head = 0; head < H_v; head++) {
10457-
// GGML_LOG_INFO("Sequence %ld, head %ld: \n[ ", seq, head);
10458-
// for (int i = 0; i < S_v; i++) {
10459-
// GGML_LOG_INFO("[ ");
10460-
// for (int j = 0; j < S_v; j++) {
10461-
// GGML_LOG_INFO("%.6f", kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + i * S_v + j]);
10462-
// if (j < S_v - 1) {
10463-
// GGML_LOG_INFO(", ");
10464-
// }
10465-
// }
10466-
// GGML_LOG_INFO("], \n");
10467-
// }
10468-
// GGML_LOG_INFO("]\n");
10469-
// }
10470-
// }
1047110372

1047210373
for (int64_t seq = 0; seq < n_seqs; seq++) {
1047310374
for (int64_t head = 0; head < H_v; head++) {
@@ -10480,7 +10381,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1048010381
}
1048110382
}
1048210383
}
10483-
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "state_end_chunk", chunk);
1048410384

1048510385
// Free temporary memory
1048610386
free(pc_q_chunk_data);
@@ -10511,21 +10411,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1051110411
}
1051210412
}
1051310413
}
10514-
print_debug_info(output, S_v * H_v * n_tokens * n_seqs, "output", chunk);
10515-
// GGML_LOG_INFO("\nFull output tensor: \n\n");
10516-
// for (int64_t seq = 0; seq < n_seqs; seq++) {
10517-
// for (int64_t head = 0; head < H_v; head++) {
10518-
// GGML_LOG_INFO("\n[ ");
10519-
// for (int64_t i = 0; i < n_tokens; i++) {
10520-
// for (int64_t d = 0; d < S_v; d++) {
10521-
// GGML_LOG_INFO("%.4f ", output[seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d]);
10522-
// }
10523-
// }
10524-
// GGML_LOG_INFO(" ]");
10525-
// }
10526-
// }
10527-
print_debug_info(new_state, S_v * S_v * H_v * n_seqs, "new_state", chunk);
10528-
1052910414
free(pc_core_attn_out);
1053010415
free(pc_attn_inter);
1053110416
free(pc_v_new);
@@ -10622,7 +10507,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1062210507
}
1062310508
}
1062410509
}
10625-
print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
1062610510

1062710511
// 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
1062810512
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -10636,7 +10520,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1063610520
}
1063710521
}
1063810522
}
10639-
print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
1064010523

1064110524
// 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
1064210525
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -10651,7 +10534,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1065110534
}
1065210535
}
1065310536
}
10654-
print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
1065510537

1065610538
// 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
1065710539
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -10663,7 +10545,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1066310545
}
1066410546
}
1066510547
}
10666-
print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
1066710548

1066810549
// 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
1066910550
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -10677,7 +10558,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1067710558
}
1067810559
}
1067910560
}
10680-
print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
1068110561

1068210562
// 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
1068310563
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -10691,7 +10571,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1069110571
}
1069210572
}
1069310573
}
10694-
print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
1069510574

1069610575
// Store the output for this token (for all seqs and heads)
1069710576
for (int64_t seq = 0; seq < n_seqs; seq++) {

ggml/src/ggml-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ static bool ggml_op_is_empty(enum ggml_op op) {
102102
}
103103
}
104104

105-
static inline float ggml_softplus(float input) {
105+
static inline float ggml_compute_softplus_f32(float input) {
106106
return (input > 20.0f) ? input : logf(1 + expf(input));
107107
}
108108
//

0 commit comments

Comments
 (0)