Skip to content

Commit 3b7e33a

Browse files
committed
Add support for returning hidden states from NemotronHForCausalLM
Signed-off-by: Desh Raj <r.desh26@gmail.com>
1 parent e8ef32e commit 3b7e33a

File tree

1 file changed

+11
-2
lines changed
  • nemo_automodel/components/models/nemotron_v3

1 file changed

+11
-2
lines changed

nemo_automodel/components/models/nemotron_v3/model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,8 @@ def forward(
306306
cache_position: Optional[torch.LongTensor] = None,
307307
position_ids: Optional[torch.LongTensor] = None,
308308
logits_to_keep: Union[int, torch.Tensor] = 0,
309+
output_hidden_states: Optional[bool] = None,
310+
return_dict: Optional[bool] = None,
309311
**kwargs: Any,
310312
) -> CausalLMOutputWithPast:
311313
"""Forward pass with optional loss computation.
@@ -322,13 +324,20 @@ def forward(
322324
position_ids: Unused – accepted for API compatibility with GenerationMixin.
323325
logits_to_keep: If > 0, only compute logits for the last ``logits_to_keep``
324326
token positions (avoids materialising the full logit matrix during generation).
327+
output_hidden_states: Whether to return hidden states
328+
return_dict: Accepted for API compatibility (always returns CausalLMOutputWithPast)
325329
**kwargs: Additional arguments forwarded to the base model.
326330
327331
Returns:
328332
:class:`~transformers.modeling_outputs.CausalLMOutputWithPast` with
329333
``logits`` (float32, ``[batch_size, seq_len, vocab_size]``), optional
330-
``loss``, and ``past_key_values``.
334+
``loss``, ``past_key_values``, and ``hidden_states``.
331335
"""
336+
output_hidden_states = (
337+
output_hidden_states if output_hidden_states is not None
338+
else getattr(self.config, 'output_hidden_states', False)
339+
)
340+
332341
# Forward through base model
333342
hidden_states = self.model(
334343
input_ids,
@@ -366,7 +375,7 @@ def forward(
366375
loss=loss,
367376
logits=logits,
368377
past_key_values=past_key_values if use_cache else None,
369-
hidden_states=None,
378+
hidden_states=(hidden_states,) if output_hidden_states else None,
370379
attentions=None,
371380
)
372381

0 commit comments

Comments
 (0)