Skip to content

Commit b4cbe03

Browse files
mtp-batch(chore): Fix logit flags for speculative sampling and remove debug logs
1 parent a99709d commit b4cbe03

File tree

6 files changed

+8
-130
lines changed

6 files changed

+8
-130
lines changed

common/speculative.cpp

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,6 @@ llama_token mtp_speculative_gen_draft(
380380

381381
mtp_batch.mtp_params.op_type = MTP_OP_DRAFT_GEN;
382382

383-
// LOG_INF("[DEBUG-DRAFT-CALL] Calling llama_decode for draft. update_mtp_kv=%s, use_mtp_head=%s\n",
384-
// mtp_batch.update_mtp_kv ? "true" : "false",
385-
// mtp_batch.use_mtp_head ? "true" : "false"
386-
// );
387-
388383
// Perform the MTP draft generation decode. This writes the MTP layer's
389384
// KV state for the draft token into the cache.
390385
llama_decode(ctx, mtp_batch);
@@ -416,7 +411,7 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
416411
return;
417412
}
418413

419-
LOG_INF("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
414+
LOG_DBG("[MTP-UPDATE|%s] Updating %d tokens...\n", is_prompt_warmup ? "PROMPT_WARMUP" : "GEN_ACCEPTED", batch.n_tokens);
420415

421416
llama_batch mtp_batch = batch;
422417
if (is_prompt_warmup) {
@@ -426,7 +421,7 @@ void mtp_update_kv_cache(struct llama_context * ctx, const llama_batch& batch, b
426421
}
427422

428423
for (int i = 0; i < mtp_batch.n_tokens; ++i) {
429-
mtp_batch.logits[i] = false;
424+
mtp_batch.logits[i] = true;
430425
}
431426
llama_decode(ctx, mtp_batch);
432427
}
@@ -447,7 +442,7 @@ void mtp_accept_tokens(
447442

448443
llama_batch accepted_batch = llama_batch_init(ids.size(), 0, 1);
449444
for (size_t i = 0; i < ids.size(); ++i) {
450-
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, false);
445+
common_batch_add(accepted_batch, ids[i], n_past_base + i, { seq_id }, true);
451446
}
452447

453448
mtp_update_kv_cache(ctx, accepted_batch, false);
@@ -456,15 +451,3 @@ void mtp_accept_tokens(
456451

457452
llama_batch_free(accepted_batch);
458453
}
459-
460-
// Debug function - It will be removed later
461-
double calculate_vector_sum_double(const float* vec, size_t size) {
462-
if (!vec) {
463-
return 0.0;
464-
}
465-
double sum = 0.0;
466-
for (size_t i = 0; i < size; ++i) {
467-
sum += vec[i];
468-
}
469-
return sum;
470-
}

common/speculative.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,3 @@ void mtp_accept_tokens(
5757
int32_t n_past_base,
5858
llama_seq_id seq_id
5959
);
60-
61-
double calculate_vector_sum_double(const float* vec, size_t size);

src/llama-context.cpp

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -809,15 +809,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
809809
//LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0);
810810
}
811811

812-
const int64_t t_exec_start_us = ggml_time_us();
813812
const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1);
814-
const int64_t t_exec_end_us = ggml_time_us();
815-
// LLAMA_LOG_INFO(
816-
// "[PERF] Graph compute time: %.2f ms (ubatch_size: %u, MTP path: %s)\n",
817-
// (t_exec_end_us - t_exec_start_us) / 1000.0,
818-
// ubatch.n_tokens,
819-
// do_mtp_kv_update ? "yes" : "no"
820-
// );
821813
if (status != GGML_STATUS_SUCCESS) {
822814
LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status);
823815
ret = status;
@@ -827,9 +819,6 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
827819
ret = GGML_STATUS_SUCCESS;
828820
if (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
829821
ggml_tensor * sum_tensor = ggml_get_tensor(res->get_ctx(), "mtp_input_sum");
830-
if (sum_tensor) {
831-
LLAMA_LOG_WARN("[DEBUG-SUM] MTP input sum node successfully created.\n");
832-
}
833822
}
834823
return res;
835824
}
@@ -1123,20 +1112,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
11231112

11241113
do {
11251114
const auto & ubatch = mctx->get_ubatch();
1126-
if (ubatch.n_tokens > 0) {
1127-
std::string pos_str;
1128-
for (uint32_t i = 0; i < std::min((uint32_t)5, ubatch.n_tokens); ++i) {
1129-
pos_str += std::to_string(ubatch.pos[i]) + " ";
1130-
}
1131-
// LLAMA_LOG_WARN(
1132-
// "[DEBUG-POS] ubatch_size=%u, update_mtp_kv=%s, use_mtp_head=%s. Positions: %s...\n",
1133-
// ubatch.n_tokens,
1134-
// batch_inp.update_mtp_kv ? "true" : "false",
1135-
// batch_inp.use_mtp_head ? "true" : "false",
1136-
// pos_str.c_str()
1137-
// );
1138-
}
1139-
11401115
// count the outputs in this ubatch
11411116
{
11421117
int32_t n_outputs_new = 0;
@@ -1281,8 +1256,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
12811256
GGML_ABORT("unknown pooling type");
12821257
}
12831258
}
1284-
} else {
1285-
LLAMA_LOG_WARN("[DEBUG-EMBD-COPY] Skipping embedding buffer copy for MTP operation (use_mtp_head=true).\n");
12861259
}
12871260
}
12881261

@@ -1347,13 +1320,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
13471320
// overlap with device computation.
13481321
ggml_backend_sched_reset(sched.get());
13491322
}
1350-
1351-
if (batch_inp.mtp_params.op_type == MTP_OP_NONE) {
1352-
synchronize();
1353-
const size_t n_embd = this->model.hparams.n_embd;
1354-
double full_buffer_sum = calculate_vector_sum(this->embd, n_outputs_all * n_embd);
1355-
LLAMA_LOG_WARN("[INTEGRITY-CHECK|A] After main decode. ubatch_size=%d. Checksum: %e\n", n_outputs_all, full_buffer_sum);
1356-
}
13571323
return 0;
13581324
}
13591325

@@ -3124,7 +3090,7 @@ std::unique_ptr<llama_memory_context_i> llama_context::initialize_decode_context
31243090
if (cparams.warmup) {
31253091
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
31263092
} else if (kvd->forced_sinfos && !kvd->forced_sinfos->empty()) {
3127-
LLAMA_LOG_WARN("[DEBUG-CACHE-REUSE] Forcing sinfos, bypassing find_slot.\n");
3093+
LLAMA_LOG_DEBUG("%s: Forcing sinfos, bypassing find_slot.\n", __func__);
31283094
mctx = static_cast<llama_kv_cache_unified *>(memory.get())->init_batch_with_sinfos(
31293095
*balloc, cparams.n_ubatch, *kvd->forced_sinfos, true
31303096
);
@@ -3160,19 +3126,13 @@ bool llama_context::prepare_mtp_graph_inputs(
31603126
}
31613127

31623128
if (source_hidden_state != nullptr && hidden_states_input != nullptr) {
3163-
const size_t n_embd = this->model.hparams.n_embd;
3164-
const size_t n_tokens_for_sum = (mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED && ubatch.n_tokens > 2) ? ubatch.n_tokens : 1;
3165-
double input_sum = calculate_vector_sum(source_hidden_state, n_tokens_for_sum * n_embd);
3166-
31673129
const char * op_type;
31683130
if (mtp_params.op_type == MTP_OP_WARMUP || mtp_params.op_type == MTP_OP_UPDATE_ACCEPTED) {
31693131
op_type = "MTP_UPDATE";
31703132
} else { // MTP_OP_DRAFT_GEN
31713133
op_type = "DRAFT_GEN";
31723134
}
31733135

3174-
LLAMA_LOG_WARN("[MTP-INPUT-CHECK] Operation: %s | Input Checksum: %e\n", op_type, input_sum);
3175-
31763136
ggml_backend_tensor_set(hidden_states_input, source_hidden_state, 0, ggml_nbytes(hidden_states_input));
31773137
} else {
31783138
LLAMA_LOG_ERROR("%s: MTP hidden state input tensor ('%s') not found or main embd buffer is null\n",

src/llama-kv-cache-unified.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -766,7 +766,6 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
766766
}
767767

768768
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
769-
LLAMA_LOG_WARN("%s: Entering find_slot for ubatch of %d tokens.\n", __func__, ubatch.n_tokens);
770769
if (debug > 0) {
771770
const auto & cells = v_cells[seq_to_stream[1]];
772771

@@ -972,9 +971,6 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_
972971
}
973972
}
974973
}
975-
LLAMA_LOG_WARN("%s: find_slot SUCCEEDED for ubatch of %d tokens. Idxs:%s\n", __func__, ubatch.n_tokens, idxs_str.c_str());
976-
} else {
977-
LLAMA_LOG_ERROR("%s: find_slot FAILED to allocate cells for ubatch of %d tokens.\n", __func__, ubatch.n_tokens);
978974
}
979975

980976
return res;

src/llama-model.cpp

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13788,35 +13788,11 @@ struct llm_build_glm4 : public llm_graph_context {
1378813788

1378913789
struct llm_build_glm4_moe : public llm_graph_context {
1379013790
llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
13791-
// LLAMA_LOG_WARN(
13792-
// "[GRAPH_BUILD] Building graph. Path: %s, MTP_Update: %s, UBatch_Tokens: %d, First_Pos: %d\n",
13793-
// params.use_mtp_head ? "MTP" : "MAIN",
13794-
// params.update_mtp_kv ? "true" : "false",
13795-
// n_tokens,
13796-
// n_tokens > 0 ? ubatch.pos[0] : -1
13797-
// );
1379813791
const int64_t n_embd_head = hparams.n_embd_head_v;
1379913792
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
1380013793

1380113794
ggml_tensor * cur;
1380213795

13803-
// LLAMA_LOG_WARN(
13804-
// "[DEBUG-GRAPH-STATE] Building graph. MTP Head=%s, MTP KV Update=%s, n_tokens=%d\n",
13805-
// params.use_mtp_head ? "true" : "false",
13806-
// params.update_mtp_kv ? "true" : "false",
13807-
// n_tokens
13808-
// );
13809-
// for (int i = 0; i < n_tokens; ++i) {
13810-
// LLAMA_LOG_WARN(" - ubatch token[%d]: ID=%d, Pos=%d\n", i, ubatch.token[i], ubatch.pos[i]);
13811-
// }
13812-
if (n_tokens > 0) {
13813-
LLAMA_LOG_WARN(
13814-
" - ubatch tokens: [ID=%d, Pos=%d] ... [ID=%d, Pos=%d]\n",
13815-
ubatch.token[0], ubatch.pos[0],
13816-
ubatch.token[n_tokens-1], ubatch.pos[n_tokens-1]
13817-
);
13818-
}
13819-
1382013796
if (params.mtp_params.op_type != MTP_OP_NONE) {
1382113797
ggml_tensor* hidden_states_from_main_model;
1382213798

@@ -13913,10 +13889,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
1391313889
cb(Qcur, "Qcur", il);
1391413890
cb(Kcur, "Kcur", il);
1391513891
cb(Vcur, "Vcur", il);
13916-
if (ubatch.n_tokens > 0) {
13917-
LLAMA_LOG_WARN("[KV_WRITE] path=MAIN, layer=%d, n_tokens=%d, pos_start=%d, pos_end=%d\n",
13918-
il, ubatch.n_tokens, ubatch.pos[0], ubatch.pos[ubatch.n_tokens-1]);
13919-
}
13892+
1392013893
cur = build_attn(inp_attn,
1392113894
model.layers[il].wo, NULL,
1392213895
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
@@ -14066,14 +14039,7 @@ struct llm_build_glm4_moe : public llm_graph_context {
1406614039
cb(Qcur, "Qcur", il);
1406714040
cb(Kcur, "Kcur", il);
1406814041
cb(Vcur, "Vcur", il);
14069-
// LLAMA_LOG_WARN("[DEBUG-MTP-ATTN] Inputs for build_attn in the layer %d:\n", il);
14070-
// LLAMA_LOG_WARN(" - Qcur shape: [%d, %d, %d]\n", Qcur->ne[0], Qcur->ne[1], Qcur->ne[2]);
14071-
// LLAMA_LOG_WARN(" - Kcur shape: [%d, %d, %d]\n", Kcur->ne[0], Kcur->ne[1], Kcur->ne[2]);
14072-
// LLAMA_LOG_WARN(" - Vcur shape: [%d, %d, %d]\n", Vcur->ne[0], Vcur->ne[1], Vcur->ne[2]);
14073-
if (ubatch.n_tokens > 0) {
14074-
LLAMA_LOG_WARN("[KV_WRITE] path=MTP, layer=%d, n_tokens=%d, pos_start=%d, pos_end=%d\n",
14075-
il, ubatch.n_tokens, ubatch.pos[0], ubatch.pos[ubatch.n_tokens-1]);
14076-
}
14042+
1407714043
cur = build_attn(inp_attn,
1407814044
mtp_layer.wo, NULL,
1407914045
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);

tools/server/server.cpp

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,7 +1738,7 @@ struct server_queue {
17381738

17391739
while (true) {
17401740
QUE_DBG("%s", "processing new tasks\n");
1741-
const int64_t t_turn_start_us = ggml_time_us();
1741+
17421742
while (true) {
17431743
std::unique_lock<std::mutex> lock(mutex_tasks);
17441744
if (!running) {
@@ -1761,11 +1761,7 @@ struct server_queue {
17611761
QUE_DBG("%s", "update slots\n");
17621762

17631763
callback_update_slots();
1764-
const int64_t t_turn_end_us = ggml_time_us();
1765-
SRV_DBG(
1766-
"[PERF] Server turn time: %.2f ms\n",
1767-
(t_turn_end_us - t_turn_start_us) / 1000.0
1768-
);
1764+
17691765
QUE_DBG("%s", "waiting for new tasks\n");
17701766
{
17711767
std::unique_lock<std::mutex> lock(mutex_tasks);
@@ -3471,7 +3467,6 @@ struct server_context {
34713467
batch.seq_id + i,
34723468
batch.logits + i,
34733469
};
3474-
LOG_INF("\n[DEBUG-CHUNK] Processing main model chunk. Batch size: %d\n", n_tokens);
34753470

34763471
const int ret = llama_decode(ctx, batch_view);
34773472

@@ -3569,10 +3564,8 @@ struct server_context {
35693564
}
35703565
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
35713566
slot.last_tok_idx = tok_idx;
3572-
//SRV_INF("main loop sampled token: '%s'\n", common_token_to_piece(ctx, id, true).c_str());
35733567

35743568
slot.i_batch = -1;
3575-
SLT_INF(slot, "[SAMPLER-ACCEPT] Accepting token ID %d at index %zu\n", id, i);
35763569
common_sampler_accept(slot.smpl, id, true);
35773570

35783571
slot.n_decoded += 1;
@@ -3644,7 +3637,6 @@ struct server_context {
36443637

36453638
llama_tokens draft;
36463639
if (slot.has_mtp) {
3647-
SLT_INF(slot, "[POS-SYNC] Before draft gen. n_past = %d\n", slot.n_past);
36483640
llama_token draft_id = mtp_speculative_gen_draft(slot.smpl, ctx, id, slot.n_past, slot.last_tok_idx);
36493641
draft.reserve(1);
36503642
draft.push_back(draft_id);
@@ -3680,41 +3672,24 @@ struct server_context {
36803672
}
36813673

36823674
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
3683-
SLT_INF(slot, "[POS-SYNC] Before validation decode. n_past = %d, spec_batch_size = %d\n", slot.n_past, slot.batch_spec.n_tokens);
36843675
llama_decode(ctx, slot.batch_spec);
36853676

3686-
const size_t n_embd = llama_n_embd(llama_get_model(ctx));
3687-
const size_t golden_buffer_size_in_floats = slot.batch_spec.n_tokens * n_embd;
3688-
const float* golden_embd_ptr = llama_get_embeddings(ctx);
3689-
double golden_checksum = calculate_vector_sum_double(golden_embd_ptr, golden_buffer_size_in_floats);
3690-
SLT_INF(slot, "[VERIFY] Golden checksum after validation: %e (size: %zu tokens)\n", golden_checksum, slot.batch_spec.n_tokens);
3691-
36923677
// the accepted tokens from the speculation
36933678
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
3694-
SLT_INF(slot, "[POS-SYNC] Tokens accepted: %zu\n", ids.size());
36953679

36963680
if (slot.has_mtp) {
36973681
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1));
36983682

3699-
const float* embd_after_draft_ptr = llama_get_embeddings(ctx);
3700-
double checksum_after_draft = calculate_vector_sum_double(embd_after_draft_ptr, golden_buffer_size_in_floats);
3701-
SLT_INF(slot, "[VERIFY] Checksum after draft gen (should be unchanged): %e\n", checksum_after_draft);
3702-
37033683
if (!ids.empty()) {
37043684
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, ids.size() - 1));
37053685
} else {
37063686
llama_set_draft_input_hidden_state(ctx, llama_get_embeddings_ith(ctx, 0));
37073687
}
37083688

37093689
mtp_accept_tokens(ctx, ids, slot.n_past, slot.id);
3710-
3711-
const float* embd_after_update_ptr = llama_get_embeddings(ctx);
3712-
double checksum_after_update = calculate_vector_sum_double(embd_after_update_ptr, golden_buffer_size_in_floats);
3713-
SLT_INF(slot, "[VERIFY] Checksum after MTP update (should be unchanged): %e\n", checksum_after_update);
37143690
}
37153691

37163692
slot.n_past += ids.size();
3717-
SLT_INF(slot, "[POS-SYNC] After n_past update. New n_past = %d\n", slot.n_past);
37183693
slot.n_decoded += ids.size();
37193694

37203695
// update how many tokens out of those tested were accepted

0 commit comments

Comments
 (0)