@@ -30,6 +30,7 @@ def lce_forward_deprecated(
3030 output_hidden_states : Optional [bool ] = None ,
3131 return_dict : Optional [bool ] = None ,
3232 cache_position : Optional [torch .LongTensor ] = None ,
33+ ** kwargs ,
3334) -> Union [Tuple , CausalLMOutputWithPast ]:
3435 r"""
3536 Args:
@@ -76,6 +77,7 @@ def lce_forward_deprecated(
7677 output_hidden_states = output_hidden_states ,
7778 return_dict = return_dict ,
7879 cache_position = cache_position ,
80+ ** kwargs ,
7981 )
8082
8183 hidden_states = outputs [0 ]
@@ -147,7 +149,7 @@ def lce_forward(
147149 cache_position : Optional [torch .LongTensor ] = None ,
148150 logits_to_keep : Union [int , torch .Tensor ] = 0 ,
149151 skip_logits : Optional [bool ] = None ,
150- ** loss_kwargs ,
152+ ** kwargs ,
151153) -> Union [Tuple , CausalLMOutputWithPast ]:
152154 r"""
153155 Args:
@@ -204,14 +206,15 @@ def lce_forward(
204206 output_hidden_states = output_hidden_states ,
205207 return_dict = return_dict ,
206208 cache_position = cache_position ,
209+ ** kwargs ,
207210 )
208211
209212 hidden_states = outputs [0 ]
210213 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
211214 slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
212215 kept_hidden_states = hidden_states [:, slice_indices , :]
213216
214- shift_labels = loss_kwargs .pop ("shift_labels" , None )
217+ shift_labels = kwargs .pop ("shift_labels" , None )
215218 logits = None
216219 loss = None
217220
@@ -230,7 +233,7 @@ def lce_forward(
230233 shift_labels = shift_labels ,
231234 hidden_size = self .config .hidden_size ,
232235 final_logit_softcapping = self .config .final_logit_softcapping ,
233- ** loss_kwargs ,
236+ ** kwargs ,
234237 )
235238
236239 else :
@@ -242,7 +245,7 @@ def lce_forward(
242245
243246 loss = None
244247 if labels is not None :
245- loss = self .loss_function (logits , labels , self .vocab_size , ** loss_kwargs )
248+ loss = self .loss_function (logits , labels , self .vocab_size , ** kwargs )
246249
247250 if not return_dict :
248251 output = (logits ,) + outputs [1 :]
0 commit comments