Skip to content

Commit dce6632

Browse files
enable prompt template for gguf format inference (#57)
1 parent d59b5d8 commit dce6632

File tree

5 files changed

+80
-29
lines changed

5 files changed

+80
-29
lines changed

llmserve/backend/llm/initializers/llamacpp.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def __init__(
6666

6767
def _get_model_init_kwargs(self) -> Dict[str, Any]:
6868
return {
69-
# We use a large integer to put all of the layers on GPU by default.
70-
"n_gpu_layers": 0 if self.device.type == "cpu" else 10**6,
71-
"seed": 0,
69+
# -1 means all layers are offloaded to GPU
70+
"n_gpu_layers": 0 if self.device.type == "cpu" else -1,
71+
"seed": -1,
7272
"verbose": False,
7373
"n_threads": int(os.environ["OMP_NUM_THREADS"]),
7474
**self.model_init_kwargs,
@@ -82,15 +82,11 @@ def load_model(self, model_id: str) -> "Llama":
8282
# Lazy import to avoid issues on CPU head node
8383
from llama_cpp import Llama
8484

85-
return Llama(
85+
self.model = Llama(
8686
model_path=os.path.abspath(model_path),
8787
**self._get_model_init_kwargs(),
8888
)
89-
89+
return self.model
90+
9091
def load_tokenizer(self, tokenizer_name: str) -> None:
91-
return None
92-
93-
def postprocess(
94-
self, model: "Llama", tokenizer: None
95-
) -> Tuple["Llama", LlamaCppTokenizer]:
96-
return super().postprocess(model, LlamaCppTokenizer(model))
92+
return LlamaCppTokenizer(self.model)

llmserve/backend/llm/pipelines/llamacpp/llamacpp_pipeline.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ...initializers.llamacpp import LlamaCppInitializer, LlamaCppTokenizer
1010
from .._base import StreamingPipeline
1111
from ..utils import decode_stopping_sequences_where_needed, construct_prompts
12+
import json
1213

1314
if TYPE_CHECKING:
1415
from llama_cpp import Llama, LogitsProcessorList, StoppingCriteriaList
@@ -104,20 +105,44 @@ def __call__(self, inputs: List[str], **kwargs) -> List[Response]:
104105
inputs, prompt_format=self.prompt_format)
105106

106107
logger.info(inputs)
107-
tokenized_inputs = self.tokenizer.encode(inputs[0])
108+
109+
tokenized_inputs = self.tokenizer.encode(inputs)
108110
kwargs = self._add_default_generate_kwargs(
109111
kwargs,
110112
model_inputs={"inputs": inputs,
111113
"tokenized_inputs": tokenized_inputs},
112114
)
113115

116+
chat_completion = False
117+
try:
118+
inputs_bak = inputs
119+
inputs = [json.loads(prompt) for prompt in inputs]
120+
chat_completion = True
121+
except:
122+
logger.info("Seems no chat template from user")
123+
inputs = inputs_bak
124+
114125
logger.info(f"Forward params: {kwargs}, model_inputs {inputs}")
115126
responses = []
116127
for input in inputs:
117128
st = time.monotonic()
118-
output = self.model(input, **kwargs)
129+
if chat_completion:
130+
kwargs.pop('stopping_criteria', None)
131+
kwargs.pop('echo', None)
132+
logger.info(f"Forward params: {kwargs}, model_inputs {inputs}")
133+
output = self.model.create_chat_completion(
134+
messages=input,
135+
**kwargs
136+
)
137+
text = output["choices"][0]["message"]["content"].replace("\u200b", "").strip()
138+
else:
139+
output = self.model(input, **kwargs)
140+
text = output["choices"][0]["text"].replace("\u200b", "").strip()
141+
142+
143+
logger.info(f"llm's raw response is: {output}")
119144
gen_time = time.monotonic() - st
120-
text = output["choices"][0]["text"].replace("\u200b", "").strip()
145+
121146
responses.append(
122147
Response(
123148
generated_text=text,
@@ -178,6 +203,7 @@ def from_initializer(
178203
cls,
179204
initializer: "LlamaCppInitializer",
180205
model_id: str,
206+
prompt_format: Optional[str] = None,
181207
device: Optional[Union[str, int, torch.device]] = None,
182208
**kwargs,
183209
) -> "LlamaCppPipeline":
@@ -188,6 +214,7 @@ def from_initializer(
188214
return cls(
189215
model,
190216
tokenizer,
217+
prompt_format,
191218
device=device,
192219
**kwargs,
193220
)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
deployment_config:
2+
autoscaling_config:
3+
min_replicas: 0
4+
initial_replicas: 1
5+
max_replicas: 8
6+
target_num_ongoing_requests_per_replica: 1.0
7+
metrics_interval_s: 10.0
8+
look_back_period_s: 30.0
9+
smoothing_factor: 1.0
10+
downscale_delay_s: 300.0
11+
upscale_delay_s: 90.0
12+
ray_actor_options:
13+
num_cpus: 0.1 # for a model deployment, we have 3 actor created, 1 and 2 will cost 0.1 cpu, and the model infrence will cost 6(see the setting in the end of the file)
14+
model_config:
15+
warmup: True
16+
model_task: text-generation
17+
model_id: Qwen/Qwen1.5-7B-Chat-GGUF
18+
max_input_words: 128
19+
initialization:
20+
initializer:
21+
type: LlamaCpp
22+
model_filename: qwen1_5-7b-chat-q3_k_m.gguf
23+
model_init_kwargs:
24+
test: true
25+
pipeline: llamacpp
26+
generation:
27+
max_batch_size: 2
28+
batch_wait_timeout_s: 0
29+
generate_kwargs:
30+
max_tokens: 32
31+
echo: true
32+
prompt_format: '[{{"role": "system", "content": "You are a friendly chatbot who always responds in the style of a pirate"}},{{"role": "user", "content": "{instruction}"}}]'
33+
stopping_sequences: ["\n"]
34+
scaling_config:
35+
num_workers: 1
36+
num_gpus_per_worker: 0
37+
num_cpus_per_worker: 8 # for inference

models/text-generation--llama-7b-GGUF.yaml

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
deployment_config:
22
autoscaling_config:
3-
min_replicas: 0
3+
min_replicas: 1
44
initial_replicas: 1
55
max_replicas: 8
66
target_num_ongoing_requests_per_replica: 1.0
@@ -15,34 +15,25 @@ model_config:
1515
warmup: True
1616
model_task: text-generation
1717
model_id: TheBloke/Llama-2-7B-GGUF
18-
max_input_words: 800
18+
max_input_words: 128
1919
initialization:
20-
# s3_mirror_config:
21-
# endpoint_url: http://39.107.108.170:9000
22-
# bucket_uri: /Users/hub/models/llama-2-7b-gguf/
2320
initializer:
2421
type: LlamaCpp
25-
model_filename: llama-2-7b.Q5_K_S.gguf
22+
model_filename: llama-2-7b.Q2_K.gguf
2623
model_init_kwargs:
2724
test: true
28-
29-
# use_kernel: true # for deepspped type only
30-
# max_tokens: 1536 # for deepspped type only
31-
# pipeline: defaulttransformers
32-
# pipeline: default
3325
pipeline: llamacpp
3426
generation:
3527
max_batch_size: 2
3628
batch_wait_timeout_s: 0
3729
generate_kwargs:
38-
# do_sample: true
3930
max_tokens: 128
4031
temperature: 0.7
4132
top_p: 0.8
4233
top_k: 50
4334
echo: false
44-
# prompt_format: "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\n{instruction}\n### Response:\n"
45-
stopping_sequences: ["### Response:", "### End"]
35+
prompt_format: "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\n{instruction}\n### Response:\n"
36+
#stopping_sequences: ["\n"]
4637
scaling_config:
4738
num_workers: 1
4839
num_gpus_per_worker: 0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
"accelerate==0.25.0",
5252
"deepspeed==0.14.0",
5353
"torchmetrics==1.2.1",
54-
"llama_cpp_python==0.2.20",
54+
"llama_cpp_python==0.2.57",
5555
"transformers==4.39.1",
5656
],
5757
"vllm": [

0 commit comments

Comments
 (0)