Skip to content

Commit 327b77f

Browse files
authored
Fix processing of sampling_params within a batched request (#1273)
Introduced a bug in #1214 where when sending a batch request with `sampling_params` defined, we were improperly attempting to update the `DecodeConfig` with the `List[SamplingParams]`, instead of updating with the `SamplingParams` object that corresponded to the specific request within the batch. This implements a fix for it.
1 parent a5e5e91 commit 327b77f

File tree

1 file changed

+9
-2
lines changed
  • shortfin/python/shortfin_apps/llm/components

1 file changed

+9
-2
lines changed

shortfin/python/shortfin_apps/llm/components/generate.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import io
1010
import json
1111
import logging
12+
13+
from copy import copy
1214
from typing import List
1315

1416
import shortfin as sf
@@ -146,7 +148,6 @@ def __init__(
146148
self.complete_infeed = self.system.create_queue()
147149

148150
self.decode_config = service.server_params.decode_config
149-
self.decode_config.update_from_sampling_params(gen_req.sampling_params)
150151

151152
async def run(self):
152153
logger.debug("Started ClientBatchGenerateProcess: %r", self)
@@ -166,6 +167,12 @@ async def run(self):
166167
else:
167168
input_batch = self.tokenize()
168169
for index, input_tokens in enumerate(input_batch):
170+
decode_config = copy(self.decode_config)
171+
decode_config.update_from_sampling_params(
172+
self.gen_req.sampling_params
173+
if self.gen_req.is_single
174+
else self.gen_req.sampling_params[index]
175+
)
169176
gen_process = GenerateItemProcess(
170177
self,
171178
self.gen_req,
@@ -175,7 +182,7 @@ async def run(self):
175182
else self.gen_req.text[index],
176183
input_tokens if is_pretokenized else input_tokens.ids,
177184
eos_token_id=self.tokenizer.eos_token_id,
178-
decode_config=self.decode_config,
185+
decode_config=decode_config,
179186
)
180187
gen_processes.append(gen_process)
181188
gen_process.launch()

0 commit comments

Comments
 (0)