@@ -118,8 +118,9 @@ def loss_scale_sp_func(outputs, labels, loss_scale=None, num_items_in_batch=None
118118 else :
119119 logits = outputs
120120 device = logits .device
121+ if labels .shape [1 ] > logits .shape [1 ]:
122+ _ , _ , labels , _ , _ , loss_scale = ulysses .pad_and_split_inputs (None , None , labels , None , None , loss_scale )
121123 logits = logits .view (- 1 , logits .shape [- 1 ])
122- _ , _ , labels , _ , _ , loss_scale = ulysses .pad_and_split_inputs (None , None , labels , None , None , loss_scale )
123124
124125 labels = labels .flatten ().to (device )
125126 sploss_parallel_size = int (os .environ .get ('CELOSS_PARALLEL_SIZE' , '0' ))
@@ -142,7 +143,7 @@ def loss_scale_sp_func(outputs, labels, loss_scale=None, num_items_in_batch=None
142143
143144
144145@profiling_decorator
145- def _prepare_inputs (self , generation_batch ):
146+ def _prepare_inputs_grpo (self , generation_batch ):
146147 ulysses = self .ulysses
147148 mode = 'train' if self .model .training else 'eval'
148149 if mode == 'train' :
@@ -159,6 +160,14 @@ def _prepare_inputs(self, generation_batch):
159160 return inputs
160161
161162
163+ def _prepare_inputs (self , inputs , ulysses ):
164+ if 'labels' in inputs :
165+ labels = inputs ['labels' ]
166+ _ , _ , labels , _ , _ , _ = ulysses .pad_and_split_inputs (None , None , labels , None , None , None )
167+ inputs ['labels' ] = labels
168+ return self ._origin_prepare_inputs (inputs )
169+
170+
162171def old_policy (self ):
163172 ulysses = self .ulysses
164173 # changes: `* ulysses.sp_world_size`
@@ -171,7 +180,8 @@ def get_per_token_logps(self,
171180 logits : torch .FloatTensor ,
172181 labels : torch .LongTensor ,
173182 ulysses = None ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
174- _ , _ , labels , _ , _ , _ = ulysses .pad_and_split_inputs (None , None , labels , None , None , None )
183+ if labels .shape [1 ] > logits .shape [1 ]:
184+ _ , _ , labels , _ , _ , _ = ulysses .pad_and_split_inputs (None , None , labels , None , None , None )
175185 loss_mask = labels != self .label_pad_token_id
176186 labels = labels .clone () # No need to shift, pad and split has shifted the inputs.
177187 labels [~ loss_mask ] = 0
@@ -823,9 +833,13 @@ def prepare_trainer(self, trainer):
823833
824834 trainer .ulysses = self
825835 if trainer .__class__ .__name__ == 'Seq2SeqTrainer' :
836+ trainer ._origin_prepare_inputs = trainer ._prepare_inputs
837+ trainer ._prepare_inputs = MethodType (partial (_prepare_inputs , ulysses = self ), trainer )
826838 trainer .compute_loss_func = partial (loss_scale_sp_func , ulysses = self )
827839
828840 elif trainer .__class__ .__name__ == 'DPOTrainer' :
841+ trainer ._origin_prepare_inputs = trainer ._prepare_inputs
842+ trainer ._prepare_inputs = MethodType (partial (_prepare_inputs , ulysses = self ), trainer )
829843 trainer .get_per_token_logps = MethodType (partial (get_per_token_logps , ulysses = self ), trainer )
830844
831845 def rlhf_loss_scale_sp_func (_ , * args , ** kwargs ):
@@ -838,7 +852,7 @@ def rlhf_loss_scale_sp_func(_, *args, **kwargs):
838852 trainer .ulysses = self
839853 trainer .args .gradient_accumulation_steps = trainer .args .gradient_accumulation_steps * self .sp_world_size
840854 trainer .old_policy = MethodType (old_policy , trainer )
841- trainer ._prepare_inputs = MethodType (_prepare_inputs , trainer )
855+ trainer ._prepare_inputs = MethodType (_prepare_inputs_grpo , trainer )
842856 trainer ._get_per_token_logps = MethodType (_get_per_token_logps , trainer )
843857 trainer .split_by_mini_batches = MethodType (split_by_mini_batches , trainer )
844858
@@ -852,7 +866,8 @@ def compute_acc(preds, labels, *args, **kwargs) -> Dict[str, List[float]]:
852866 preds = torch .from_numpy (preds ).to (get_current_device ())
853867 if isinstance (labels , np .ndarray ):
854868 labels = torch .from_numpy (labels ).to (get_current_device ())
855- _ , _ , labels , _ , _ , _ = self .pad_and_split_inputs (None , None , labels , None , None , None )
869+ if labels .shape [1 ] > preds .shape [1 ]:
870+ _ , _ , labels , _ , _ , _ = self .pad_and_split_inputs (None , None , labels , None , None , None )
856871 shape0 = preds .shape [0 ]
857872 preds_output = torch .empty ((shape0 * self .sp_world_size , preds .shape [1 ]),
858873 dtype = preds .dtype ,
0 commit comments