@@ -306,6 +306,8 @@ def forward(
306306 cache_position : Optional [torch .LongTensor ] = None ,
307307 position_ids : Optional [torch .LongTensor ] = None ,
308308 logits_to_keep : Union [int , torch .Tensor ] = 0 ,
309+ output_hidden_states : Optional [bool ] = None ,
310+ return_dict : Optional [bool ] = None ,
309311 ** kwargs : Any ,
310312 ) -> CausalLMOutputWithPast :
311313 """Forward pass with optional loss computation.
@@ -322,13 +324,20 @@ def forward(
322324 position_ids: Unused – accepted for API compatibility with GenerationMixin.
323325 logits_to_keep: If > 0, only compute logits for the last ``logits_to_keep``
324326 token positions (avoids materialising the full logit matrix during generation).
327+ output_hidden_states: Whether to return hidden states
328+ return_dict: Accepted for API compatibility (always returns CausalLMOutputWithPast)
325329 **kwargs: Additional arguments forwarded to the base model.
326330
327331 Returns:
328332 :class:`~transformers.modeling_outputs.CausalLMOutputWithPast` with
329333 ``logits`` (float32, ``[batch_size, seq_len, vocab_size]``), optional
330- ``loss``, and ``past_key_values ``.
334+ ``loss``, ``past_key_values``, and ``hidden_states ``.
331335 """
336+ output_hidden_states = (
337+ output_hidden_states if output_hidden_states is not None
338+ else getattr (self .config , 'output_hidden_states' , False )
339+ )
340+
332341 # Forward through base model
333342 hidden_states = self .model (
334343 input_ids ,
@@ -366,7 +375,7 @@ def forward(
366375 loss = loss ,
367376 logits = logits ,
368377 past_key_values = past_key_values if use_cache else None ,
369- hidden_states = None ,
378+ hidden_states = ( hidden_states ,) if output_hidden_states else None ,
370379 attentions = None ,
371380 )
372381
0 commit comments