Skip to content

Commit 2cab86a

Browse files
committed
Let the debug out.
1 parent 7eef0bd commit 2cab86a

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2975,7 +2975,7 @@ static void ggml_compute_forward_tri_f32(const ggml_compute_params * params, ggm
29752975
const ggml_tensor * src0 = dst->src[0];
29762976

29772977
ggml_tri_type ttype = (ggml_tri_type) dst->op_params[0];
2978-
float c = *((float *) &(dst->op_params[1]));
2978+
float c = ggml_get_op_params_f32(dst, 1);
29792979
bool keep_org_val = isnan(c);
29802980

29812981
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -10902,7 +10902,6 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1090210902
// src6, src7, src8 are nullptr in recurrent version
1090310903

1090410904
const int64_t H_v = (int64_t) dst->op_params[0];
10905-
const int64_t S_k = (int64_t) dst->op_params[1];
1090610905
const int64_t S_v = (int64_t) dst->op_params[2];
1090710906
const int64_t original_n_tokens = (int64_t) dst->op_params[3]; // Get original sequence length
1090810907
const int64_t n_tokens = original_n_tokens; // Use the original sequence length
@@ -10972,7 +10971,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1097210971
}
1097310972
}
1097410973
}
10975-
print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
10974+
//print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_copy", token);
1097610975

1097710976
// 1. last_recurrent_state = last_recurrent_state * g_t (for all seqs and heads)
1097810977
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -10986,7 +10985,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1098610985
}
1098710986
}
1098810987
}
10989-
print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
10988+
//print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state_times_g_t", token);
1099010989

1099110990
// 2. kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
1099210991
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -11001,7 +11000,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1100111000
}
1100211001
}
1100311002
}
11004-
print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
11003+
//print_debug_info(kv_mem, n_seqs * H_v * S_v, "kv_mem", token);
1100511004

1100611005
// 3. delta = (v_t - kv_mem) * beta_t (for all seqs and heads)
1100711006
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -11013,7 +11012,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1101311012
}
1101411013
}
1101511014
}
11016-
print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
11015+
//print_debug_info(delta, n_seqs * H_v * S_v, "delta", token);
1101711016

1101811017
// 4. last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) (for all seqs and heads)
1101911018
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -11027,7 +11026,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1102711026
}
1102811027
}
1102911028
}
11030-
print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
11029+
//print_debug_info(temp_state, n_seqs * H_v * S_v * S_v, "temp_state", token);
1103111030

1103211031
// 5. core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) (for all seqs and heads)
1103311032
for (int64_t seq = 0; seq < n_seqs; seq++) {
@@ -11041,7 +11040,7 @@ void ggml_compute_forward_delta_net_recurrent_f32(const ggml_compute_params * pa
1104111040
}
1104211041
}
1104311042
}
11044-
print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
11043+
//print_debug_info(attn_out_t, n_seqs * S_v * H_v, "attn_out_t", token);
1104511044

1104611045
// Store the output for this token (for all seqs and heads)
1104711046
for (int64_t seq = 0; seq < n_seqs; seq++) {

0 commit comments

Comments
 (0)