Skip to content

Commit cd3069d

Browse files
authored
kv-cache : log (debug) all streams in find_slot (#15176)
This commit updates `llama_kv_cache_unified::find_slot` to log information for all streams when debug is enabled. The motivation for this change is that currently if a non-unified kv-cache is used, then only one stream will be logged because the code was currently uses `seq_to_stream[1]`.
1 parent 50e81bd commit cd3069d

File tree

1 file changed

+53
-49
lines changed

1 file changed

+53
-49
lines changed

src/llama-kv-cache-unified.cpp

Lines changed: 53 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -738,66 +738,70 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
738738
}
739739

740740
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
741-
if (debug > 0) {
742-
const auto & cells = v_cells[seq_to_stream[1]];
743-
744-
const uint32_t head_cur = v_heads[1];
745741

746-
LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
747-
__func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
742+
if (debug > 0) {
743+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
744+
const auto seq_id = ubatch.seq_id_unq[s];
745+
const auto stream_id = seq_to_stream[seq_id];
746+
const auto & cells = v_cells[stream_id];
747+
const uint32_t head_cur = v_heads[stream_id];
748+
749+
LLAMA_LOG_DEBUG("%s: stream[%d], n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n",
750+
__func__, stream_id, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa);
751+
752+
if ((debug == 2 && n_swa > 0) || debug > 2) {
753+
std::string ss;
754+
for (uint32_t i = 0; i < cells.size(); ++i) {
755+
if (cells.is_empty(i)) {
756+
ss += '.';
757+
} else {
758+
assert(cells.seq_count(i) >= 1);
748759

749-
if ((debug == 2 && n_swa > 0) || debug > 2) {
750-
std::string ss;
751-
for (uint32_t i = 0; i < cells.size(); ++i) {
752-
if (cells.is_empty(i)) {
753-
ss += '.';
754-
} else {
755-
assert(cells.seq_count(i) >= 1);
760+
if (cells.seq_count(i) == 1) {
761+
ss += std::to_string(cells.seq_get(i));
762+
} else {
763+
ss += 'M';
764+
}
765+
}
766+
if (i%256 == 255) {
767+
ss += " *";
768+
ss += '\n';
769+
}
770+
}
771+
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
772+
}
756773

757-
if (cells.seq_count(i) == 1) {
758-
ss += std::to_string(cells.seq_get(i));
774+
if ((debug == 2 && n_swa > 0) || debug > 2) {
775+
std::string ss;
776+
for (uint32_t i = 0; i < cells.size(); ++i) {
777+
std::string cur;
778+
if (cells.is_empty(i)) {
779+
cur = '.';
759780
} else {
760-
ss += 'M';
781+
cur = std::to_string(cells.pos_get(i));
782+
}
783+
const int n = cur.size();
784+
for (int j = 0; j < 5 - n; ++j) {
785+
cur += ' ';
786+
}
787+
ss += cur;
788+
if (i%256 == 255) {
789+
ss += " *";
790+
}
791+
if (i%64 == 63) {
792+
ss += '\n';
761793
}
762794
}
763-
if (i%256 == 255) {
764-
ss += " *";
765-
ss += '\n';
766-
}
795+
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
767796
}
768-
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
769-
}
770797

771-
if ((debug == 2 && n_swa > 0) || debug > 2) {
772-
std::string ss;
773-
for (uint32_t i = 0; i < cells.size(); ++i) {
774-
std::string cur;
775-
if (cells.is_empty(i)) {
776-
cur = '.';
777-
} else {
778-
cur = std::to_string(cells.pos_get(i));
779-
}
780-
const int n = cur.size();
781-
for (int j = 0; j < 5 - n; ++j) {
782-
cur += ' ';
783-
}
784-
ss += cur;
785-
if (i%256 == 255) {
786-
ss += " *";
787-
}
788-
if (i%64 == 63) {
789-
ss += '\n';
798+
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
799+
if (cells.seq_pos_min(s) < 0) {
800+
continue;
790801
}
791-
}
792-
LLAMA_LOG_DEBUG("\n%s\n", ss.c_str());
793-
}
794802

795-
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
796-
if (cells.seq_pos_min(s) < 0) {
797-
continue;
803+
LLAMA_LOG_DEBUG("%s: stream[%d] min[%d] = %5d, max[%d] = %5d\n", __func__, stream_id, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
798804
}
799-
800-
LLAMA_LOG_DEBUG("%s: min[%d] = %5d, max[%d] = %5d\n", __func__, s, cells.seq_pos_min(s), s, cells.seq_pos_max(s));
801805
}
802806
}
803807

0 commit comments

Comments
 (0)