@@ -20,9 +20,17 @@ class ORPOTrainer(PushToMsHubMixin, SwiftMixin, HFORPOTrainer):
2020 def __init__ (self , * args , template : Template , test_oom_error = False , ** kwargs ):
2121 self .template = template
2222 is_vision = kwargs .pop ('is_vision' )
23+ self .keys = []
2324 super ().__init__ (* args , ** kwargs )
25+ self .train_dataset = self .train_dataset .filter (lambda x : x ['prompt_input_ids' ] is not None )
26+ if self .eval_dataset is not None :
27+ self .eval_dataset = self .eval_dataset .filter (lambda x : x ['prompt_input_ids' ] is not None )
2428 train_ds_info = self .stat_dataset (self .train_dataset , self .is_encoder_decoder )
25- val_ds_info = self .stat_dataset (self .eval_dataset , self .is_encoder_decoder )
29+ if self .eval_dataset is not None :
30+ val_ds_info = self .stat_dataset (self .eval_dataset , self .is_encoder_decoder )
31+ self .dataset_info = {'train_dataset' : train_ds_info , 'val_dataset' : val_ds_info }
32+ else :
33+ self .dataset_info = {'train_dataset' : train_ds_info }
2634 self .dataset_info = {'train_dataset' : train_ds_info , 'val_dataset' : val_ds_info }
2735 if test_oom_error :
2836 self .train_dataset = sort_by_max_length (self .train_dataset , 20000 )
@@ -51,6 +59,10 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None)
5159 prompt ['response' ] = None
5260 prompt_tokens = self .template .encode (prompt )[0 ]
5361
62+ # Skip examples that do not contain 'input_ids'
63+ if 'input_ids' not in prompt_tokens :
64+ return {k : None for k in self .keys }
65+
5466 # resolve conflict in data_collator when labels are None, pop it afterwards
5567 prompt_tokens ['labels' ] = prompt_tokens ['input_ids' ]
5668 # Batching image-related information for paired response using template
@@ -168,7 +180,8 @@ def tokenize_row(self, feature, model: Union[PreTrainedModel, nn.Module] = None)
168180 labels = torch .tensor (batch ['chosen_labels' ]))
169181
170182 batch .update (prompt_tokens )
171-
183+ if not self .keys :
184+ self .keys = (list (batch .keys ()))
172185 return batch
173186
174187 def concatenated_forward (
@@ -214,7 +227,7 @@ def concatenated_forward(
214227 model_kwargs ['output_router_logits' ] = True
215228
216229 outputs = model (
217- concatenated_batch ['concatenated_input_ids' ],
230+ input_ids = concatenated_batch ['concatenated_input_ids' ],
218231 attention_mask = concatenated_batch ['concatenated_attention_mask' ],
219232 use_cache = False ,
220233 ** model_kwargs ,
0 commit comments