Skip to content

Commit 4a90583

Browse files
committed
sampling : cleanup and clarify output_reserve
1 parent d88ba18 commit 4a90583

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

src/llama-context.cpp

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,40 +1723,41 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
17231723
logits = has_logits ? output_base : nullptr;
17241724
embd = has_embd ? output_base + logits_size : nullptr;
17251725
} else {
1726+
// Allocate worst case (full vocabulary size) for backend sampled
1727+
// data in the pinned memory buffer.
17261728
size_t offset = 0;
17271729
uint8_t * base = (uint8_t *) output_base;
17281730

1729-
if (sampling.logits_size > 0) {
1730-
sampling.logits = (float *) (base + offset);
1731-
offset += sampling.logits_size * sizeof(float);
1732-
}
1733-
if (sampling.probs_size > 0) {
1734-
sampling.probs = (float *) (base + offset);
1735-
offset += sampling.probs_size * sizeof(float);
1736-
}
1737-
if (sampling.sampled_size > 0) {
1738-
sampling.sampled = (llama_token *) (base + offset);
1739-
offset += sampling.sampled_size * sizeof(llama_token);
1740-
}
1741-
if (sampling.candidates_size > 0) {
1742-
sampling.candidates = (llama_token *) (base + offset);
1743-
offset += sampling.candidates_size * sizeof(llama_token);
1744-
}
1731+
sampling.logits = (float *) (base + offset);
1732+
offset += sampling.logits_size * sizeof(float);
1733+
1734+
sampling.probs = (float *) (base + offset);
1735+
offset += sampling.probs_size * sizeof(float);
17451736

1737+
sampling.sampled = (llama_token *) (base + offset);
1738+
offset += sampling.sampled_size * sizeof(llama_token);
1739+
1740+
sampling.candidates = (llama_token *) (base + offset);
1741+
offset += sampling.candidates_size * sizeof(llama_token);
1742+
1743+
// The count vectors keep track of the actual number of logits/probs/candidates
1744+
// copied from the backend for each output row.
17461745
const size_t n_rows = (size_t) n_outputs_max;
17471746
if (sampling.outputs_capacity < n_rows) {
1747+
// The output size has increased, so resize and reset the count vectors.
17481748
sampling.outputs_capacity = n_rows;
17491749

17501750
sampling.logits_count.assign(n_rows, 0);
17511751
sampling.probs_count.assign(n_rows, 0);
17521752
sampling.candidates_count.assign(n_rows, 0);
17531753
} else {
1754+
// The output size has not increased so just reset the counts to zero.
17541755
std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0);
17551756
std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0);
17561757
std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0);
17571758
}
17581759

1759-
if (sampling.sampled && sampling.sampled_size > 0) {
1760+
if (sampling.sampled) {
17601761
std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL);
17611762
}
17621763
}
@@ -1814,9 +1815,11 @@ void llama_context::output_reorder() {
18141815
if (!sampling.logits_count.empty()) {
18151816
std::swap(sampling.logits_count[i0], sampling.logits_count[i1]);
18161817
}
1818+
18171819
if (!sampling.probs_count.empty()) {
18181820
std::swap(sampling.probs_count[i0], sampling.probs_count[i1]);
18191821
}
1822+
18201823
if (!sampling.candidates_count.empty()) {
18211824
std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]);
18221825
}

0 commit comments

Comments
 (0)