Skip to content

Commit c4a1601

Browse files
committed
test: add comprehensive unit tests for nonuniform TP
- Moved test from nonuniform_tp.py to tests/unit_tests/distributed/ - Added TestNonuniformTPUtilities: tests for utility functions - compute_uniform_tp_spares_with_parity (3 test cases) - get_active_ranks_for_dp (2 test cases) - Added TestNonuniformTPParameterResharding: tests for parameter resharding - ntp_map for no spares, healthy ranks, unhealthy ranks - ntp_init for layers with attention and MLP (4 test cases) - Added TestNonuniformTPOptimizer: tests for optimizer wrapper - attribute delegation, prepare_grads, contiguity handling (5 test cases) - Added TestNonuniformTPIntegration: integration tests - DDP initialization and backward hooks (2 test cases) - Total: 17 test cases covering all major NTP functionality
1 parent f9bc507 commit c4a1601

File tree

2 files changed

+361
-38
lines changed

2 files changed

+361
-38
lines changed

megatron/core/distributed/nonuniform_tp.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -697,41 +697,3 @@ def prepare_grads(self, *args, **kwargs):
697697
return result
698698

699699

700-
# ======================================================================================
701-
# Test Function
702-
# ======================================================================================
703-
704-
705-
def test_ntp():
706-
"""Test function for nonuniform TP initialization."""
707-
head_dim = 128
708-
ffn_exp = 4
709-
710-
class MockConfig:
711-
num_attention_heads = 24
712-
ffn_hidden_size = num_attention_heads * head_dim * ffn_exp
713-
714-
class MockModule:
715-
def __init__(self, out_features):
716-
self.weight = torch.nn.Parameter(torch.randn(out_features, 1, dtype=torch.half))
717-
self.weight.partition_dim = 1
718-
self.weight.tensor_model_parallel = True
719-
self.config = MockConfig()
720-
721-
def parameters(self):
722-
return [self.weight]
723-
724-
class MockLayer:
725-
def __init__(self):
726-
self.self_attention = MockModule(int(3 * 10248 / 8))
727-
self.mlp = MockModule(12288 // 8)
728-
729-
layer = MockLayer()
730-
ddp_config = DistributedDataParallelConfig(tp_base=8, tp_spares=2)
731-
ntp_init(layer, ddp_config)
732-
print("NTP initialization test passed!")
733-
return layer
734-
735-
736-
if __name__ == '__main__':
737-
layer = test_ntp()

0 commit comments

Comments
 (0)