@@ -237,6 +237,7 @@ def prediction_step(
237
237
** gen_kwargs ,
238
238
) -> Tuple [Optional [float ], Optional [torch .Tensor ], Optional [torch .Tensor ]]:
239
239
if not self .args .predict_with_generate or prediction_loss_only :
240
+ inputs ['_position_ids' ] = inputs .get ('position_ids' )
240
241
with self .template .forward_context (self .model , inputs ):
241
242
return super ().prediction_step (
242
243
model , inputs , prediction_loss_only = prediction_loss_only , ignore_keys = ignore_keys )
@@ -277,15 +278,19 @@ def _prepare_inputs(self, inputs):
277
278
compute_loss_func = get_loss_func ('loss_scale' )
278
279
279
280
sample_channels = inputs .pop ('channel' , None )
280
- if sample_channels is not None and self .args .channels is not None :
281
+ position_ids = inputs .pop ('_position_ids' , None )
282
+ if self .args .channels is not None :
283
+ assert sample_channels is not None , f'sample_channels: { sample_channels } '
281
284
state = self .state
282
285
setattr (state , 'local_step' , getattr (state , 'local_step' , 0 ))
283
286
setattr (state , 'ch_loss_steps' , getattr (state , 'ch_loss_steps' , {}))
284
287
285
288
loss_kwargs ['sample_channels' ] = sample_channels
286
289
loss_kwargs ['trainer' ] = self
287
- if inputs .get ('position_ids' ) is not None :
288
- loss_kwargs ['position_ids' ] = inputs ['position_ids' ]
290
+ if position_ids is None :
291
+ position_ids = inputs .get ('position_ids' )
292
+ if position_ids is not None :
293
+ loss_kwargs ['position_ids' ] = position_ids
289
294
290
295
use_logits_to_keep = self .get_use_logits_to_keep ('labels' in inputs and self .label_smoother is None
291
296
and compute_loss_func is None )
@@ -352,5 +357,6 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
352
357
return (loss , outputs ) if return_outputs else loss
353
358
354
359
def training_step (self , model , inputs , * args , ** kwargs ):
360
+ inputs ['_position_ids' ] = inputs .get ('position_ids' )
355
361
with self .template .forward_context (self .model , inputs ):
356
362
return super ().training_step (model , inputs , * args , ** kwargs )
0 commit comments