Skip to content

Commit 54182cc

Browse files
desh2608claude
andcommitted
test: add unit tests for output_hidden_states in NemotronHForCausalLM
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Desh Raj <r.desh26@gmail.com>
1 parent 3b7e33a commit 54182cc

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

tests/unit_tests/models/nemotron_v3/test_nemotron_v3_model.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,67 @@ def test_causal_lm_forward_with_labels(self, config, backend):
492492
assert output.loss is not None
493493
assert output.loss.ndim == 0 # scalar loss
494494

495+
def test_causal_lm_output_hidden_states(self, config, backend):
496+
"""Test output_hidden_states parameter controls hidden state return."""
497+
from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM
498+
499+
model = NemotronHForCausalLM(config, backend=backend)
500+
model = model.to(torch.bfloat16)
501+
502+
batch_size, seq_len = 2, 8
503+
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
504+
505+
# Default: hidden_states should be None
506+
output = model(input_ids)
507+
assert output.hidden_states is None
508+
509+
# Enabled: returns a tuple with one tensor of correct shape/dtype
510+
output = model(input_ids, output_hidden_states=True)
511+
assert isinstance(output.hidden_states, tuple)
512+
assert len(output.hidden_states) == 1
513+
assert output.hidden_states[0].shape == (batch_size, seq_len, config.hidden_size)
514+
assert output.hidden_states[0].dtype == torch.bfloat16
515+
516+
# return_dict accepted without error (API compatibility)
517+
output = model(input_ids, output_hidden_states=True, return_dict=True)
518+
assert output.hidden_states is not None
519+
520+
def test_causal_lm_hidden_states_config_and_override(self, config, backend):
521+
"""Test config.output_hidden_states and explicit parameter override."""
522+
from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM
523+
524+
config.output_hidden_states = True
525+
model = NemotronHForCausalLM(config, backend=backend)
526+
model = model.to(torch.bfloat16)
527+
528+
input_ids = torch.randint(0, config.vocab_size, (2, 8))
529+
530+
# Config enables hidden states when parameter is not passed
531+
output = model(input_ids)
532+
assert output.hidden_states is not None
533+
assert isinstance(output.hidden_states, tuple)
534+
535+
# Explicit False overrides config
536+
output = model(input_ids, output_hidden_states=False)
537+
assert output.hidden_states is None
538+
539+
def test_causal_lm_hidden_states_with_labels(self, config, backend):
540+
"""Test hidden states returned alongside loss computation."""
541+
from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM
542+
543+
model = NemotronHForCausalLM(config, backend=backend)
544+
model = model.to(torch.bfloat16)
545+
546+
batch_size, seq_len = 2, 8
547+
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
548+
labels = torch.randint(0, config.vocab_size, (batch_size, seq_len))
549+
550+
output = model(input_ids, labels=labels, output_hidden_states=True)
551+
assert output.loss is not None
552+
assert output.loss.ndim == 0
553+
assert output.hidden_states is not None
554+
assert output.hidden_states[0].shape == (batch_size, seq_len, config.hidden_size)
555+
495556
def test_causal_lm_prepare_inputs_for_generation(self, config, backend):
496557
"""Test prepare_inputs_for_generation returns full sequence."""
497558
from nemo_automodel.components.models.nemotron_v3.model import NemotronHForCausalLM

0 commit comments

Comments
 (0)