Skip to content

Commit 4cd4544

Browse files
committed
fix
Signed-off-by: Hemil Desai <hemild@nvidia.com>
1 parent 5d01df8 commit 4cd4544

File tree

4 files changed

+1014
-1
lines changed

4 files changed

+1014
-1
lines changed

tests/unit_tests/models/deepseek_v3/test_dsv3_layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,6 @@ def create_mock_config(self, **overrides):
359359
config.hidden_size = 1024
360360
config.rope_scaling = None
361361
config.max_position_embeddings = 4096
362-
config.rms_norm_eps = 1e-6
363362

364363
for key, value in overrides.items():
365364
setattr(config, key, value)

tests/unit_tests/models/deepseek_v32/test_dsv32_layers.py

Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,3 +408,391 @@ def test_init_weights(self, mock_trunc_normal):
408408

409409
# Check that norm reset_parameters was called
410410
assert mock_norm.reset_parameters.call_count >= 2 # q_a_layernorm, kv_a_layernorm
411+
412+
413+
class TestDeepseekV32IndexerInitWeights:
414+
def create_mock_config(self, **overrides):
415+
config = Mock(spec=DeepseekV32Config)
416+
config.num_attention_heads = 8
417+
config.hidden_size = 256
418+
config.q_lora_rank = 128
419+
config.index_n_heads = 4
420+
config.index_head_dim = 32
421+
config.index_topk = 16
422+
config.qk_rope_head_dim = 16
423+
424+
for key, value in overrides.items():
425+
setattr(config, key, value)
426+
return config
427+
428+
@patch("torch.nn.init.trunc_normal_")
429+
@patch("nemo_automodel.components.models.deepseek_v32.layers.initialize_linear_module")
430+
def test_indexer_init_weights(self, mock_init_linear, mock_trunc_normal):
431+
"""Test Indexer weight initialization directly."""
432+
config = self.create_mock_config()
433+
backend = BackendConfig(attn="sdpa", linear="torch", rms_norm="torch")
434+
435+
mock_linear = Mock()
436+
mock_linear.weight = torch.randn(64, 256)
437+
mock_init_linear.return_value = mock_linear
438+
439+
indexer = DeepseekV32Indexer(config, backend)
440+
indexer.init_weights(init_std=0.02)
441+
442+
# Should call trunc_normal_ for wq_b, wk, weights_proj (3 linear layers)
443+
assert mock_trunc_normal.call_count == 3
444+
445+
446+
class TestDeepseekV32IndexerForward:
447+
def create_mock_config(self, **overrides):
448+
config = Mock(spec=DeepseekV32Config)
449+
config.num_attention_heads = 8
450+
config.hidden_size = 64
451+
config.q_lora_rank = 32
452+
config.index_n_heads = 4
453+
config.index_head_dim = 16
454+
config.index_topk = 8
455+
config.qk_rope_head_dim = 8
456+
457+
for key, value in overrides.items():
458+
setattr(config, key, value)
459+
return config
460+
461+
@skip_if_no_gpu
462+
def test_indexer_forward_bshd(self):
463+
"""Test Indexer forward pass with bshd format."""
464+
config = self.create_mock_config()
465+
backend = BackendConfig(attn="sdpa", linear="torch", rms_norm="torch")
466+
467+
indexer = DeepseekV32Indexer(config, backend).cuda().to(torch.bfloat16)
468+
469+
bsz, seq_len = 2, 16
470+
x = torch.randn(bsz, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16)
471+
q_resid = torch.randn(bsz, seq_len, config.q_lora_rank, device="cuda", dtype=torch.bfloat16)
472+
# Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
473+
angles = torch.randn(bsz, seq_len, config.qk_rope_head_dim // 2, device="cuda", dtype=torch.float32)
474+
freqs_cis = torch.polar(torch.ones_like(angles), angles)
475+
476+
topk_indices = indexer(x, q_resid, freqs_cis)
477+
478+
assert topk_indices.shape == (bsz, seq_len, config.index_topk)
479+
assert topk_indices.dtype == torch.int64
480+
481+
@pytest.mark.skip(reason="thd format requires complex freqs_cis setup matching model runtime")
482+
@skip_if_no_gpu
483+
def test_indexer_forward_thd(self):
484+
"""Test Indexer forward pass with thd format."""
485+
config = self.create_mock_config()
486+
backend = BackendConfig(attn="sdpa", linear="torch", rms_norm="torch")
487+
488+
indexer = DeepseekV32Indexer(config, backend).cuda().to(torch.bfloat16)
489+
490+
num_tokens = 32
491+
x = torch.randn(num_tokens, config.hidden_size, device="cuda", dtype=torch.bfloat16)
492+
q_resid = torch.randn(num_tokens, config.q_lora_rank, device="cuda", dtype=torch.bfloat16)
493+
# Create complex freqs_cis for thd format [T, D/2] as complex tensor
494+
angles = torch.randn(num_tokens, config.qk_rope_head_dim // 2, device="cuda", dtype=torch.float32)
495+
freqs_cis = torch.polar(torch.ones_like(angles), angles)
496+
497+
topk_indices = indexer(x, q_resid, freqs_cis)
498+
499+
assert topk_indices.shape == (num_tokens, config.index_topk)
500+
assert topk_indices.dtype == torch.int64
501+
502+
@skip_if_no_gpu
503+
def test_indexer_forward_with_attention_mask_bshd(self):
504+
"""Test Indexer forward pass with attention mask in bshd format."""
505+
config = self.create_mock_config()
506+
backend = BackendConfig(attn="sdpa", linear="torch", rms_norm="torch")
507+
508+
indexer = DeepseekV32Indexer(config, backend).cuda().to(torch.bfloat16)
509+
510+
bsz, seq_len = 2, 16
511+
x = torch.randn(bsz, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16)
512+
q_resid = torch.randn(bsz, seq_len, config.q_lora_rank, device="cuda", dtype=torch.bfloat16)
513+
# Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
514+
angles = torch.randn(bsz, seq_len, config.qk_rope_head_dim // 2, device="cuda", dtype=torch.float32)
515+
freqs_cis = torch.polar(torch.ones_like(angles), angles)
516+
517+
# Create causal mask
518+
attention_mask = torch.triu(
519+
torch.full((1, 1, seq_len, seq_len), float("-inf"), device="cuda"),
520+
diagonal=1,
521+
)
522+
523+
topk_indices = indexer(x, q_resid, freqs_cis, attention_mask=attention_mask)
524+
525+
assert topk_indices.shape == (bsz, seq_len, config.index_topk)
526+
527+
@pytest.mark.skip(reason="thd format requires complex freqs_cis setup matching model runtime")
528+
@skip_if_no_gpu
529+
def test_indexer_forward_with_attention_mask_thd(self):
530+
"""Test Indexer forward pass with attention mask in thd format."""
531+
config = self.create_mock_config()
532+
backend = BackendConfig(attn="sdpa", linear="torch", rms_norm="torch")
533+
534+
indexer = DeepseekV32Indexer(config, backend).cuda().to(torch.bfloat16)
535+
536+
num_tokens = 32
537+
x = torch.randn(num_tokens, config.hidden_size, device="cuda", dtype=torch.bfloat16)
538+
q_resid = torch.randn(num_tokens, config.q_lora_rank, device="cuda", dtype=torch.bfloat16)
539+
# Create complex freqs_cis for thd format [T, D/2] as complex tensor
540+
angles = torch.randn(num_tokens, config.qk_rope_head_dim // 2, device="cuda", dtype=torch.float32)
541+
freqs_cis = torch.polar(torch.ones_like(angles), angles)
542+
543+
# Create causal mask for thd format
544+
attention_mask = torch.triu(
545+
torch.full((1, 1, num_tokens, num_tokens), float("-inf"), device="cuda"),
546+
diagonal=1,
547+
)
548+
549+
topk_indices = indexer(x, q_resid, freqs_cis, attention_mask=attention_mask)
550+
551+
assert topk_indices.shape == (num_tokens, config.index_topk)
552+
553+
@skip_if_no_gpu
554+
def test_indexer_forward_topk_larger_than_seq(self):
555+
"""Test Indexer forward when topk > seq_len."""
556+
config = self.create_mock_config(index_topk=64) # larger than seq_len
557+
backend = BackendConfig(attn="sdpa", linear="torch", rms_norm="torch")
558+
559+
indexer = DeepseekV32Indexer(config, backend).cuda().to(torch.bfloat16)
560+
561+
bsz, seq_len = 2, 16 # seq_len < index_topk
562+
x = torch.randn(bsz, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16)
563+
q_resid = torch.randn(bsz, seq_len, config.q_lora_rank, device="cuda", dtype=torch.bfloat16)
564+
# Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
565+
angles = torch.randn(bsz, seq_len, config.qk_rope_head_dim // 2, device="cuda", dtype=torch.float32)
566+
freqs_cis = torch.polar(torch.ones_like(angles), angles)
567+
568+
topk_indices = indexer(x, q_resid, freqs_cis)
569+
570+
# Should clamp to seq_len
571+
assert topk_indices.shape == (bsz, seq_len, seq_len)
572+
573+
574+
class TestDeepseekV32MLAForward:
575+
def create_mock_config(self, **overrides):
576+
config = Mock(spec=DeepseekV32Config)
577+
config.num_attention_heads = 4
578+
config.hidden_size = 64
579+
config.q_lora_rank = 32
580+
config.kv_lora_rank = 32
581+
config.qk_nope_head_dim = 8
582+
config.qk_rope_head_dim = 8
583+
config.qk_head_dim = 16
584+
config.v_head_dim = 16
585+
config.rope_scaling = None
586+
config.max_position_embeddings = 4096
587+
config.index_n_heads = 4
588+
config.index_head_dim = 16
589+
config.index_topk = 8
590+
591+
for key, value in overrides.items():
592+
setattr(config, key, value)
593+
return config
594+
595+
@skip_if_no_gpu
596+
def test_mla_forward_bshd_sdpa(self):
597+
"""Test MLA forward pass with bshd format and SDPA backend."""
598+
config = self.create_mock_config()
599+
backend = BackendConfig(attn="sdpa", linear="torch", rms_norm="torch")
600+
601+
mla = DeepseekV32MLA(config, backend).cuda().to(torch.bfloat16)
602+
603+
bsz, seq_len = 2, 16
604+
x = torch.randn(bsz, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16)
605+
# Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
606+
angles = torch.randn(bsz, seq_len, config.qk_rope_head_dim // 2, device="cuda", dtype=torch.float32)
607+
freqs_cis = torch.polar(torch.ones_like(angles), angles)
608+
609+
output = mla(x, freqs_cis)
610+
611+
assert output.shape == (bsz, seq_len, config.hidden_size)
612+
assert output.dtype == torch.bfloat16
613+
614+
@pytest.mark.skip(reason="thd format requires complex freqs_cis setup matching model runtime")
615+
@skip_if_no_gpu
616+
def test_mla_forward_thd_sdpa(self):
617+
"""Test MLA forward pass with thd format and SDPA backend."""
618+
config = self.create_mock_config()
619+
backend = BackendConfig(attn="sdpa", linear="torch", rms_norm="torch")
620+
621+
mla = DeepseekV32MLA(config, backend).cuda().to(torch.bfloat16)
622+
623+
num_tokens = 32
624+
x = torch.randn(num_tokens, config.hidden_size, device="cuda", dtype=torch.bfloat16)
625+
# Create complex freqs_cis for thd format [T, D/2] as complex tensor
626+
angles = torch.randn(num_tokens, config.qk_rope_head_dim // 2, device="cuda", dtype=torch.float32)
627+
freqs_cis = torch.polar(torch.ones_like(angles), angles)
628+
629+
output = mla(x, freqs_cis)
630+
631+
assert output.shape == (num_tokens, config.hidden_size)
632+
assert output.dtype == torch.bfloat16
633+
634+
@skip_if_no_gpu
635+
def test_mla_forward_with_attention_mask(self):
636+
"""Test MLA forward pass with attention mask."""
637+
config = self.create_mock_config()
638+
backend = BackendConfig(attn="sdpa", linear="torch", rms_norm="torch")
639+
640+
mla = DeepseekV32MLA(config, backend).cuda().to(torch.bfloat16)
641+
642+
bsz, seq_len = 2, 16
643+
x = torch.randn(bsz, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16)
644+
# Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
645+
angles = torch.randn(bsz, seq_len, config.qk_rope_head_dim // 2, device="cuda", dtype=torch.float32)
646+
freqs_cis = torch.polar(torch.ones_like(angles), angles)
647+
648+
# Create causal mask
649+
attention_mask = torch.triu(
650+
torch.full((1, 1, seq_len, seq_len), float("-inf"), device="cuda"),
651+
diagonal=1,
652+
)
653+
654+
output = mla(x, freqs_cis, attention_mask=attention_mask)
655+
656+
assert output.shape == (bsz, seq_len, config.hidden_size)
657+
658+
@skip_te
659+
@skip_if_no_gpu
660+
def test_mla_forward_bshd_te(self):
661+
"""Test MLA forward pass with bshd format and TE backend."""
662+
config = self.create_mock_config()
663+
backend = BackendConfig(attn="te", linear="torch", rms_norm="torch")
664+
665+
mla = DeepseekV32MLA(config, backend).cuda().to(torch.bfloat16)
666+
667+
bsz, seq_len = 2, 16
668+
x = torch.randn(bsz, seq_len, config.hidden_size, device="cuda", dtype=torch.bfloat16)
669+
# Create complex freqs_cis for bshd format [B, T, D/2] as complex tensor
670+
angles = torch.randn(bsz, seq_len, config.qk_rope_head_dim // 2, device="cuda", dtype=torch.float32)
671+
freqs_cis = torch.polar(torch.ones_like(angles), angles)
672+
673+
output = mla(x, freqs_cis)
674+
675+
assert output.shape == (bsz, seq_len, config.hidden_size)
676+
677+
678+
class TestBuildSparseMaskWithAttentionMask:
679+
def create_mock_config(self, **overrides):
680+
config = Mock(spec=DeepseekV32Config)
681+
config.num_attention_heads = 8
682+
config.hidden_size = 256
683+
config.q_lora_rank = 128
684+
config.kv_lora_rank = 64
685+
config.qk_nope_head_dim = 16
686+
config.qk_rope_head_dim = 16
687+
config.qk_head_dim = 32
688+
config.v_head_dim = 32
689+
config.rope_scaling = None
690+
config.max_position_embeddings = 4096
691+
config.index_n_heads = 4
692+
config.index_head_dim = 32
693+
config.index_topk = 16
694+
695+
for key, value in overrides.items():
696+
setattr(config, key, value)
697+
return config
698+
699+
@patch("nemo_automodel.components.models.deepseek_v32.layers.initialize_linear_module")
700+
@patch("nemo_automodel.components.models.deepseek_v32.layers.initialize_rms_norm_module")
701+
@patch("nemo_automodel.components.models.deepseek_v32.layers.initialize_attn_module_and_func")
702+
def test_build_sparse_mask_combines_with_attention_mask(self, mock_init_attn, mock_init_rms, mock_init_linear):
703+
"""Test that sparse mask is combined with attention mask."""
704+
config = self.create_mock_config()
705+
backend = BackendConfig(attn="sdpa", linear="torch", rms_norm="torch")
706+
707+
mock_init_linear.return_value = Mock()
708+
mock_init_rms.return_value = Mock()
709+
mock_init_attn.return_value = (Mock(), Mock())
710+
711+
mla = DeepseekV32MLA(config, backend)
712+
713+
bsz, seq_len, topk = 2, 32, 8
714+
topk_indices = torch.randint(0, seq_len, (bsz, seq_len, topk))
715+
716+
# Create an attention mask
717+
attention_mask = torch.triu(
718+
torch.full((bsz, 1, seq_len, seq_len), float("-inf")),
719+
diagonal=1,
720+
)
721+
722+
sparse_mask = mla._build_sparse_mask(
723+
topk_indices,
724+
seq_len,
725+
qkv_format="bshd",
726+
bsz=bsz,
727+
n_heads=1,
728+
dtype=torch.float32,
729+
attention_mask=attention_mask,
730+
union_across_batches=False,
731+
)
732+
733+
# Result should combine both masks
734+
assert sparse_mask.shape == (bsz, 1, seq_len, seq_len)
735+
736+
# Check that causal structure is preserved (upper triangle should be -inf)
737+
for b in range(bsz):
738+
for i in range(seq_len):
739+
for j in range(i + 1, seq_len):
740+
assert sparse_mask[b, 0, i, j] == float("-inf")
741+
742+
743+
class TestHadamardTransformFallback:
744+
"""Test the fallback hadamard_transform implementation when fast_hadamard_transform is not available."""
745+
746+
def test_hadamard_transform_torch_basic(self):
747+
"""Test basic hadamard transform functionality."""
748+
# Import the torch fallback implementation directly
749+
from nemo_automodel.components.models.deepseek_v32 import layers
750+
751+
# Check if we're using the fallback
752+
if not layers._FAST_HADAMARD_AVAILABLE:
753+
# Test the torch implementation
754+
batch_size = 4
755+
n = 64 # Must be power of 2
756+
x = torch.randn(batch_size, n)
757+
scale = n**-0.5
758+
759+
result = layers.hadamard_transform(x, scale)
760+
761+
assert result.shape == x.shape
762+
assert result.dtype == x.dtype
763+
764+
def test_hadamard_transform_torch_power_of_2(self):
765+
"""Test that hadamard transform works with different power-of-2 sizes."""
766+
from nemo_automodel.components.models.deepseek_v32 import layers
767+
768+
if not layers._FAST_HADAMARD_AVAILABLE:
769+
for n in [8, 16, 32, 64, 128]:
770+
batch_size = 2
771+
x = torch.randn(batch_size, n)
772+
scale = n**-0.5
773+
774+
result = layers.hadamard_transform(x, scale)
775+
assert result.shape == (batch_size, n)
776+
777+
778+
class TestRotateActivationEdgeCases:
779+
"""Test edge cases for _rotate_activation function."""
780+
781+
@skip_if_no_gpu
782+
def test_rotate_activation_float16_converts(self):
783+
"""Test that float16 input is converted to bfloat16."""
784+
x = torch.randn(2, 8, 64, device="cuda", dtype=torch.float16)
785+
result = _rotate_activation(x)
786+
assert result.dtype == torch.bfloat16
787+
assert result.shape == x.shape
788+
789+
def test_rotate_activation_applies_scale(self):
790+
"""Test that rotation applies the correct scale factor."""
791+
from nemo_automodel.components.models.deepseek_v32 import layers
792+
793+
if not layers._FAST_HADAMARD_AVAILABLE:
794+
# With fallback, we can verify the scale is applied
795+
x = torch.randn(2, 64, dtype=torch.bfloat16)
796+
result = _rotate_activation(x)
797+
# The scale factor should be hidden_size^-0.5 = 64^-0.5 = 0.125
798+
assert result.shape == x.shape

0 commit comments

Comments
 (0)