@@ -177,6 +177,7 @@ def _init_template(self,
177177 self .max_length = max_length
178178 self .truncation_strategy = truncation_strategy
179179 self .model = kwargs .get ('model' , None )
180+ self .use_loss_scale = kwargs .get ('use_loss_scale' , True )
180181 for key in [
181182 'prefix' , 'prompt' , 'chat_sep' , 'suffix' , 'prefix_has_system'
182183 ]:
@@ -207,6 +208,8 @@ def encode(
207208 system = None
208209 else :
209210 assert self .prefix_has_system is not None , 'The template does not support `system`.'
211+ if query is None :
212+ query = ''
210213 inputs , tokenizer_kwargs = self ._encode (query , response , history ,
211214 system ,
212215 self .truncation_strategy )
@@ -233,7 +236,8 @@ def _concat_context_list(
233236 if isinstance (context , str ):
234237 if '{{RESPONSE}}' == context :
235238 assert response is not None
236- content_part , weight_part = calculate_loss_scale (response )
239+ content_part , weight_part = calculate_loss_scale (
240+ response , self .use_loss_scale )
237241 res_context_list .extend (content_part )
238242 compute_loss_idx .extend (weight_part )
239243 continue
@@ -330,7 +334,7 @@ def _encode(
330334 # last response
331335 context_list .append ('{{RESPONSE}}' )
332336 context_list += self .suffix
333- if q is not None :
337+ if q or r :
334338 self ._concat_context_list (
335339 context_list ,
336340 res_context_list ,
@@ -457,7 +461,7 @@ def register_template(template_type: str,
457461class DefaultGenerationTemplate (Template ):
458462
459463 def __init__ (self ):
460- return super ().__init__ ([], ['{{QUERY}}' ], None , [['eos_token_id' ]])
464+ super ().__init__ ([], ['{{QUERY}}' ], None , [['eos_token_id' ]])
461465
462466
463467register_template (TemplateType .default_generation , DefaultGenerationTemplate ())
0 commit comments