@@ -1723,40 +1723,41 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
17231723 logits = has_logits ? output_base : nullptr ;
17241724 embd = has_embd ? output_base + logits_size : nullptr ;
17251725 } else {
1726+ // Allocate worst case (full vocabulary size) for backend sampled
1727+ // data in the pinned memory buffer.
17261728 size_t offset = 0 ;
17271729 uint8_t * base = (uint8_t *) output_base;
17281730
1729- if (sampling.logits_size > 0 ) {
1730- sampling.logits = (float *) (base + offset);
1731- offset += sampling.logits_size * sizeof (float );
1732- }
1733- if (sampling.probs_size > 0 ) {
1734- sampling.probs = (float *) (base + offset);
1735- offset += sampling.probs_size * sizeof (float );
1736- }
1737- if (sampling.sampled_size > 0 ) {
1738- sampling.sampled = (llama_token *) (base + offset);
1739- offset += sampling.sampled_size * sizeof (llama_token);
1740- }
1741- if (sampling.candidates_size > 0 ) {
1742- sampling.candidates = (llama_token *) (base + offset);
1743- offset += sampling.candidates_size * sizeof (llama_token);
1744- }
1731+ sampling.logits = (float *) (base + offset);
1732+ offset += sampling.logits_size * sizeof (float );
1733+
1734+ sampling.probs = (float *) (base + offset);
1735+ offset += sampling.probs_size * sizeof (float );
17451736
1737+ sampling.sampled = (llama_token *) (base + offset);
1738+ offset += sampling.sampled_size * sizeof (llama_token);
1739+
1740+ sampling.candidates = (llama_token *) (base + offset);
1741+ offset += sampling.candidates_size * sizeof (llama_token);
1742+
1743+ // The count vectors keep track of the actual number of logits/probs/candidates
1744+ // copied from the backend for each output row.
17461745 const size_t n_rows = (size_t ) n_outputs_max;
17471746 if (sampling.outputs_capacity < n_rows) {
1747+ // The output size has increased, so resize and reset the count vectors.
17481748 sampling.outputs_capacity = n_rows;
17491749
17501750 sampling.logits_count .assign (n_rows, 0 );
17511751 sampling.probs_count .assign (n_rows, 0 );
17521752 sampling.candidates_count .assign (n_rows, 0 );
17531753 } else {
1754+ // The output size has not increased so just reset the counts to zero.
17541755 std::fill (sampling.logits_count .begin (), sampling.logits_count .end (), 0 );
17551756 std::fill (sampling.probs_count .begin (), sampling.probs_count .end (), 0 );
17561757 std::fill (sampling.candidates_count .begin (), sampling.candidates_count .end (), 0 );
17571758 }
17581759
1759- if (sampling.sampled && sampling. sampled_size > 0 ) {
1760+ if (sampling.sampled ) {
17601761 std::fill_n (sampling.sampled , sampling.sampled_size , LLAMA_TOKEN_NULL);
17611762 }
17621763 }
@@ -1814,9 +1815,11 @@ void llama_context::output_reorder() {
18141815 if (!sampling.logits_count .empty ()) {
18151816 std::swap (sampling.logits_count [i0], sampling.logits_count [i1]);
18161817 }
1818+
18171819 if (!sampling.probs_count .empty ()) {
18181820 std::swap (sampling.probs_count [i0], sampling.probs_count [i1]);
18191821 }
1822+
18201823 if (!sampling.candidates_count .empty ()) {
18211824 std::swap (sampling.candidates_count [i0], sampling.candidates_count [i1]);
18221825 }
0 commit comments