@@ -671,6 +671,14 @@ def tryparsefloat(value,fallback):
671671 except ValueError :
672672 return fallback
673673
674+ def replace_last_in_string (text : str , match : str , replacement : str ) -> str :
675+ if match == "" :
676+ return text
677+ head , sep , tail = text .rpartition (match )
678+ if sep == "" :
679+ return text # old not found
680+ return head + replacement + tail
681+
674682def is_incomplete_utf8_sequence (byte_seq ): #note, this will only flag INCOMPLETE sequences, corrupted ones will be ignored.
675683 try :
676684 byte_seq .decode ('utf-8' )
@@ -2608,6 +2616,11 @@ def transform_genparams(genparams, api_format):
26082616 assistant_message_start = adapter_obj .get ("assistant_start" , "\n ### Response:\n " )
26092617 assistant_message_end = adapter_obj .get ("assistant_end" , "" )
26102618 if isinstance (prompt , str ): #needed because comfy SD uses same field name
2619+ if assistant_message_gen and assistant_message_gen != assistant_message_start : #replace final output tag with unspaced (gen) version if exists
2620+ if prompt .rstrip ().endswith ("{{[OUTPUT]}}" ):
2621+ prompt = replace_last_in_string (prompt ,"{{[OUTPUT]}}" ,assistant_message_gen )
2622+ elif assistant_message_start and prompt .rstrip ().endswith (assistant_message_start ):
2623+ prompt = replace_last_in_string (prompt , assistant_message_start , assistant_message_gen )
26112624 if "{{[INPUT_END]}}" in prompt or "{{[OUTPUT_END]}}" in prompt :
26122625 prompt = prompt .replace ("{{[INPUT]}}" , user_message_start )
26132626 prompt = prompt .replace ("{{[OUTPUT]}}" , assistant_message_start )
0 commit comments