File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -232,7 +232,7 @@ def batched_generate_fn(
232
232
"""
233
233
batch_size = len (prompts )
234
234
assert batch_size > 0 , "No prompts are given"
235
- assert prompt_chunksize > 0 , "prompt_chunksize must be positive"
235
+ assert prompt_chunksize >= 1 , "prompt_chunksize must be positive"
236
236
prompt_size = []
237
237
device = prompts [0 ].device
238
238
prompt_dtype = prompts [0 ].dtype
@@ -266,7 +266,7 @@ def batched_generate_fn(
266
266
max_prefill_length = model .kv_cache_max_prefill_length ()
267
267
if max_prefill_length is None :
268
268
max_prefill_length = min_prompt_size
269
- token_pos = min ([ min_prompt_size , max_prefill_length ] )
269
+ token_pos = min (min_prompt_size , max_prefill_length )
270
270
start = 0
271
271
while True :
272
272
inputs = torch .cat (
@@ -275,7 +275,7 @@ def batched_generate_fn(
275
275
)
276
276
# We may need the last time slice of `all_logits` below:
277
277
all_logits = model (inputs , input_pos = start )
278
- if token_pos = = min_prompt_size :
278
+ if token_pos > = min_prompt_size :
279
279
break
280
280
start = token_pos
281
281
# Note that `max_tokens_forward` can change during the course of
You can’t perform that action at this time.
0 commit comments