Skip to content

Commit 16bc60a

Browse files
committed
Sync sampling : optimize samplers by reusing bucket sort
1 parent 77160c6 commit 16bc60a

File tree

3 files changed

+2
-23
lines changed

3 files changed

+2
-23
lines changed

llama_cpp/_internals.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -405,13 +405,6 @@ def sample_repetition_penalties(
405405
# )
406406
raise NotImplementedError("sample_repetition_penalties is not implemented in llama.cpp")
407407

408-
def sample_softmax(self, candidates: "_LlamaTokenDataArray"):
409-
# llama_cpp.llama_sample_softmax(
410-
# self.ctx,
411-
# llama_cpp.byref(candidates.candidates),
412-
# )
413-
raise NotImplementedError("sample_softmax is not implemented in llama.cpp")
414-
415408
def sample_top_k(self, candidates: "_LlamaTokenDataArray", k: int, min_keep: int):
416409
# llama_cpp.llama_sample_top_k(
417410
# self.ctx, llama_cpp.byref(candidates.candidates), k, min_keep
@@ -592,6 +585,7 @@ def __init__(self, *, n_vocab: int):
592585
self.candidates = llama_cpp.llama_token_data_array(
593586
data=self.candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
594587
size=self.n_vocab,
588+
selected=-1,
595589
sorted=False,
596590
)
597591
self.default_candidates_data_id = np.arange(self.n_vocab, dtype=np.intc) # type: ignore
@@ -729,7 +723,6 @@ def sample(
729723
ctx_main.sample_grammar(token_data_array, self.grammar)
730724

731725
if self.params.temp < 0:
732-
ctx_main.sample_softmax(token_data_array)
733726
id = token_data_array.candidates_data.id[0]
734727
elif self.params.temp == 0:
735728
id = ctx_main.sample_token_greedy(token_data_array)
@@ -827,10 +820,6 @@ def add_dist(self, seed: int):
827820
sampler = llama_cpp.llama_sampler_init_dist(seed)
828821
self._add_sampler(sampler)
829822

830-
def add_softmax(self):
831-
sampler = llama_cpp.llama_sampler_init_softmax()
832-
self._add_sampler(sampler)
833-
834823
def add_top_k(self, k: int):
835824
sampler = llama_cpp.llama_sampler_init_top_k(k)
836825
self._add_sampler(sampler)

llama_cpp/llama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,6 @@ def _init_sampler(
735735
sampler.add_grammar(self._model, grammar)
736736

737737
if temp < 0.0:
738-
sampler.add_softmax()
739738
sampler.add_dist(self._seed)
740739
elif temp == 0.0:
741740
sampler.add_greedy()

llama_cpp/llama_cpp.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ class llama_token_data(ctypes.Structure):
552552
# llama_token_data * data;
553553
# size_t size;
554554
# int64_t selected; // this is the index in the data array (i.e. not the token id)
555-
# bool sorted;
555+
# bool sorted; // note: do not assume the data is sorted - always check this flag
556556
# } llama_token_data_array;
557557
class llama_token_data_array(ctypes.Structure):
558558
"""Used to sample tokens given logits
@@ -3742,15 +3742,6 @@ def llama_sampler_init_dist(seed: int) -> llama_sampler_p:
37423742
...
37433743

37443744

3745-
# /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
3746-
# /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
3747-
# DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void),
3748-
# "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
3749-
@ctypes_function("llama_sampler_init_softmax", [], llama_sampler_p_ctypes)
3750-
def llama_sampler_init_softmax() -> llama_sampler_p:
3751-
...
3752-
3753-
37543745
# /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
37553746
# /// Setting k <= 0 makes this a noop
37563747
# LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k);

0 commit comments

Comments
 (0)