Skip to content

Commit b261ccf

Browse files
committed
Adding commandr
1 parent b738734 commit b261ccf

File tree

2 files changed

+47
-18
lines changed

2 files changed

+47
-18
lines changed

llm_exl2_dynamic_gen.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
import queue
3535
import uvicorn
3636
from io import StringIO
37-
from util import format_prompt_llama3, format_prompt, format_prompt_tess
37+
from util import format_prompt_llama3, format_prompt, format_prompt_tess, format_prompt_commandr
3838
from util_merge import ExLlamaV2MergePassthrough
3939

4040
def generate_unique_id():
4141
return uuid.uuid4()
4242

4343
# This is a demo and small stress to showcase some of the features of the dynamic batching generator.
44-
repo_str = 'tess-xl-exl2-speculative'
44+
repo_str = 'commandr-exl2'
4545

4646
class CompletionRequest(BaseModel):
4747
model: str
@@ -205,7 +205,7 @@ def display(self):
205205
total_context = 32768
206206

207207
# Max individual context
208-
max_context = 12288
208+
max_context = 8192
209209

210210
# N-gram or draft model speculative decoding. Largely detrimental to performance at higher batch sizes.
211211
use_ngram = False
@@ -215,7 +215,7 @@ def display(self):
215215
draft_model_dir = specrepo_id
216216

217217
# Max number of batches to run at once, assuming the sequences will fit within total_context.
218-
max_batch_size = 6 if paged else 1
218+
max_batch_size = 4 if paged else 1
219219

220220
# Max chunk size. Determines the size of prefill operations. Can be reduced to reduce pauses whenever a
221221
# new job is started, but at the expense of overall prompt ingestion speed.
@@ -267,22 +267,22 @@ def display(self):
267267
config.max_input_len = max_chunk_size
268268
config.max_attention_size = max_chunk_size ** 2
269269

270-
ropescale = 2.5
271-
config.scale_alpha_value = ropescale
270+
#ropescale = 2.5
271+
#config.scale_alpha_value = ropescale
272272
config.max_seq_len = max_context
273273
model = ExLlamaV2(config)
274274

275275
# Configure the cache. The dynamic generator expects a batch size of 1 and a max_seq_len equal to
276276
# the total number of cached tokens. The flat cache will be split dynamically
277277

278-
#cache = ExLlamaV2Cache(
279-
# model,
280-
# max_seq_len = total_context,
281-
#lazy = True
282-
#)
278+
cache = ExLlamaV2Cache_Q4(
279+
model,
280+
max_seq_len = total_context,
281+
lazy = True
282+
)
283283

284-
#model.load_autosplit(cache, progress = True)
285-
model.load([16,18,18,20])
284+
model.load_autosplit(cache, progress = True)
285+
#model.load([16,18,18,20])
286286
# Also, tokenizer
287287

288288
print("Loading tokenizer...")
@@ -296,11 +296,11 @@ def display(self):
296296
#lora = ExLlamaV2Lora.from_directory(model, lora_directory)
297297
lora = None
298298

299-
cache = ExLlamaV2Cache_Q4(
300-
model,
301-
max_seq_len = total_context,
299+
#cache = ExLlamaV2Cache_Q4(
300+
# model,
301+
# max_seq_len = total_context,
302302
#lazy = True
303-
)
303+
#)
304304

305305
# Initialize the generator
306306

@@ -574,6 +574,8 @@ async def mainchat(request: ChatCompletionRequest):
574574
prompt = await format_prompt_tess(request.messages)
575575
elif repo_str == 'tinyllama-exl2-speculative':
576576
prompt = await format_prompt_zephyr(request.messages)
577+
elif repo_str == 'commandr-exl2' or repo_str == 'commandr-exl2-speculative':
578+
prompt = await format_prompt_commandr(request.messages)
577579
else:
578580
prompt = await format_prompt(request.messages)
579581
status_area.update(f"Prompt: {prompt}")

util.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,31 @@ async def format_prompt_mixtral(messages):
148148
formatted_prompt += f"[INST] {message.content} [/INST] "
149149
elif message.role == "assistant":
150150
formatted_prompt += f" {message.content}</s> " # Prep for user follow-up
151-
return formatted_prompt
151+
return formatted_prompt
152+
153+
async def format_prompt_commandr(messages):
154+
formatted_prompt = ""
155+
system_message_found = False
156+
157+
# Check for a system message first
158+
for message in messages:
159+
if message.role == "system":
160+
system_message_found = True
161+
break
162+
163+
# If no system message was found, prepend a default one
164+
if not system_message_found:
165+
formatted_prompt += f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
166+
167+
for message in messages:
168+
if message.role == "system":
169+
formatted_prompt += f"<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
170+
elif message.role == "user":
171+
formatted_prompt += f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
172+
elif message.role == "assistant":
173+
formatted_prompt += f"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{message.content}<|END_OF_TURN_TOKEN|>"
174+
# Add the final "### Assistant:\n" to prompt for the next response
175+
formatted_prompt += "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
176+
return formatted_prompt
177+
178+

0 commit comments

Comments
 (0)