@@ -104,6 +104,15 @@ def _has_system(prefix: Prompt) -> bool:
104104 return False
105105
106106
107+ def _replace_system (prefix : Prompt ) -> Prompt :
108+ res = []
109+ for p in prefix :
110+ if '{{SYSTEM}}' in p :
111+ p = p .replace ('{{SYSTEM}}' , '' )
112+ res .append (p )
113+ return res
114+
115+
107116class Template :
108117
109118 def __init__ (self ,
@@ -113,11 +122,13 @@ def __init__(self,
113122 suffix : Prompt ,
114123 default_system : Optional [str ] = None ,
115124 prefix_has_system : Optional [Prompt ] = None ) -> None :
116- self .prefix = prefix
125+ if default_system == '' :
126+ default_system = None
117127 if _has_system (prefix ):
118128 assert prefix_has_system is None , 'The prefix already contains {{SYSTEM}}.'
119- assert default_system is not None , 'You need to provide the `default_system`.'
120129 prefix_has_system = prefix
130+ prefix = _replace_system (prefix )
131+ self .prefix = prefix
121132 self .prefix_has_system = prefix_has_system
122133 if self .prefix_has_system is None :
123134 assert default_system is None , 'The template does not support `system`.'
@@ -157,7 +168,10 @@ def _init_template(self,
157168 assert self ._is_init is False , 'The template has been initialized.'
158169 self ._is_init = True
159170 self .tokenizer = tokenizer
160- if default_system is not None :
171+ # if default_system is None. not change self.default_system
172+ if default_system == '' :
173+ self .default_system = None
174+ elif default_system is not None :
161175 assert self .prefix_has_system is not None , 'The template does not support `system`.'
162176 self .default_system = default_system
163177 self .max_length = max_length
@@ -189,6 +203,8 @@ def encode(
189203 if system is None :
190204 if self .use_default_system :
191205 system = self .default_system
206+ elif system == '' :
207+ system = None
192208 else :
193209 assert self .prefix_has_system is not None , 'The template does not support `system`.'
194210 inputs , tokenizer_kwargs = self ._encode (query , response , history ,
@@ -299,7 +315,6 @@ def _encode(
299315 res_context_list : List [Context ] = []
300316 compute_loss_idx : List [float ] = []
301317 if system is None :
302- assert self .prefix != self .prefix_has_system , f'template.prefix: { self .prefix } '
303318 prefix = self .prefix
304319 else :
305320 prefix = self .prefix_has_system
@@ -586,22 +601,21 @@ def data_collator(self,
586601
587602register_template (
588603 TemplateType .yi_vl ,
589- YiVLTemplate (['{{SYSTEM}}\n \n ' ],
590- ['### Human: ' , [- 200 ], '\n {{QUERY}}\n ### Assistant:\n ' ],
591- ['\n ' ], ['\n ###' ], yi_vl_default_system ),
604+ YiVLTemplate ([], ['### Human: ' , [- 200 ], '\n {{QUERY}}\n ### Assistant:\n ' ],
605+ ['\n ' ], ['\n ###' ], yi_vl_default_system , ['{{SYSTEM}}\n \n ' ]),
592606 use_model = True ,
593607 infer_media_type = 'round' ,
594608 lazy_tokenize = True )
595609
596610register_template (
597611 TemplateType .baichuan ,
598612 Template (['{{SYSTEM}}' ], [[195 ], '{{QUERY}}' , [196 ]], [],
599- [['eos_token_id' ]], '' ))
613+ [['eos_token_id' ]]))
600614register_template (
601615 TemplateType .chatglm2 ,
602616 Template ([[64790 , 64792 ], '{{SYSTEM}}' ],
603617 ['[Round {{ROUND1}}]\n \n 问:{{QUERY}}\n \n 答:' ], ['\n \n ' ],
604- [['eos_token_id' ]], '' ))
618+ [['eos_token_id' ]]))
605619
606620register_template (
607621 TemplateType .chatglm_generation ,
@@ -818,29 +832,29 @@ def get_generate_ids(generate_ids: Tensor,
818832register_template (
819833 TemplateType .xverse ,
820834 Template (['{{SYSTEM}}' ], ['Human: {{QUERY}}\n \n Assistant: ' ],
821- [['eos_token_id' ]], [['eos_token_id' ]], '' ))
835+ [['eos_token_id' ]], [['eos_token_id' ]]))
822836register_template (TemplateType .yuan ,
823837 Template ([], ['{{QUERY}}<sep>' ], None , [['eos_token_id' ]]))
824838register_template (
825839 TemplateType .ziya ,
826840 Template ([['bos_token_id' ], '{{SYSTEM}}' ], ['<human>:{{QUERY}}\n <bot>:' ],
827- ['\n ' ], [['eos_token_id' ]], '' ))
841+ ['\n ' ], [['eos_token_id' ]]))
828842
829843register_template (
830844 TemplateType .skywork ,
831845 Template (['<s>{{SYSTEM}}' ], ['</s><s>[USER]{{QUERY}}[SEP][BOT]' ], None ,
832- ['[SEP]</s>' ], '' ))
846+ ['[SEP]</s>' ]))
833847
834848register_template (
835849 TemplateType .bluelm ,
836850 Template ([['bos_token_id' ], '{{SYSTEM}}' ], ['[|Human|]:{{QUERY}}[|AI|]:' ],
837- [], [['eos_token_id' ]], '' ))
851+ [], [['eos_token_id' ]]))
838852
839853register_template (
840854 TemplateType .codefuse_codellama ,
841855 Template (['{{SYSTEM}}' ], [
842856 '<|role_start|>human<|role_end|>{{QUERY}}<|role_start|>bot<|role_end|>'
843- ], [], [['eos_token_id' ]], '' ))
857+ ], [], [['eos_token_id' ]]))
844858
845859register_template (
846860 TemplateType .codefuse ,
@@ -867,12 +881,12 @@ def get_generate_ids(generate_ids: Tensor,
867881register_template (
868882 TemplateType .sus ,
869883 Template (['{{SYSTEM}}' ], ['### Human: {{QUERY}}\n \n ### Assistant: ' ],
870- ['<|endoftext|>' ], ['<|endoftext|>' ], '' ))
884+ ['<|endoftext|>' ], ['<|endoftext|>' ]))
871885
872886register_template (
873887 TemplateType .orion ,
874888 Template (['<s>{{SYSTEM}}' ], ['Human: {{QUERY}}\n \n Assistant: </s>' ],
875- ['</s>' ], ['</s>' ], '' ))
889+ ['</s>' ], ['</s>' ]))
876890
877891
878892class CogAgentTemplate (Template ):
@@ -939,7 +953,7 @@ def data_collator(self,
939953
940954register_template (
941955 TemplateType .openbmb ,
942- Template (['<s>{{SYSTEM}}' ], ['<用户>{{QUERY}}<AI>' ], [], ['</s>' ], '' ))
956+ Template (['<s>{{SYSTEM}}' ], ['<用户>{{QUERY}}<AI>' ], [], ['</s>' ]))
943957
944958
945959def get_template (
0 commit comments