|
2 | 2 |
|
3 | 3 | import torch
|
4 | 4 | import time
|
| 5 | +import json |
5 | 6 | from transformers import Pipeline as TransformersPipeline
|
6 | 7 | from transformers import PreTrainedModel, PreTrainedTokenizer, pipeline
|
7 | 8 |
|
8 | 9 | from llmserve.backend.logger import get_logger
|
9 | 10 | from llmserve.backend.server.models import Prompt, Response
|
10 | 11 |
|
11 | 12 | from ._base import BasePipeline
|
12 |
| -from .utils import construct_prompts_experimental, truncate_to_first_stop_token |
| 13 | +from .utils import construct_prompts |
13 | 14 | from llmserve.backend.server.utils import render_gradio_params
|
14 | 15 | from .default_pipeline import DefaultPipeline
|
15 | 16 |
|
@@ -135,12 +136,20 @@ def preprocess(self, prompts: List[str], **generate_kwargs):
|
135 | 136 | st = time.monotonic()
|
136 | 137 | inputs = None
|
137 | 138 | logger.info(f"input from pipeline: ****** {prompts}")
|
138 |
| - prompt_text = construct_prompts_experimental( |
| 139 | + prompt_text = construct_prompts( |
139 | 140 | prompts, prompt_format=self.prompt_format)
|
140 |
| - instruction_text = construct_prompts_experimental(prompts, prompt_format="") |
| 141 | + instruction_text = construct_prompts(prompts, prompt_format="") |
141 | 142 | logger.info(f"input from pipeline: ****** {prompt_text}")
|
142 | 143 |
|
143 | 144 | if isinstance(self.pipeline, transformers.pipelines.text_generation.TextGenerationPipeline):
|
| 145 | + try: |
| 146 | + prompt_text_bak = prompt_text |
| 147 | + prompt_text = [json.loads(prompt) for prompt in prompt_text] |
| 148 | + prompt_text = [self.tokenizer.apply_chat_template(prompt_obj, tokenize=False, add_generation_prompt=True) for prompt_obj in prompt_text] |
| 149 | + except: |
| 150 | + logger.info("Seems no chat template from user or the model donot has a 'chat template'") |
| 151 | + prompt_text = prompt_text_bak |
| 152 | + |
144 | 153 | inputs = self.tokenizer(
|
145 | 154 | prompt_text, return_tensors="pt", add_special_tokens = generate_kwargs.get("add_special_tokens", True), padding=True
|
146 | 155 | )
|
@@ -224,7 +233,7 @@ def postprocess(self, model_outputs, **postprocess_kwargs) -> List[Response]:
|
224 | 233 | output).input_ids)
|
225 | 234 | num_input_tokens = len(self.tokenizer(inputs[index]))
|
226 | 235 | response = Response(
|
227 |
| - generated_text=output, |
| 236 | + generated_text=output[len(inputs[index]):], |
228 | 237 | num_generated_tokens=num_generated_tokens,
|
229 | 238 | num_input_tokens=num_input_tokens,
|
230 | 239 | )
|
|
0 commit comments