Skip to content

Commit fce121b

Browse files
committed
Small change
1 parent 212c7a7 commit fce121b

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

litgpt/generate/base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,16 @@ def batched_generate_fn(
275275
)
276276
# We may need the last time slice of `all_logits` below:
277277
all_logits = model(inputs, input_pos=start)
278-
if start == 0:
279-
max_tokens_forward = model.kv_cache_max_tokens_forward()
280-
if prompt_chunksize > max_tokens_forward:
281-
print(
282-
f"prompt_chunksize = {prompt_chunksize} > {max_tokens_forward} = max_tokens_forward. Lowering it to the latter.")
283-
prompt_chunksize = max_tokens_forward
284-
start = token_pos
285278
if token_pos == min_prompt_size:
286279
break
287-
chunksize = min(prompt_chunksize, min_prompt_size - token_pos)
280+
start = token_pos
281+
# Note that `max_tokens_forward` can change during the course of
282+
# prompt processing:
283+
chunksize = min((
284+
prompt_chunksize,
285+
model.kv_cache_max_tokens_forward(),
286+
min_prompt_size - token_pos
287+
))
288288
token_pos += chunksize
289289

290290
# Generation loop: One token per iteration

litgpt/kvcache/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def next_token_pos(self) -> Optional[int]:
151151
@property
152152
def max_tokens_forward(self) -> int:
153153
"""
154+
Note that this limit may change during the course of the generation
155+
for certain caches.
156+
154157
Returns:
155158
Maximum number of token positions which can be treated in
156159
:meth:`forward`. Depends on cache, but is `<= cache_length`

0 commit comments

Comments
 (0)