@@ -492,6 +492,67 @@ def test_causal_lm_forward_with_labels(self, config, backend):
492492 assert output .loss is not None
493493 assert output .loss .ndim == 0 # scalar loss
494494
495+ def test_causal_lm_output_hidden_states (self , config , backend ):
496+ """Test output_hidden_states parameter controls hidden state return."""
497+ from nemo_automodel .components .models .nemotron_v3 .model import NemotronHForCausalLM
498+
499+ model = NemotronHForCausalLM (config , backend = backend )
500+ model = model .to (torch .bfloat16 )
501+
502+ batch_size , seq_len = 2 , 8
503+ input_ids = torch .randint (0 , config .vocab_size , (batch_size , seq_len ))
504+
505+ # Default: hidden_states should be None
506+ output = model (input_ids )
507+ assert output .hidden_states is None
508+
509+ # Enabled: returns a tuple with one tensor of correct shape/dtype
510+ output = model (input_ids , output_hidden_states = True )
511+ assert isinstance (output .hidden_states , tuple )
512+ assert len (output .hidden_states ) == 1
513+ assert output .hidden_states [0 ].shape == (batch_size , seq_len , config .hidden_size )
514+ assert output .hidden_states [0 ].dtype == torch .bfloat16
515+
516+ # return_dict accepted without error (API compatibility)
517+ output = model (input_ids , output_hidden_states = True , return_dict = True )
518+ assert output .hidden_states is not None
519+
520+ def test_causal_lm_hidden_states_config_and_override (self , config , backend ):
521+ """Test config.output_hidden_states and explicit parameter override."""
522+ from nemo_automodel .components .models .nemotron_v3 .model import NemotronHForCausalLM
523+
524+ config .output_hidden_states = True
525+ model = NemotronHForCausalLM (config , backend = backend )
526+ model = model .to (torch .bfloat16 )
527+
528+ input_ids = torch .randint (0 , config .vocab_size , (2 , 8 ))
529+
530+ # Config enables hidden states when parameter is not passed
531+ output = model (input_ids )
532+ assert output .hidden_states is not None
533+ assert isinstance (output .hidden_states , tuple )
534+
535+ # Explicit False overrides config
536+ output = model (input_ids , output_hidden_states = False )
537+ assert output .hidden_states is None
538+
539+ def test_causal_lm_hidden_states_with_labels (self , config , backend ):
540+ """Test hidden states returned alongside loss computation."""
541+ from nemo_automodel .components .models .nemotron_v3 .model import NemotronHForCausalLM
542+
543+ model = NemotronHForCausalLM (config , backend = backend )
544+ model = model .to (torch .bfloat16 )
545+
546+ batch_size , seq_len = 2 , 8
547+ input_ids = torch .randint (0 , config .vocab_size , (batch_size , seq_len ))
548+ labels = torch .randint (0 , config .vocab_size , (batch_size , seq_len ))
549+
550+ output = model (input_ids , labels = labels , output_hidden_states = True )
551+ assert output .loss is not None
552+ assert output .loss .ndim == 0
553+ assert output .hidden_states is not None
554+ assert output .hidden_states [0 ].shape == (batch_size , seq_len , config .hidden_size )
555+
495556 def test_causal_lm_prepare_inputs_for_generation (self , config , backend ):
496557 """Test prepare_inputs_for_generation returns full sequence."""
497558 from nemo_automodel .components .models .nemotron_v3 .model import NemotronHForCausalLM
0 commit comments