Skip to content

Commit 8e6d27f

Browse files
committed
handle if assistant_message_gen and assistant_message_gen!=assistant_message_start, replace final output tag with unspaced (gen) version if exists
1 parent 204739e commit 8e6d27f

File tree

4 files changed

+26
-2
lines changed

4 files changed

+26
-2
lines changed

kcpp_adapters/Jamba.json

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"system_start": "<|bom|><|system|> ",
3+
"system_end": "<|eom|>",
4+
"user_start": "<|bom|><|user|> ",
5+
"user_end": "<|eom|>",
6+
"assistant_start": "<|bom|><|assistant|> ",
7+
"assistant_gen": "<|bom|><|assistant|>",
8+
"assistant_end": "<|eom|>"
9+
}

kcpp_adapters/Mistral-NonTekken.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"system_end": "",
44
"user_start": "[INST] ",
55
"user_end": "",
6-
"assistant_start": "[/INST]",
6+
"assistant_start": "[/INST] ",
7+
"assistant_gen": "[/INST]",
78
"assistant_end": "</s>"
89
}

kcpp_adapters/RWKV-World.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"system_end": "\n\n",
44
"user_start": "User: ",
55
"user_end": "\n\n",
6-
"assistant_start": "Assistant:",
6+
"assistant_start": "Assistant: ",
7+
"assistant_gen": "Assistant:",
78
"assistant_end": "\n\n"
89
}

koboldcpp.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,14 @@ def tryparsefloat(value,fallback):
671671
except ValueError:
672672
return fallback
673673

674+
def replace_last_in_string(text: str, match: str, replacement: str) -> str:
675+
if match == "":
676+
return text
677+
head, sep, tail = text.rpartition(match)
678+
if sep == "":
679+
return text # old not found
680+
return head + replacement + tail
681+
674682
def is_incomplete_utf8_sequence(byte_seq): #note, this will only flag INCOMPLETE sequences, corrupted ones will be ignored.
675683
try:
676684
byte_seq.decode('utf-8')
@@ -2608,6 +2616,11 @@ def transform_genparams(genparams, api_format):
26082616
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
26092617
assistant_message_end = adapter_obj.get("assistant_end", "")
26102618
if isinstance(prompt, str): #needed because comfy SD uses same field name
2619+
if assistant_message_gen and assistant_message_gen!=assistant_message_start: #replace final output tag with unspaced (gen) version if exists
2620+
if prompt.rstrip().endswith("{{[OUTPUT]}}"):
2621+
prompt = replace_last_in_string(prompt,"{{[OUTPUT]}}",assistant_message_gen)
2622+
elif assistant_message_start and prompt.rstrip().endswith(assistant_message_start):
2623+
prompt = replace_last_in_string(prompt, assistant_message_start, assistant_message_gen)
26112624
if "{{[INPUT_END]}}" in prompt or "{{[OUTPUT_END]}}" in prompt:
26122625
prompt = prompt.replace("{{[INPUT]}}", user_message_start)
26132626
prompt = prompt.replace("{{[OUTPUT]}}", assistant_message_start)

0 commit comments

Comments
 (0)