@@ -8728,7 +8728,7 @@ static void ggml_compute_forward_ssm_scan_f32(
87288728 // n_head
87298729 for (int h = ih0; h < ih1; ++h) {
87308730 // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8731- const float dt_soft_plus = ggml_softplus (dt[h]);
8731+ const float dt_soft_plus = ggml_compute_softplus_f32 (dt[h]);
87328732 const float dA = expf (dt_soft_plus * A[h]);
87338733 const int g = h / (nh / ng); // repeat_interleave
87348734
@@ -8825,7 +8825,7 @@ static void ggml_compute_forward_ssm_scan_f32(
88258825 // n_head
88268826 for (int h = ih0; h < ih1; ++h) {
88278827 // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
8828- const float dt_soft_plus = ggml_softplus (dt[h]);
8828+ const float dt_soft_plus = ggml_compute_softplus_f32 (dt[h]);
88298829 const int g = h / (nh / ng); // repeat_interleave
88308830
88318831 // dim
@@ -9712,22 +9712,6 @@ void ggml_compute_forward_gla(
97129712 }
97139713}
97149714
9715- static double debug_sum (float * data, size_t size) {
9716- double sum = 0.0 ;
9717- for (unsigned int i = 0 ; i < size; i++) {
9718- sum += data[i];
9719- }
9720- return sum;
9721- }
9722-
9723- static void print_debug_info (float * data, size_t size, const char * name, int64_t token) {
9724- #ifdef MR_CHUNKY_TALKS
9725- GGML_LOG_INFO (" \n ggml-debug: %s (%ld) first 5 values: [%.6f, %.6f, %.6f, %.6f, %.6f, ...]\n " ,
9726- name, token, data[0 ], data[1 ], data[2 ], data[3 ], data[4 ]);
9727- GGML_LOG_INFO (" total elements: %ld, sum = %.10f\n " , size, debug_sum (data, size));
9728- #endif MR_CHUNKY_TALKS
9729- }
9730-
97319715// Helper function to compute cumulative sum
97329716static void delta_cumsum_f32 (const float * x, float * dst, const int64_t n) {
97339717 float cumsum = 0 .0f ;
@@ -9837,34 +9821,9 @@ static void delta_apply_triangular_updates_chunk_f32(float * attn,
98379821 attn_ptr[i * chunk_size + j] = row[j] + sum_val;
98389822 }
98399823
9840- if (i % 10 == 0 ) {
9841- if (seq == 1 && head == 0 && chunk == 0 ) {
9842- print_debug_info (row, i, " row[1, 0, 0]" , i);
9843- print_debug_info (sub, i * i, " sub[1, 0, 0]" , i);
9844- }
9845- if (seq == 0 && head == 1 && chunk == 0 ) {
9846- print_debug_info (row, i, " row[0, 1, 0]" , i);
9847- print_debug_info (sub, i * i, " sub[0, 1, 0]" , i);
9848- }
9849- if (seq == 0 && head == 0 && chunk == 1 ) {
9850- print_debug_info (row, i, " row[0, 0, 1]" , i);
9851- print_debug_info (sub, i * i, " sub[0, 0, 1]" , i);
9852- }
9853- }
9854-
98559824 free (row);
98569825 free (sub);
98579826 }
9858-
9859- if (seq == 1 && head == 0 && chunk == 0 ) {
9860- print_debug_info (attn_ptr, chunk_size * chunk_size, " attn[1, 0, 0]" , 0 );
9861- }
9862- if (seq == 0 && head == 1 && chunk == 0 ) {
9863- print_debug_info (attn_ptr, chunk_size * chunk_size, " attn[0, 1, 0]" , 0 );
9864- }
9865- if (seq == 0 && head == 0 && chunk == 1 ) {
9866- print_debug_info (attn_ptr, chunk_size * chunk_size, " attn[0, 0, 1]" , 0 );
9867- }
98689827 }
98699828 }
98709829 }
@@ -10191,8 +10150,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1019110150 }
1019210151 }
1019310152 }
10194- print_debug_info (new_state, S_v * S_v * H_v * n_seqs, " init_state" , -1 );
10195-
1019610153
1019710154 GGML_ASSERT (ggml_is_contiguous (src0));
1019810155 GGML_ASSERT (ggml_is_contiguous (src1));
@@ -10229,13 +10186,10 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1022910186 // for i in range(1, chunk_size): attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
1023010187 // attn = attn + torch.eye(chunk_size)
1023110188 delta_apply_triangular_updates_chunk_f32 (attn, chunk_size, n_seqs, H_v, num_chunks);
10232- print_debug_info (attn, chunk_size * chunk_size * H_v * num_chunks * n_seqs, " attn_chunk" , -1 );
1023310189 delta_add_identity_matrix_chunk_f32 (attn, chunk_size, n_seqs, H_v, num_chunks);
10234- print_debug_info (attn, chunk_size * chunk_size * H_v * num_chunks * n_seqs, " attn_eye" , -1 );
1023510190
1023610191 // Compute value = attn @ v_beta
1023710192 delta_compute_value_f32 (attn, (const float *) src6->data , value, chunk_size, S_v, H_v, n_seqs, num_chunks);
10238- print_debug_info (value, chunk_size * S_v * H_v * num_chunks * n_seqs, " value" , -1 );
1023910193
1024010194 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
1024110195 for (int i = 0 ; i < num_chunks; i++) {
@@ -10248,7 +10202,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1024810202 }
1024910203 }
1025010204 }
10251- print_debug_info (k_cumdecay, chunk_size * S_v * H_v * num_chunks * n_seqs, " k_cumdecay" , -1 );
1025210205
1025310206 // Process each chunk with all sequences and heads together
1025410207 for (int64_t chunk = 0 ; chunk < num_chunks; chunk++) {
@@ -10304,9 +10257,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1030410257 }
1030510258 }
1030610259
10307- print_debug_info (pc_q_chunk_data, chunk_size * S_v * H_v * n_seqs, " q_i_chunk" , chunk);
10308- print_debug_info (pc_k_chunk_data, chunk_size * S_v * H_v * n_seqs, " k_i_chunk" , chunk);
10309-
1031010260 // Step 4: Compute NEW attention matrix for this chunk: attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
1031110261 // Note: decay_mask[:, :, i] means we need to use the decay_mask for this specific chunk
1031210262 // The mask applied is the simple causal attention mask: torch.triu(torch.ones(chunk_size, chunk_size), diagonal=1)
@@ -10328,7 +10278,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1032810278 delta_matmul_f32 (q_ptr, k_trans, attn_ptr, chunk_size, chunk_size, S_v);
1032910279 }
1033010280 }
10331- print_debug_info (attn, chunk_size * chunk_size * num_chunks * H_v * n_seqs, " attn_q_k_trans" , chunk);
1033210281
1033310282 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
1033410283 for (int64_t head = 0 ; head < H_v; head++) {
@@ -10348,20 +10297,15 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1034810297 }
1034910298 }
1035010299
10351- print_debug_info (attn, chunk_size * chunk_size * num_chunks * H_v * n_seqs, " attn_step4_new_chunk" , chunk);
10352-
1035310300 // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
1035410301 // k_cumdecay has shape [chunk_size, v_head_dim], state has shape [v_head_dim, v_head_dim]
1035510302 delta_matmul_state_chunk_f32 (k_cumdecay, new_state, pc_v_prime, chunk_size, S_v, S_v, n_seqs, H_v, chunk, num_chunks);
10356- print_debug_info (pc_v_prime, chunk_size * S_v * H_v * n_seqs, " v_prime_chunk" , chunk);
1035710303
1035810304 // v_new = v_i - v_prime
1035910305 delta_tensor_subtract_chunk_f32 (value, pc_v_prime, pc_v_new, chunk_size * S_v, n_seqs, H_v, num_chunks, chunk);
10360- print_debug_info (pc_v_new, chunk_size * S_v * H_v * n_seqs, " v_new_chunk" , chunk);
1036110306
1036210307 // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
1036310308 delta_matmul_state_chunk_f32 (pc_q_g_exp, new_state, pc_attn_inter, chunk_size, S_v, S_v, n_seqs, H_v, -1 , -1 );
10364- print_debug_info (pc_attn_inter, chunk_size * S_v * H_v * n_seqs, " attn_inter_chunk" , chunk);
1036510309
1036610310 // core_attn_out[:, :, i] = attn_inter + attn @ v_new
1036710311 // Use regular matrix multiplication for attn @ v_new
@@ -10375,9 +10319,7 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1037510319 delta_matmul_f32 (attn_ptr, v_new_ptr, attn_v_new_ptr, chunk_size, S_v, chunk_size);
1037610320 }
1037710321 }
10378- print_debug_info (pc_attn_v_new, chunk_size * S_v * H_v * n_seqs, " attn_v_new_chunk" , chunk);
1037910322 delta_tensor_add_chunk_f32 (pc_attn_inter, pc_attn_v_new, pc_core_attn_out, chunk_size * S_v, n_seqs, H_v);
10380- print_debug_info (pc_core_attn_out, chunk_size * S_v * H_v * n_seqs, " core_attn_out_chunk" , chunk);
1038110323
1038210324 // Prepare g_last and g_diff_exp for state update
1038310325 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -10394,9 +10336,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1039410336 }
1039510337 }
1039610338
10397- print_debug_info (pc_g_last, H_v * n_seqs, " g_last_chunk" , chunk);
10398- print_debug_info (pc_g_diff_exp, chunk_size * H_v * n_seqs, " g_diff_exp" , chunk);
10399-
1040010339 float * k_g_diffexp = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
1040110340 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
1040210341 for (int64_t head = 0 ; head < H_v; head++) {
@@ -10408,7 +10347,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1040810347 }
1040910348 }
1041010349 }
10411- print_debug_info (k_g_diffexp, chunk_size * S_v * H_v * n_seqs, " k_g_diffexp" , chunk);
1041210350 float * k_g_diffexp_T = (float *) malloc (chunk_size * S_v * H_v * n_seqs * sizeof (float ));
1041310351 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
1041410352 for (int64_t head = 0 ; head < H_v; head++) {
@@ -10421,25 +10359,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1042110359 }
1042210360 }
1042310361
10424- // for (int64_t seq = 0; seq < n_seqs; seq++) {
10425- // for (int64_t head = 0; head < H_v; head++) {
10426- // GGML_LOG_INFO("Sequence %ld, head %ld: \n[ ", seq, head);
10427- // for (int i = 0; i < chunk_size; i++) {
10428- // GGML_LOG_INFO("[ ");
10429- // for (int j = 0; j < S_v; j++) {
10430- // GGML_LOG_INFO("%.6f", k_g_diffexp[(chunk_size * S_v * H_v) * seq + (chunk_size * S_v) * head + i * S_v + j]);
10431- // if (j < chunk_size - 1) {
10432- // GGML_LOG_INFO(", ");
10433- // }
10434- // }
10435- // GGML_LOG_INFO("], \n");
10436- // }
10437- // GGML_LOG_INFO("]\n");
10438- // }
10439- // }
10440-
10441- print_debug_info (k_g_diffexp_T, chunk_size * S_v * H_v * n_seqs, " k_g_diffexp_T" , chunk);
10442-
1044310362 float * kgd_mul_vnew = (float *) malloc (S_v * S_v * H_v * n_seqs * sizeof (float ));
1044410363
1044510364 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -10450,24 +10369,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1045010369 S_v, S_v, chunk_size);
1045110370 }
1045210371 }
10453- print_debug_info (kgd_mul_vnew, S_v * S_v * H_v * n_seqs, " kgd_mul_vnew" , chunk);
10454-
10455- // for (int64_t seq = 0; seq < n_seqs; seq++) {
10456- // for (int64_t head = 0; head < H_v; head++) {
10457- // GGML_LOG_INFO("Sequence %ld, head %ld: \n[ ", seq, head);
10458- // for (int i = 0; i < S_v; i++) {
10459- // GGML_LOG_INFO("[ ");
10460- // for (int j = 0; j < S_v; j++) {
10461- // GGML_LOG_INFO("%.6f", kgd_mul_vnew[(S_v * S_v * H_v) * seq + (S_v * S_v) * head + i * S_v + j]);
10462- // if (j < S_v - 1) {
10463- // GGML_LOG_INFO(", ");
10464- // }
10465- // }
10466- // GGML_LOG_INFO("], \n");
10467- // }
10468- // GGML_LOG_INFO("]\n");
10469- // }
10470- // }
1047110372
1047210373 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
1047310374 for (int64_t head = 0 ; head < H_v; head++) {
@@ -10480,7 +10381,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1048010381 }
1048110382 }
1048210383 }
10483- print_debug_info (new_state, S_v * S_v * H_v * n_seqs, " state_end_chunk" , chunk);
1048410384
1048510385 // Free temporary memory
1048610386 free (pc_q_chunk_data);
@@ -10511,21 +10411,6 @@ void ggml_compute_forward_delta_net_f32(const ggml_compute_params * params, ggml
1051110411 }
1051210412 }
1051310413 }
10514- print_debug_info (output, S_v * H_v * n_tokens * n_seqs, " output" , chunk);
10515- // GGML_LOG_INFO("\nFull output tensor: \n\n");
10516- // for (int64_t seq = 0; seq < n_seqs; seq++) {
10517- // for (int64_t head = 0; head < H_v; head++) {
10518- // GGML_LOG_INFO("\n[ ");
10519- // for (int64_t i = 0; i < n_tokens; i++) {
10520- // for (int64_t d = 0; d < S_v; d++) {
10521- // GGML_LOG_INFO("%.4f ", output[seq * (n_tokens * S_v * H_v) + head * (n_tokens * S_v) + (chunk * chunk_size + i) * S_v + d]);
10522- // }
10523- // }
10524- // GGML_LOG_INFO(" ]");
10525- // }
10526- // }
10527- print_debug_info (new_state, S_v * S_v * H_v * n_seqs, " new_state" , chunk);
10528-
1052910414 free (pc_core_attn_out);
1053010415 free (pc_attn_inter);
1053110416 free (pc_v_new);
@@ -10622,7 +10507,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1062210507 }
1062310508 }
1062410509 }
10625- print_debug_info (temp_state, n_seqs * H_v * S_v * S_v, " temp_state_copy" , token);
1062610510
1062710511 // 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
1062810512 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -10636,7 +10520,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1063610520 }
1063710521 }
1063810522 }
10639- print_debug_info (temp_state, n_seqs * H_v * S_v * S_v, " temp_state_times_g_t" , token);
1064010523
1064110524 // 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
1064210525 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -10651,7 +10534,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1065110534 }
1065210535 }
1065310536 }
10654- print_debug_info (kv_mem, n_seqs * H_v * S_v, " kv_mem" , token);
1065510537
1065610538 // 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
1065710539 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -10663,7 +10545,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1066310545 }
1066410546 }
1066510547 }
10666- print_debug_info (delta, n_seqs * H_v * S_v, " delta" , token);
1066710548
1066810549 // 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
1066910550 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -10677,7 +10558,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1067710558 }
1067810559 }
1067910560 }
10680- print_debug_info (temp_state, n_seqs * H_v * S_v * S_v, " temp_state" , token);
1068110561
1068210562 // 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
1068310563 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -10691,7 +10571,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1069110571 }
1069210572 }
1069310573 }
10694- print_debug_info (attn_out_t , n_seqs * S_v * H_v, " attn_out_t" , token);
1069510574
1069610575 // Store the output for this token (for all seqs and heads)
1069710576 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
0 commit comments