Skip to content

Commit 310de8a

Browse files
authored
Fix prompt is not string bug (#81)
* Fix prompt is not string bug * update parameter type
1 parent 0cebcd3 commit 310de8a

File tree

7 files changed

+22
-15
lines changed

7 files changed

+22
-15
lines changed

llmserve/backend/llm/engines/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from llmserve.backend.logger import get_logger
1313
from llmserve.backend.server.models import Args, Prompt
1414
import asyncio
15-
from typing import AsyncGenerator, Generator
15+
from typing import Union, AsyncGenerator, Generator
1616

1717
logger = get_logger(__name__)
1818

@@ -67,5 +67,5 @@ async def check_health(self):
6767
pass
6868

6969
@abstractmethod
70-
def stream_generate_texts(self, prompt: str) -> Generator[str, None, None]:
70+
def stream_generate_texts(self, prompt: Union[Prompt, List[Prompt]]) -> Generator[str, None, None]:
7171
pass

llmserve/backend/llm/engines/generic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from llmserve.backend.server.utils import render_gradio_params
3333
from ._base import LLMEngine
3434

35-
from typing import AsyncGenerator, Generator
35+
from typing import AsyncGenerator, Generator, Union
3636
from queue import Empty
3737
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3838
from threading import Thread
@@ -284,7 +284,7 @@ def ping(self) -> bool:
284284
"""Ping the worker."""
285285
return True
286286

287-
async def worker_stream_generate_texts(self, prompt: str, **kwargs) -> Generator[str, None, None]: # type: ignore
287+
async def worker_stream_generate_texts(self, prompt: Union[Prompt, List[Prompt]], **kwargs) -> Generator[str, None, None]: # type: ignore
288288
logger.info(f"Call PredictionWorker.worker_stream_generate_texts with kwargs: {kwargs}")
289289
for s in self.generator.streamGenerate(prompt, **kwargs):
290290
# logger.info(f"PredictionWorker.worker_stream_generate_texts -> yield ->{s}")
@@ -430,7 +430,7 @@ async def check_health(self):
430430
"Reinitializing worker group."
431431
)
432432

433-
def stream_generate_texts(self, prompt: str) -> Generator[str, None, None]: # type: ignore
433+
def stream_generate_texts(self, prompt: Union[Prompt, List[Prompt]]) -> Generator[str, None, None]: # type: ignore
434434
logger.info(f"GenericEngine.stream_generate_texts -> worker.length: {len(self.base_worker_group)}")
435435
worker0 = self.base_worker_group[0]
436436
for strHandle in worker0.worker_stream_generate_texts.remote(

llmserve/backend/llm/pipelines/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def _sanitize_parameters(
341341
return preprocess_params, forward_params, postprocess_params
342342

343343
@abstractmethod
344-
def streamGenerate(self, prompt: str, **generate_kwargs) -> Generator[str, None, None]:
344+
def streamGenerate(self, prompt: Union[Prompt, List[Prompt]], **generate_kwargs) -> Generator[str, None, None]:
345345
pass
346346

347347
class StreamingPipeline(BasePipeline):

llmserve/backend/llm/pipelines/default_pipeline.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from transformers import PreTrainedModel, PreTrainedTokenizer
66

77
from llmserve.backend.logger import get_logger
8-
from llmserve.backend.server.models import Response
8+
from llmserve.backend.server.models import Prompt, Response
99
import json
1010

1111
from ._base import BasePipeline
@@ -167,15 +167,22 @@ def postprocess(self, model_outputs, **postprocess_kwargs) -> List[Response]:
167167
response.postprocessing_time = et
168168
return decoded
169169

170-
def streamGenerate(self, prompt: str, **generate_kwargs) -> Generator[str, None, None]:
170+
def streamGenerate(self, prompt: Union[Prompt, List[Prompt]], **generate_kwargs) -> Generator[str, None, None]:
171171
logger.info(f"DefaultPipeline.streamGenerate with generate_kwargs: {generate_kwargs}")
172172
# timeout=0 will dramatic slow down the speed of generator, the root caused still unknow
173173
streamer = TextIteratorStreamer(self.tokenizer,
174174
# timeout=0,
175175
skip_prompt=True,
176176
skip_special_tokens=True)
177-
input_ids = self.tokenizer([prompt], return_tensors="pt")
178-
# generation_kwargs = dict(input_ids, streamer=streamer, max_new_tokens=20)
177+
prompt_inputs = []
178+
if isinstance(prompt, Prompt):
179+
prompt_inputs = [prompt.prompt]
180+
elif isinstance(prompt, list):
181+
prompt_inputs = [p.prompt for p in prompt]
182+
183+
logger.info(f"DefaultPipeline.streamGenerate with prompt_inputs: {prompt_inputs}")
184+
input_ids = self.tokenizer(prompt_inputs, return_tensors="pt")
185+
# input_ids = self.tokenizer([prompt], return_tensors="pt")
179186
max_new_tokens = 256
180187
if generate_kwargs["max_new_tokens"]:
181188
max_new_tokens = generate_kwargs["max_new_tokens"]

llmserve/backend/llm/pipelines/llamacpp/llamacpp_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from llmserve.backend.logger import get_logger
7-
from llmserve.backend.server.models import Response
7+
from llmserve.backend.server.models import Prompt, Response
88

99
from ...initializers.llamacpp import LlamaCppInitializer, LlamaCppTokenizer
1010
from .._base import StreamingPipeline
@@ -225,7 +225,7 @@ def from_initializer(
225225
**kwargs,
226226
)
227227

228-
def streamGenerate(self, prompt: str, **generate_kwargs) -> Generator[str, None, None]:
228+
def streamGenerate(self, prompt: Union[Prompt, List[Prompt]], **generate_kwargs) -> Generator[str, None, None]:
229229
logger.info(f"stream prompt: {prompt}")
230230
inputs = construct_prompts(prompt, prompt_format=self.prompt_format)
231231
logger.info(f"stream inputs: {inputs}")

llmserve/backend/llm/predictor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from llmserve.backend.logger import get_logger
1515
from llmserve.backend.server.models import Args, Prompt
1616

17-
from typing import AsyncGenerator, Generator
17+
from typing import AsyncGenerator, Generator, Union
1818

1919
initialize_node_remote = ray.remote(initialize_node)
2020
logger = get_logger(__name__)
@@ -181,7 +181,7 @@ async def _predict_async(
181181
async def check_health(self):
182182
self.engine.check_health()
183183

184-
async def stream_generate_texts(self, prompt: str) -> Generator[str, None, None]: # type: ignore
184+
async def stream_generate_texts(self, prompt: Union[Prompt, List[Prompt]]) -> Generator[str, None, None]: # type: ignore
185185
logger.info(f"call LLMPredictor.stream_generate_texts")
186186
for s in self.engine.stream_generate_texts(prompt):
187187
logger.info(f"LLMPredictor.stream_generate_texts -> yield ->{s}")

llmserve/backend/server/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def streamer(self, model: str, prompt: Union[Prompt, List[Prompt]]) -> Streaming
368368
logger.info(f"search stream model key: {modelID}")
369369
return StreamingResponse(self.streamer_generate_text(modelID, prompt), media_type="text/plain")
370370

371-
async def streamer_generate_text(self, modelID: str, prompt: str) -> AsyncGenerator[str, None]:
371+
async def streamer_generate_text(self, modelID: str, prompt: Union[Prompt, List[Prompt]]) -> AsyncGenerator[str, None]:
372372
logger.info(f'streamer_generate_text: {modelID}, prompt: "{prompt}"')
373373
r: DeploymentResponseGenerator = self._models[modelID].stream_generate_texts.remote(prompt)
374374
async for i in r:

0 commit comments

Comments
 (0)