@@ -408,3 +408,391 @@ def test_init_weights(self, mock_trunc_normal):
408408
409409 # Check that norm reset_parameters was called
410410 assert mock_norm .reset_parameters .call_count >= 2 # q_a_layernorm, kv_a_layernorm
411+
412+
413+ class TestDeepseekV32IndexerInitWeights :
414+ def create_mock_config (self , ** overrides ):
415+ config = Mock (spec = DeepseekV32Config )
416+ config .num_attention_heads = 8
417+ config .hidden_size = 256
418+ config .q_lora_rank = 128
419+ config .index_n_heads = 4
420+ config .index_head_dim = 32
421+ config .index_topk = 16
422+ config .qk_rope_head_dim = 16
423+
424+ for key , value in overrides .items ():
425+ setattr (config , key , value )
426+ return config
427+
428+ @patch ("torch.nn.init.trunc_normal_" )
429+ @patch ("nemo_automodel.components.models.deepseek_v32.layers.initialize_linear_module" )
430+ def test_indexer_init_weights (self , mock_init_linear , mock_trunc_normal ):
431+ """Test Indexer weight initialization directly."""
432+ config = self .create_mock_config ()
433+ backend = BackendConfig (attn = "sdpa" , linear = "torch" , rms_norm = "torch" )
434+
435+ mock_linear = Mock ()
436+ mock_linear .weight = torch .randn (64 , 256 )
437+ mock_init_linear .return_value = mock_linear
438+
439+ indexer = DeepseekV32Indexer (config , backend )
440+ indexer .init_weights (init_std = 0.02 )
441+
442+ # Should call trunc_normal_ for wq_b, wk, weights_proj (3 linear layers)
443+ assert mock_trunc_normal .call_count == 3
444+
445+
446+ class TestDeepseekV32IndexerForward :
447+ def create_mock_config (self , ** overrides ):
448+ config = Mock (spec = DeepseekV32Config )
449+ config .num_attention_heads = 8
450+ config .hidden_size = 64
451+ config .q_lora_rank = 32
452+ config .index_n_heads = 4
453+ config .index_head_dim = 16
454+ config .index_topk = 8
455+ config .qk_rope_head_dim = 8
456+
457+ for key , value in overrides .items ():
458+ setattr (config , key , value )
459+ return config
460+
461+ @skip_if_no_gpu
462+ def test_indexer_forward_bshd (self ):
463+ """Test Indexer forward pass with bshd format."""
464+ config = self .create_mock_config ()
465+ backend = BackendConfig (attn = "sdpa" , linear = "torch" , rms_norm = "torch" )
466+
467+ indexer = DeepseekV32Indexer (config , backend ).cuda ().to (torch .bfloat16 )
468+
469+ bsz , seq_len = 2 , 16
470+ x = torch .randn (bsz , seq_len , config .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
471+ q_resid = torch .randn (bsz , seq_len , config .q_lora_rank , device = "cuda" , dtype = torch .bfloat16 )
472+ # Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
473+ angles = torch .randn (bsz , seq_len , config .qk_rope_head_dim // 2 , device = "cuda" , dtype = torch .float32 )
474+ freqs_cis = torch .polar (torch .ones_like (angles ), angles )
475+
476+ topk_indices = indexer (x , q_resid , freqs_cis )
477+
478+ assert topk_indices .shape == (bsz , seq_len , config .index_topk )
479+ assert topk_indices .dtype == torch .int64
480+
481+ @pytest .mark .skip (reason = "thd format requires complex freqs_cis setup matching model runtime" )
482+ @skip_if_no_gpu
483+ def test_indexer_forward_thd (self ):
484+ """Test Indexer forward pass with thd format."""
485+ config = self .create_mock_config ()
486+ backend = BackendConfig (attn = "sdpa" , linear = "torch" , rms_norm = "torch" )
487+
488+ indexer = DeepseekV32Indexer (config , backend ).cuda ().to (torch .bfloat16 )
489+
490+ num_tokens = 32
491+ x = torch .randn (num_tokens , config .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
492+ q_resid = torch .randn (num_tokens , config .q_lora_rank , device = "cuda" , dtype = torch .bfloat16 )
493+ # Create complex freqs_cis for thd format [T, D/2] as complex tensor
494+ angles = torch .randn (num_tokens , config .qk_rope_head_dim // 2 , device = "cuda" , dtype = torch .float32 )
495+ freqs_cis = torch .polar (torch .ones_like (angles ), angles )
496+
497+ topk_indices = indexer (x , q_resid , freqs_cis )
498+
499+ assert topk_indices .shape == (num_tokens , config .index_topk )
500+ assert topk_indices .dtype == torch .int64
501+
502+ @skip_if_no_gpu
503+ def test_indexer_forward_with_attention_mask_bshd (self ):
504+ """Test Indexer forward pass with attention mask in bshd format."""
505+ config = self .create_mock_config ()
506+ backend = BackendConfig (attn = "sdpa" , linear = "torch" , rms_norm = "torch" )
507+
508+ indexer = DeepseekV32Indexer (config , backend ).cuda ().to (torch .bfloat16 )
509+
510+ bsz , seq_len = 2 , 16
511+ x = torch .randn (bsz , seq_len , config .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
512+ q_resid = torch .randn (bsz , seq_len , config .q_lora_rank , device = "cuda" , dtype = torch .bfloat16 )
513+ # Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
514+ angles = torch .randn (bsz , seq_len , config .qk_rope_head_dim // 2 , device = "cuda" , dtype = torch .float32 )
515+ freqs_cis = torch .polar (torch .ones_like (angles ), angles )
516+
517+ # Create causal mask
518+ attention_mask = torch .triu (
519+ torch .full ((1 , 1 , seq_len , seq_len ), float ("-inf" ), device = "cuda" ),
520+ diagonal = 1 ,
521+ )
522+
523+ topk_indices = indexer (x , q_resid , freqs_cis , attention_mask = attention_mask )
524+
525+ assert topk_indices .shape == (bsz , seq_len , config .index_topk )
526+
527+ @pytest .mark .skip (reason = "thd format requires complex freqs_cis setup matching model runtime" )
528+ @skip_if_no_gpu
529+ def test_indexer_forward_with_attention_mask_thd (self ):
530+ """Test Indexer forward pass with attention mask in thd format."""
531+ config = self .create_mock_config ()
532+ backend = BackendConfig (attn = "sdpa" , linear = "torch" , rms_norm = "torch" )
533+
534+ indexer = DeepseekV32Indexer (config , backend ).cuda ().to (torch .bfloat16 )
535+
536+ num_tokens = 32
537+ x = torch .randn (num_tokens , config .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
538+ q_resid = torch .randn (num_tokens , config .q_lora_rank , device = "cuda" , dtype = torch .bfloat16 )
539+ # Create complex freqs_cis for thd format [T, D/2] as complex tensor
540+ angles = torch .randn (num_tokens , config .qk_rope_head_dim // 2 , device = "cuda" , dtype = torch .float32 )
541+ freqs_cis = torch .polar (torch .ones_like (angles ), angles )
542+
543+ # Create causal mask for thd format
544+ attention_mask = torch .triu (
545+ torch .full ((1 , 1 , num_tokens , num_tokens ), float ("-inf" ), device = "cuda" ),
546+ diagonal = 1 ,
547+ )
548+
549+ topk_indices = indexer (x , q_resid , freqs_cis , attention_mask = attention_mask )
550+
551+ assert topk_indices .shape == (num_tokens , config .index_topk )
552+
553+ @skip_if_no_gpu
554+ def test_indexer_forward_topk_larger_than_seq (self ):
555+ """Test Indexer forward when topk > seq_len."""
556+ config = self .create_mock_config (index_topk = 64 ) # larger than seq_len
557+ backend = BackendConfig (attn = "sdpa" , linear = "torch" , rms_norm = "torch" )
558+
559+ indexer = DeepseekV32Indexer (config , backend ).cuda ().to (torch .bfloat16 )
560+
561+ bsz , seq_len = 2 , 16 # seq_len < index_topk
562+ x = torch .randn (bsz , seq_len , config .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
563+ q_resid = torch .randn (bsz , seq_len , config .q_lora_rank , device = "cuda" , dtype = torch .bfloat16 )
564+ # Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
565+ angles = torch .randn (bsz , seq_len , config .qk_rope_head_dim // 2 , device = "cuda" , dtype = torch .float32 )
566+ freqs_cis = torch .polar (torch .ones_like (angles ), angles )
567+
568+ topk_indices = indexer (x , q_resid , freqs_cis )
569+
570+ # Should clamp to seq_len
571+ assert topk_indices .shape == (bsz , seq_len , seq_len )
572+
573+
574+ class TestDeepseekV32MLAForward :
575+ def create_mock_config (self , ** overrides ):
576+ config = Mock (spec = DeepseekV32Config )
577+ config .num_attention_heads = 4
578+ config .hidden_size = 64
579+ config .q_lora_rank = 32
580+ config .kv_lora_rank = 32
581+ config .qk_nope_head_dim = 8
582+ config .qk_rope_head_dim = 8
583+ config .qk_head_dim = 16
584+ config .v_head_dim = 16
585+ config .rope_scaling = None
586+ config .max_position_embeddings = 4096
587+ config .index_n_heads = 4
588+ config .index_head_dim = 16
589+ config .index_topk = 8
590+
591+ for key , value in overrides .items ():
592+ setattr (config , key , value )
593+ return config
594+
595+ @skip_if_no_gpu
596+ def test_mla_forward_bshd_sdpa (self ):
597+ """Test MLA forward pass with bshd format and SDPA backend."""
598+ config = self .create_mock_config ()
599+ backend = BackendConfig (attn = "sdpa" , linear = "torch" , rms_norm = "torch" )
600+
601+ mla = DeepseekV32MLA (config , backend ).cuda ().to (torch .bfloat16 )
602+
603+ bsz , seq_len = 2 , 16
604+ x = torch .randn (bsz , seq_len , config .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
605+ # Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
606+ angles = torch .randn (bsz , seq_len , config .qk_rope_head_dim // 2 , device = "cuda" , dtype = torch .float32 )
607+ freqs_cis = torch .polar (torch .ones_like (angles ), angles )
608+
609+ output = mla (x , freqs_cis )
610+
611+ assert output .shape == (bsz , seq_len , config .hidden_size )
612+ assert output .dtype == torch .bfloat16
613+
614+ @pytest .mark .skip (reason = "thd format requires complex freqs_cis setup matching model runtime" )
615+ @skip_if_no_gpu
616+ def test_mla_forward_thd_sdpa (self ):
617+ """Test MLA forward pass with thd format and SDPA backend."""
618+ config = self .create_mock_config ()
619+ backend = BackendConfig (attn = "sdpa" , linear = "torch" , rms_norm = "torch" )
620+
621+ mla = DeepseekV32MLA (config , backend ).cuda ().to (torch .bfloat16 )
622+
623+ num_tokens = 32
624+ x = torch .randn (num_tokens , config .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
625+ # Create complex freqs_cis for thd format [T, D/2] as complex tensor
626+ angles = torch .randn (num_tokens , config .qk_rope_head_dim // 2 , device = "cuda" , dtype = torch .float32 )
627+ freqs_cis = torch .polar (torch .ones_like (angles ), angles )
628+
629+ output = mla (x , freqs_cis )
630+
631+ assert output .shape == (num_tokens , config .hidden_size )
632+ assert output .dtype == torch .bfloat16
633+
634+ @skip_if_no_gpu
635+ def test_mla_forward_with_attention_mask (self ):
636+ """Test MLA forward pass with attention mask."""
637+ config = self .create_mock_config ()
638+ backend = BackendConfig (attn = "sdpa" , linear = "torch" , rms_norm = "torch" )
639+
640+ mla = DeepseekV32MLA (config , backend ).cuda ().to (torch .bfloat16 )
641+
642+ bsz , seq_len = 2 , 16
643+ x = torch .randn (bsz , seq_len , config .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
644+ # Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
645+ angles = torch .randn (bsz , seq_len , config .qk_rope_head_dim // 2 , device = "cuda" , dtype = torch .float32 )
646+ freqs_cis = torch .polar (torch .ones_like (angles ), angles )
647+
648+ # Create causal mask
649+ attention_mask = torch .triu (
650+ torch .full ((1 , 1 , seq_len , seq_len ), float ("-inf" ), device = "cuda" ),
651+ diagonal = 1 ,
652+ )
653+
654+ output = mla (x , freqs_cis , attention_mask = attention_mask )
655+
656+ assert output .shape == (bsz , seq_len , config .hidden_size )
657+
658+ @skip_te
659+ @skip_if_no_gpu
660+ def test_mla_forward_bshd_te (self ):
661+ """Test MLA forward pass with bshd format and TE backend."""
662+ config = self .create_mock_config ()
663+ backend = BackendConfig (attn = "te" , linear = "torch" , rms_norm = "torch" )
664+
665+ mla = DeepseekV32MLA (config , backend ).cuda ().to (torch .bfloat16 )
666+
667+ bsz , seq_len = 2 , 16
668+ x = torch .randn (bsz , seq_len , config .hidden_size , device = "cuda" , dtype = torch .bfloat16 )
669+ # Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
670+ angles = torch .randn (bsz , seq_len , config .qk_rope_head_dim // 2 , device = "cuda" , dtype = torch .float32 )
671+ freqs_cis = torch .polar (torch .ones_like (angles ), angles )
672+
673+ output = mla (x , freqs_cis )
674+
675+ assert output .shape == (bsz , seq_len , config .hidden_size )
676+
677+
678+ class TestBuildSparseMaskWithAttentionMask :
679+ def create_mock_config (self , ** overrides ):
680+ config = Mock (spec = DeepseekV32Config )
681+ config .num_attention_heads = 8
682+ config .hidden_size = 256
683+ config .q_lora_rank = 128
684+ config .kv_lora_rank = 64
685+ config .qk_nope_head_dim = 16
686+ config .qk_rope_head_dim = 16
687+ config .qk_head_dim = 32
688+ config .v_head_dim = 32
689+ config .rope_scaling = None
690+ config .max_position_embeddings = 4096
691+ config .index_n_heads = 4
692+ config .index_head_dim = 32
693+ config .index_topk = 16
694+
695+ for key , value in overrides .items ():
696+ setattr (config , key , value )
697+ return config
698+
699+ @patch ("nemo_automodel.components.models.deepseek_v32.layers.initialize_linear_module" )
700+ @patch ("nemo_automodel.components.models.deepseek_v32.layers.initialize_rms_norm_module" )
701+ @patch ("nemo_automodel.components.models.deepseek_v32.layers.initialize_attn_module_and_func" )
702+ def test_build_sparse_mask_combines_with_attention_mask (self , mock_init_attn , mock_init_rms , mock_init_linear ):
703+ """Test that sparse mask is combined with attention mask."""
704+ config = self .create_mock_config ()
705+ backend = BackendConfig (attn = "sdpa" , linear = "torch" , rms_norm = "torch" )
706+
707+ mock_init_linear .return_value = Mock ()
708+ mock_init_rms .return_value = Mock ()
709+ mock_init_attn .return_value = (Mock (), Mock ())
710+
711+ mla = DeepseekV32MLA (config , backend )
712+
713+ bsz , seq_len , topk = 2 , 32 , 8
714+ topk_indices = torch .randint (0 , seq_len , (bsz , seq_len , topk ))
715+
716+ # Create an attention mask
717+ attention_mask = torch .triu (
718+ torch .full ((bsz , 1 , seq_len , seq_len ), float ("-inf" )),
719+ diagonal = 1 ,
720+ )
721+
722+ sparse_mask = mla ._build_sparse_mask (
723+ topk_indices ,
724+ seq_len ,
725+ qkv_format = "bshd" ,
726+ bsz = bsz ,
727+ n_heads = 1 ,
728+ dtype = torch .float32 ,
729+ attention_mask = attention_mask ,
730+ union_across_batches = False ,
731+ )
732+
733+ # Result should combine both masks
734+ assert sparse_mask .shape == (bsz , 1 , seq_len , seq_len )
735+
736+ # Check that causal structure is preserved (upper triangle should be -inf)
737+ for b in range (bsz ):
738+ for i in range (seq_len ):
739+ for j in range (i + 1 , seq_len ):
740+ assert sparse_mask [b , 0 , i , j ] == float ("-inf" )
741+
742+
743+ class TestHadamardTransformFallback :
744+ """Test the fallback hadamard_transform implementation when fast_hadamard_transform is not available."""
745+
746+ def test_hadamard_transform_torch_basic (self ):
747+ """Test basic hadamard transform functionality."""
748+ # Import the torch fallback implementation directly
749+ from nemo_automodel .components .models .deepseek_v32 import layers
750+
751+ # Check if we're using the fallback
752+ if not layers ._FAST_HADAMARD_AVAILABLE :
753+ # Test the torch implementation
754+ batch_size = 4
755+ n = 64 # Must be power of 2
756+ x = torch .randn (batch_size , n )
757+ scale = n ** - 0.5
758+
759+ result = layers .hadamard_transform (x , scale )
760+
761+ assert result .shape == x .shape
762+ assert result .dtype == x .dtype
763+
764+ def test_hadamard_transform_torch_power_of_2 (self ):
765+ """Test that hadamard transform works with different power-of-2 sizes."""
766+ from nemo_automodel .components .models .deepseek_v32 import layers
767+
768+ if not layers ._FAST_HADAMARD_AVAILABLE :
769+ for n in [8 , 16 , 32 , 64 , 128 ]:
770+ batch_size = 2
771+ x = torch .randn (batch_size , n )
772+ scale = n ** - 0.5
773+
774+ result = layers .hadamard_transform (x , scale )
775+ assert result .shape == (batch_size , n )
776+
777+
778+ class TestRotateActivationEdgeCases :
779+ """Test edge cases for _rotate_activation function."""
780+
781+ @skip_if_no_gpu
782+ def test_rotate_activation_float16_converts (self ):
783+ """Test that float16 input is converted to bfloat16."""
784+ x = torch .randn (2 , 8 , 64 , device = "cuda" , dtype = torch .float16 )
785+ result = _rotate_activation (x )
786+ assert result .dtype == torch .bfloat16
787+ assert result .shape == x .shape
788+
789+ def test_rotate_activation_applies_scale (self ):
790+ """Test that rotation applies the correct scale factor."""
791+ from nemo_automodel .components .models .deepseek_v32 import layers
792+
793+ if not layers ._FAST_HADAMARD_AVAILABLE :
794+ # With fallback, we can verify the scale is applied
795+ x = torch .randn (2 , 64 , dtype = torch .bfloat16 )
796+ result = _rotate_activation (x )
797+ # The scale factor should be hidden_size^-0.5 = 64^-0.5 = 0.125
798+ assert result .shape == x .shape
0 commit comments