@@ -2975,7 +2975,7 @@ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggm
29752975 const ggml_tensor * src0 = dst->src [0 ];
29762976
29772977 ggml_tri_type ttype = (ggml_tri_type) dst->op_params [0 ];
2978- float c = *(( float *) &( dst-> op_params [ 1 ]) );
2978+ float c = ggml_get_op_params_f32 ( dst, 1 );
29792979 bool keep_org_val = isnan (c);
29802980
29812981 GGML_ASSERT (ggml_is_contiguous (src0));
@@ -10902,7 +10902,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1090210902 // src6, src7, src8 are nullptr in recurrent version
1090310903
1090410904 const int64_t H_v = (int64_t ) dst->op_params [0 ];
10905- const int64_t S_k = (int64_t ) dst->op_params [1 ];
1090610905 const int64_t S_v = (int64_t ) dst->op_params [2 ];
1090710906 const int64_t original_n_tokens = (int64_t ) dst->op_params [3 ]; // Get original sequence length
1090810907 const int64_t n_tokens = original_n_tokens; // Use the original sequence length
@@ -10972,7 +10971,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1097210971 }
1097310972 }
1097410973 }
10975- print_debug_info (temp_state, n_seqs * H_v * S_v * S_v, " temp_state_copy" , token);
10974+ // print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
1097610975
1097710976 // 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
1097810977 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -10986,7 +10985,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1098610985 }
1098710986 }
1098810987 }
10989- print_debug_info (temp_state, n_seqs * H_v * S_v * S_v, " temp_state_times_g_t" , token);
10988+ // print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
1099010989
1099110990 // 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
1099210991 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -11001,7 +11000,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1100111000 }
1100211001 }
1100311002 }
11004- print_debug_info (kv_mem, n_seqs * H_v * S_v, " kv_mem" , token);
11003+ // print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
1100511004
1100611005 // 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
1100711006 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -11013,7 +11012,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1101311012 }
1101411013 }
1101511014 }
11016- print_debug_info (delta, n_seqs * H_v * S_v, " delta" , token);
11015+ // print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
1101711016
1101811017 // 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
1101911018 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -11027,7 +11026,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1102711026 }
1102811027 }
1102911028 }
11030- print_debug_info (temp_state, n_seqs * H_v * S_v * S_v, " temp_state" , token);
11029+ // print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
1103111030
1103211031 // 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
1103311032 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
@@ -11041,7 +11040,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1104111040 }
1104211041 }
1104311042 }
11044- print_debug_info (attn_out_t , n_seqs * S_v * H_v, " attn_out_t" , token);
11043+ // print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
1104511044
1104611045 // Store the output for this token (for all seqs and heads)
1104711046 for (int64_t seq = 0 ; seq < n_seqs; seq++) {
0 commit comments