@@ -112,8 +112,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
112112 loss_scale = encoded .get ('loss_scale' , None )
113113 idx_list = findall (input_ids , self .boi_token_id )
114114 img_tokens = self ._tokenize (self .processor .full_image_sequence )
115- input_ids , labels = self ._extend_tokens (input_ids , labels , idx_list , lambda _ : img_tokens )
116- loss_scale = self . _extend_loss_scale ( loss_scale , idx_list , lambda _ : img_tokens )
115+ input_ids , labels , loss_scale = self ._extend_tokens (input_ids , labels , loss_scale , idx_list ,
116+ lambda _ : img_tokens )
117117
118118 # TODO: customize
119119 processor_kwargs = Gemma3ProcessorKwargs ._defaults ['images_kwargs' ]
@@ -171,8 +171,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
171171 if inputs .images :
172172 idx_list = findall (input_ids , self .boi_token_id )
173173 img_tokens = self ._tokenize (processor .full_image_sequence )
174- input_ids , labels = self ._extend_tokens (input_ids , labels , idx_list , lambda _ : img_tokens )
175- loss_scale = self . _extend_loss_scale ( loss_scale , idx_list , lambda _ : img_tokens )
174+ input_ids , labels , loss_scale = self ._extend_tokens (input_ids , labels , loss_scale , idx_list ,
175+ lambda _ : img_tokens )
176176
177177 # Process images
178178 processor_kwargs = Gemma3nProcessorKwargs ._defaults .get ('images_kwargs' , {})
@@ -188,8 +188,8 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
188188 if audio_idx_list :
189189 # Get audio token sequence from processor
190190 audio_tokens = self ._tokenize (processor .full_audio_sequence )
191- input_ids , labels = self ._extend_tokens (input_ids , labels , audio_idx_list , lambda _ : audio_tokens )
192- loss_scale = self . _extend_loss_scale ( loss_scale , audio_idx_list , lambda _ : audio_tokens )
191+ input_ids , labels , loss_scale = self ._extend_tokens (input_ids , labels , loss_scale , audio_idx_list ,
192+ lambda _ : audio_tokens )
193193
194194 # Process audios
195195 processor_kwargs = Gemma3nProcessorKwargs ._defaults .get ('audio_kwargs' , {})
0 commit comments