|
5 | 5 | from transformers import PreTrainedModel, PreTrainedTokenizer
|
6 | 6 |
|
7 | 7 | from llmserve.backend.logger import get_logger
|
8 |
| -from llmserve.backend.server.models import Response |
| 8 | +from llmserve.backend.server.models import Prompt, Response |
9 | 9 | import json
|
10 | 10 |
|
11 | 11 | from ._base import BasePipeline
|
@@ -167,15 +167,22 @@ def postprocess(self, model_outputs, **postprocess_kwargs) -> List[Response]:
|
167 | 167 | response.postprocessing_time = et
|
168 | 168 | return decoded
|
169 | 169 |
|
170 |
| - def streamGenerate(self, prompt: str, **generate_kwargs) -> Generator[str, None, None]: |
| 170 | + def streamGenerate(self, prompt: Union[Prompt, List[Prompt]], **generate_kwargs) -> Generator[str, None, None]: |
171 | 171 | logger.info(f"DefaultPipeline.streamGenerate with generate_kwargs: {generate_kwargs}")
|
172 | 172 | # timeout=0 will dramatic slow down the speed of generator, the root caused still unknow
|
173 | 173 | streamer = TextIteratorStreamer(self.tokenizer,
|
174 | 174 | # timeout=0,
|
175 | 175 | skip_prompt=True,
|
176 | 176 | skip_special_tokens=True)
|
177 |
| - input_ids = self.tokenizer([prompt], return_tensors="pt") |
178 |
| - # generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=20) |
| 177 | + prompt_inputs = [] |
| 178 | + if isinstance(prompt, Prompt): |
| 179 | + prompt_inputs = [prompt.prompt] |
| 180 | + elif isinstance(prompt, list): |
| 181 | + prompt_inputs = [p.prompt for p in prompt] |
| 182 | + |
| 183 | + logger.info(f"DefaultPipeline.streamGenerate with prompt_inputs: {prompt_inputs}") |
| 184 | + input_ids = self.tokenizer(prompt_inputs, return_tensors="pt") |
| 185 | + # input_ids = self.tokenizer([prompt], return_tensors="pt") |
179 | 186 | max_new_tokens = 256
|
180 | 187 | if generate_kwargs["max_new_tokens"]:
|
181 | 188 | max_new_tokens = generate_kwargs["max_new_tokens"]
|
|
0 commit comments