Skip to content

Commit 1a1d920

Browse files
committed
Sync context : remove logits_all flag and update API
1 parent b510bb0 commit 1a1d920

File tree

6 files changed

+58
-80
lines changed

6 files changed

+58
-80
lines changed

llama_cpp/_internals.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -279,23 +279,35 @@ def n_ctx(self) -> int:
279279
def pooling_type(self) -> int:
280280
return llama_cpp.llama_pooling_type(self.ctx)
281281

282-
def kv_cache_clear(self):
283-
llama_cpp.llama_kv_cache_clear(self.ctx)
282+
def kv_self_clear(self):
283+
llama_cpp.llama_kv_self_clear(self.ctx)
284284

285-
def kv_cache_seq_rm(self, seq_id: int, p0: int, p1: int):
286-
llama_cpp.llama_kv_cache_seq_rm(self.ctx, seq_id, p0, p1)
285+
def kv_self_seq_rm(self, seq_id: int, p0: int, p1: int):
286+
llama_cpp.llama_kv_self_seq_rm(self.ctx, seq_id, p0, p1)
287287

288-
def kv_cache_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
289-
llama_cpp.llama_kv_cache_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
288+
def kv_self_seq_cp(self, seq_id_src: int, seq_id_dst: int, p0: int, p1: int):
289+
llama_cpp.llama_kv_self_seq_cp(self.ctx, seq_id_src, seq_id_dst, p0, p1)
290290

291-
def kv_cache_seq_keep(self, seq_id: int):
292-
llama_cpp.llama_kv_cache_seq_keep(self.ctx, seq_id)
291+
def kv_self_seq_keep(self, seq_id: int):
292+
llama_cpp.llama_kv_self_seq_keep(self.ctx, seq_id)
293293

294-
def kv_cache_seq_shift(self, seq_id: int, p0: int, p1: int, shift: int):
295-
llama_cpp.llama_kv_cache_seq_add(self.ctx, seq_id, p0, p1, shift)
294+
def kv_self_seq_add(self, seq_id: int, p0: int, p1: int, delta: int):
295+
llama_cpp.llama_kv_self_seq_add(self.ctx, seq_id, p0, p1, delta)
296+
297+
def kv_self_seq_div(self, seq_id: int, p0: int, p1: int, d: int):
298+
llama_cpp.llama_kv_self_seq_div(self.ctx, seq_id, p0, p1, d)
299+
300+
def kv_self_seq_pos_max(self, seq_id: int):
301+
llama_cpp.llama_kv_self_seq_pos_max(self.ctx, seq_id)
302+
303+
def kv_self_defrag(self):
304+
llama_cpp.llama_kv_self_defrag(self.ctx)
305+
306+
def kv_self_can_shift(self) -> bool:
307+
llama_cpp.llama_kv_self_can_shift(self.ctx)
296308

297309
def get_state_size(self) -> int:
298-
return llama_cpp.llama_get_state_size(self.ctx)
310+
return llama_cpp.llama_state_get_size(self.ctx)
299311

300312
# TODO: copy_state_data
301313

@@ -502,18 +514,16 @@ def n_tokens(self) -> int:
502514
def reset(self):
503515
self.batch.n_tokens = 0
504516

505-
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
517+
def set_batch(self, batch: Sequence[int], n_past: int):
506518
n_tokens = len(batch)
507519
self.batch.n_tokens = n_tokens
508520
for i in range(n_tokens):
509521
self.batch.token[i] = batch[i]
510522
self.batch.pos[i] = n_past + i
511523
self.batch.seq_id[i][0] = 0
512524
self.batch.n_seq_id[i] = 1
513-
self.batch.logits[i] = logits_all
514-
self.batch.logits[n_tokens - 1] = True
515525

516-
def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
526+
def add_sequence(self, batch: Sequence[int], seq_id: int):
517527
n_tokens = len(batch)
518528
n_tokens0 = self.batch.n_tokens
519529
self.batch.n_tokens += n_tokens
@@ -523,8 +533,6 @@ def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
523533
self.batch.pos[j] = i
524534
self.batch.seq_id[j][0] = seq_id
525535
self.batch.n_seq_id[j] = 1
526-
self.batch.logits[j] = logits_all
527-
self.batch.logits[n_tokens - 1] = True
528536

529537

530538
class LlamaTokenDataArray:

llama_cpp/llama.py

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def __init__(
8989
yarn_beta_fast: float = 32.0,
9090
yarn_beta_slow: float = 1.0,
9191
yarn_orig_ctx: int = 0,
92-
logits_all: bool = False,
9392
embedding: bool = False,
9493
offload_kqv: bool = True,
9594
flash_attn: bool = False,
@@ -170,7 +169,6 @@ def __init__(
170169
yarn_beta_fast: YaRN low correction dim
171170
yarn_beta_slow: YaRN high correction dim
172171
yarn_orig_ctx: YaRN original context size
173-
logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs.
174172
embedding: Embedding mode only.
175173
offload_kqv: Offload K, Q, V to GPU.
176174
flash_attn: Use flash attention.
@@ -341,9 +339,6 @@ def __init__(
341339
yarn_beta_slow if yarn_beta_slow != 0.0 else 0
342340
)
343341
self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0
344-
self.context_params.logits_all = (
345-
logits_all if draft_model is None else True
346-
) # Must be set to True for speculative decoding
347342
self.context_params.embeddings = embedding # TODO: Rename to embeddings
348343
self.context_params.offload_kqv = offload_kqv
349344
self.context_params.flash_attn = flash_attn
@@ -457,9 +452,7 @@ def free_lora_adapter():
457452

458453
self.n_tokens = 0
459454
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
460-
self.scores: npt.NDArray[np.single] = np.ndarray(
461-
(n_ctx if logits_all == True else n_batch, self._n_vocab), dtype=np.single
462-
)
455+
self.scores: npt.NDArray[np.single] = np.ndarray((n_batch, self._n_vocab), dtype=np.single)
463456

464457
self._mirostat_mu = ctypes.c_float(
465458
2.0 * 5.0
@@ -568,7 +561,7 @@ def eval_tokens(self) -> Deque[int]:
568561
def eval_logits(self) -> Deque[List[float]]:
569562
return deque(
570563
self.scores[: self.n_tokens, :].tolist(),
571-
maxlen=self._n_ctx if self.context_params.logits_all else 1,
564+
maxlen=self._n_ctx
572565
)
573566

574567
def tokenize(
@@ -635,34 +628,18 @@ def eval(self, tokens: Sequence[int]):
635628
Args:
636629
tokens: The list of tokens to evaluate.
637630
"""
638-
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
631+
self._ctx.kv_self_seq_rm(-1, self.n_tokens, -1)
639632
for i in range(0, len(tokens), self.n_batch):
640633
batch = tokens[i : min(len(tokens), i + self.n_batch)]
641634
n_past = self.n_tokens
642635
n_tokens = len(batch)
643636
self._batch.set_batch(
644-
batch=batch, n_past=n_past, logits_all=self.context_params.logits_all
637+
batch=batch, n_past=n_past
645638
)
646639
self._ctx.decode(self._batch)
647640
# Save tokens
648641
self.input_ids[n_past : n_past + n_tokens] = batch
649-
# Save logits
650-
if self.context_params.logits_all:
651-
rows = n_tokens
652-
cols = self._n_vocab
653-
logits = np.ctypeslib.as_array(
654-
self._ctx.get_logits(), shape=(rows * cols,)
655-
)
656-
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
657-
else:
658-
# rows = 1
659-
# cols = self._n_vocab
660-
# logits = np.ctypeslib.as_array(
661-
# self._ctx.get_logits(), shape=(rows * cols,)
662-
# )
663-
# self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
664-
# NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
665-
pass
642+
666643
# Update n_tokens
667644
self.n_tokens += n_tokens
668645

@@ -988,7 +965,7 @@ def generate(
988965

989966
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
990967
self.n_tokens = sample_idx
991-
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
968+
self._ctx.kv_self_seq_rm(-1, self.n_tokens, -1)
992969
break
993970

994971
if self.draft_model is not None:
@@ -1062,7 +1039,6 @@ def embed(
10621039

10631040
# get pooling information
10641041
pooling_type = self.pooling_type()
1065-
logits_all = pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE
10661042

10671043
if self.context_params.embeddings is False:
10681044
raise RuntimeError(
@@ -1140,7 +1116,7 @@ def decode_batch(seq_sizes: List[int]):
11401116
p_batch = 0
11411117

11421118
# add to batch
1143-
self._batch.add_sequence(tokens, p_batch, logits_all)
1119+
self._batch.add_sequence(tokens, p_batch)
11441120

11451121
# update batch stats
11461122
s_batch.append(n_tokens)
@@ -1340,9 +1316,9 @@ def logit_bias_processor(
13401316
else:
13411317
stop_sequences = []
13421318

1343-
if logprobs is not None and self.context_params.logits_all is False:
1319+
if logprobs is not None:
13441320
raise ValueError(
1345-
"logprobs is not supported for models created with logits_all=False"
1321+
"logprobs is not supported for models"
13461322
)
13471323

13481324
if self.cache:
@@ -2213,7 +2189,6 @@ def __getstate__(self):
22132189
yarn_beta_fast=self.context_params.yarn_beta_fast,
22142190
yarn_beta_slow=self.context_params.yarn_beta_slow,
22152191
yarn_orig_ctx=self.context_params.yarn_orig_ctx,
2216-
logits_all=self.context_params.logits_all,
22172192
embedding=self.context_params.embeddings,
22182193
offload_kqv=self.context_params.offload_kqv,
22192194
flash_attn=self.context_params.flash_attn,

llama_cpp/llama_cpp.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ class llama_model_params(ctypes.Structure):
751751

752752

753753
# // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations
754-
# // https://github.com/ggerganov/llama.cpp/pull/7544
754+
# // https://github.com/ggml-org/llama.cpp/pull/7544
755755
# struct llama_context_params {
756756
# uint32_t n_ctx; // text context, 0 = from model
757757
# uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
@@ -764,7 +764,7 @@ class llama_model_params(ctypes.Structure):
764764
# enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
765765
# enum llama_attention_type attention_type; // attention type to use for embeddings
766766

767-
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
767+
# // ref: https://github.com/ggml-org/llama.cpp/pull/2054
768768
# float rope_freq_base; // RoPE base frequency, 0 = from model
769769
# float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
770770
# float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
@@ -779,20 +779,17 @@ class llama_model_params(ctypes.Structure):
779779

780780
# enum ggml_type type_k; // data type for K cache [EXPERIMENTAL]
781781
# enum ggml_type type_v; // data type for V cache [EXPERIMENTAL]
782-
783-
# // Keep the booleans together to avoid misalignment during copy-by-value.
784-
# bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
785-
# bool embeddings; // if true, extract embeddings (together with logits)
786-
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
787-
# bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
788-
# bool no_perf; // whether to measure performance timings
789-
790-
791782
# // Abort callback
792783
# // if it returns true, execution of llama_decode() will be aborted
793784
# // currently works only with CPU execution
794785
# ggml_abort_callback abort_callback;
795786
# void * abort_callback_data;
787+
788+
# // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
789+
# bool embeddings; // if true, extract embeddings (together with logits)
790+
# bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
791+
# bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
792+
# bool no_perf; // whether to measure performance timings
796793
# };
797794
class llama_context_params(ctypes.Structure):
798795
"""Parameters for llama_context
@@ -819,13 +816,12 @@ class llama_context_params(ctypes.Structure):
819816
cb_eval_user_data (ctypes.ctypes.c_void_p): user data for cb_eval
820817
type_k (int): data type for K cache
821818
type_v (int): data type for V cache
822-
logits_all (bool): the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
819+
abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted
820+
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
823821
embeddings (bool): if true, extract embeddings (together with logits)
824822
offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU
825823
flash_attn (bool): whether to use flash attention
826824
no_perf (bool): whether to measure performance timings
827-
abort_callback (ggml_abort_callback): abort callback if it returns true, execution of llama_decode() will be aborted
828-
abort_callback_data (ctypes.ctypes.c_void_p): data for abort_callback
829825
"""
830826

831827
if TYPE_CHECKING:
@@ -850,13 +846,12 @@ class llama_context_params(ctypes.Structure):
850846
cb_eval_user_data: ctypes.c_void_p
851847
type_k: int
852848
type_v: int
853-
logits_all: bool
849+
abort_callback: Callable[[ctypes.c_void_p], bool]
850+
abort_callback_data: ctypes.c_void_p
854851
embeddings: bool
855852
offload_kqv: bool
856853
flash_attn: bool
857854
no_perf: bool
858-
abort_callback: Callable[[ctypes.c_void_p], bool]
859-
abort_callback_data: ctypes.c_void_p
860855

861856
_fields_ = [
862857
("n_ctx", ctypes.c_uint32),
@@ -880,13 +875,12 @@ class llama_context_params(ctypes.Structure):
880875
("cb_eval_user_data", ctypes.c_void_p),
881876
("type_k", ctypes.c_int),
882877
("type_v", ctypes.c_int),
883-
("logits_all", ctypes.c_bool),
878+
("abort_callback", ggml_abort_callback),
879+
("abort_callback_data", ctypes.c_void_p),
884880
("embeddings", ctypes.c_bool),
885881
("offload_kqv", ctypes.c_bool),
886882
("flash_attn", ctypes.c_bool),
887883
("no_perf", ctypes.c_bool),
888-
("abort_callback", ggml_abort_callback),
889-
("abort_callback_data", ctypes.c_void_p),
890884
]
891885

892886

@@ -2683,10 +2677,12 @@ def llama_batch_free(batch: llama_batch, /):
26832677
...
26842678

26852679

2686-
# // Processes a batch of tokens with the ecoder part of the encoder-decoder model.
2687-
# // Stores the encoder output internally for later use by the decoder cross-attention layers.
2680+
# // Process a batch of tokens.
2681+
# // In contrast to llama_decode() - this call does not use KV cache.
2682+
# // For encode-decoder contexts, processes the batch using the encoder.
2683+
# // Can store the encoder output internally for later use by the decoder's cross-attention layers.
26882684
# // 0 - success
2689-
# // < 0 - error
2685+
# // < 0 - error. the KV cache state is restored to the state before this call
26902686
# LLAMA_API int32_t llama_encode(
26912687
# struct llama_context * ctx,
26922688
# struct llama_batch batch);
@@ -2699,10 +2695,13 @@ def llama_encode(ctx: llama_context_p, batch: llama_batch, /) -> int:
26992695
...
27002696

27012697

2698+
# // Process a batch of tokens.
2699+
# // Requires KV cache.
2700+
# // For encode-decoder contexts, processes the batch using the decoder.
27022701
# // Positive return values does not mean a fatal error, but rather a warning.
27032702
# // 0 - success
27042703
# // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
2705-
# // < 0 - error
2704+
# // < 0 - error. the KV cache state is restored to the state before this call
27062705
# LLAMA_API int32_t llama_decode(
27072706
# struct llama_context * ctx,
27082707
# struct llama_batch batch);

llama_cpp/server/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
261261
yarn_beta_slow=settings.yarn_beta_slow,
262262
yarn_orig_ctx=settings.yarn_orig_ctx,
263263
mul_mat_q=settings.mul_mat_q,
264-
logits_all=settings.logits_all,
265264
embedding=settings.embedding,
266265
offload_kqv=settings.offload_kqv,
267266
flash_attn=settings.flash_attn,

llama_cpp/server/settings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ class ModelSettings(BaseSettings):
9898
mul_mat_q: bool = Field(
9999
default=True, description="if true, use experimental mul_mat_q kernels"
100100
)
101-
logits_all: bool = Field(default=True, description="Whether to return logits.")
102101
embedding: bool = Field(default=False, description="Whether to use embeddings.")
103102
offload_kqv: bool = Field(
104103
default=True, description="Whether to offload kqv to the GPU."

tests/test_llama.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def test_real_model(llama_cpp_model_path):
8181
cparams.n_ubatch = 16
8282
cparams.n_threads = multiprocessing.cpu_count()
8383
cparams.n_threads_batch = multiprocessing.cpu_count()
84-
cparams.logits_all = False
8584
cparams.flash_attn = True
8685

8786
context = internals.LlamaContext(model=model, params=cparams)
@@ -103,7 +102,7 @@ def test_real_model(llama_cpp_model_path):
103102
result = tokens
104103
n_eval = 0
105104
for _ in range(4):
106-
batch.set_batch(tokens, n_past=n_eval, logits_all=False)
105+
batch.set_batch(tokens, n_past=n_eval)
107106
context.decode(batch)
108107
n_eval += len(tokens)
109108
token_id = sampler.sample(context, -1)
@@ -122,7 +121,6 @@ def test_real_llama(llama_cpp_model_path):
122121
n_ubatch=32,
123122
n_threads=multiprocessing.cpu_count(),
124123
n_threads_batch=multiprocessing.cpu_count(),
125-
logits_all=False,
126124
flash_attn=True,
127125
)
128126

0 commit comments

Comments
 (0)