Skip to content

Commit 94037d2

Browse files
committed
Nothing
1 parent fce121b commit 94037d2

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

litgpt/generate/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def batched_generate_fn(
232232
"""
233233
batch_size = len(prompts)
234234
assert batch_size > 0, "No prompts are given"
235-
assert prompt_chunksize > 0, "prompt_chunksize must be positive"
235+
assert prompt_chunksize >= 1, "prompt_chunksize must be positive"
236236
prompt_size = []
237237
device = prompts[0].device
238238
prompt_dtype = prompts[0].dtype
@@ -266,7 +266,7 @@ def batched_generate_fn(
266266
max_prefill_length = model.kv_cache_max_prefill_length()
267267
if max_prefill_length is None:
268268
max_prefill_length = min_prompt_size
269-
token_pos = min([min_prompt_size, max_prefill_length])
269+
token_pos = min(min_prompt_size, max_prefill_length)
270270
start = 0
271271
while True:
272272
inputs = torch.cat(
@@ -275,7 +275,7 @@ def batched_generate_fn(
275275
)
276276
# We may need the last time slice of `all_logits` below:
277277
all_logits = model(inputs, input_pos=start)
278-
if token_pos == min_prompt_size:
278+
if token_pos >= min_prompt_size:
279279
break
280280
start = token_pos
281281
# Note that `max_tokens_forward` can change during the course of

0 commit comments

Comments
 (0)