Skip to content

Commit 5c63ff2

Browse files
committed
test: add end-to-end NTP test for 8 GPUs without mocking
- Tests 2 DP workers: DP rank 0 with TP=2 (reduced), DP rank 1 with TP=4 (healthy) - Uses tp_base=4, tp_spares=2 configuration - Verifies process group reconfiguration - Tests parameter initialization and gradient computation - No mocking - actual distributed test with real model
1 parent f80dee2 commit 5c63ff2

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

tests/unit_tests/distributed/test_nonuniform_tp.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
get_active_ranks_for_dp,
1919
ntp_map,
2020
ntp_init,
21+
initialize_nonuniform_tp_process_groups,
2122
NonuniformTPDistributedDataParallel,
2223
NonuniformTPOptimizer,
2324
NonuniformTPParamAndGradBuffer,
@@ -357,5 +358,121 @@ def test_ntp_backward_hook_core_gpu(self, mock_parallel_state):
357358
pytest.skip(f"Skipping due to initialization requirements: {e}")
358359

359360

361+
class TestNonuniformTPEndToEnd:
362+
"""
363+
End-to-end test for NTP without mocking.
364+
365+
Tests NTP with 8 GPUs configured as:
366+
- 2 data-parallel workers
367+
- DP rank 0: TP=2 (reduced, using 2 out of 4 GPUs)
368+
- DP rank 1: TP=4 (healthy, using all 4 GPUs)
369+
- Total: 2 + 4 = 6 active GPUs out of 8
370+
"""
371+
372+
@classmethod
373+
def setup_class(cls):
374+
"""Initialize model parallel for NTP testing."""
375+
# Initialize with tp_base=4
376+
Utils.initialize_model_parallel(tensor_model_parallel_size=4)
377+
378+
@classmethod
379+
def teardown_class(cls):
380+
"""Clean up model parallel."""
381+
Utils.destroy_model_parallel()
382+
383+
def test_ntp_end_to_end_with_8_gpus(self):
384+
"""
385+
End-to-end test using 8 GPUs with 2 DP workers:
386+
- DP rank 0: uses TP=2 (reduced from tp_base=4)
387+
- DP rank 1: uses TP=4 (healthy, full tp_base)
388+
"""
389+
import torch.distributed as dist
390+
from megatron.core import parallel_state
391+
392+
# Check we have 8 GPUs
393+
world_size = dist.get_world_size() if dist.is_initialized() else 1
394+
if world_size != 8:
395+
pytest.skip(f"This test requires 8 GPUs, but only {world_size} are available")
396+
397+
# Get current rank info
398+
rank = dist.get_rank()
399+
tp_rank = parallel_state.get_tensor_model_parallel_rank()
400+
tp_size = parallel_state.get_tensor_model_parallel_world_size()
401+
dp_rank = parallel_state.get_data_parallel_rank()
402+
403+
# Configure NTP: first DP rank uses reduced TP=2
404+
ddp_config = DistributedDataParallelConfig(
405+
tp_base=4,
406+
tp_spares=2,
407+
num_reduced_tp_dp_ranks=1,
408+
non_active_ranks_per_dp={(0, 0, 0): [2, 3]}, # DP=0: GPUs 2,3 are spares
409+
)
410+
411+
# Reconfigure process groups for NTP
412+
from megatron.core.distributed.nonuniform_tp import initialize_nonuniform_tp_process_groups
413+
414+
initialize_nonuniform_tp_process_groups(ddp_config)
415+
416+
# After reconfiguration, check TP size
417+
tp_size_after = parallel_state.get_tensor_model_parallel_world_size()
418+
419+
# Verify the configuration
420+
if dp_rank == 0:
421+
# First DP rank should have reduced TP=2
422+
assert tp_size_after == 2, f"DP rank 0 should have TP=2, got {tp_size_after}"
423+
assert tp_rank < 2, f"DP rank 0 should have tp_rank < 2, got {tp_rank}"
424+
else:
425+
# Other DP ranks keep TP=4
426+
assert tp_size_after == 4, f"DP rank {dp_rank} should have TP=4, got {tp_size_after}"
427+
assert tp_rank < 4, f"DP rank {dp_rank} should have tp_rank < 4, got {tp_rank}"
428+
429+
# Create a simple model with tensor-parallel parameters
430+
hidden_size = 128
431+
model = torch.nn.Linear(hidden_size, hidden_size, bias=False).cuda()
432+
433+
# Mark it as tensor-parallel
434+
model.weight.tensor_model_parallel = True
435+
model.weight.partition_dim = 0
436+
437+
# Initialize NTP mappings
438+
from megatron.core.distributed.nonuniform_tp import ntp_map
439+
440+
# For healthy ranks (DP=1), initialize send/recv splits
441+
if dp_rank == 1:
442+
# Create a mock module to test ntp_map
443+
class MockModule:
444+
def __init__(self, param):
445+
self.param = param
446+
447+
def parameters(self):
448+
return [self.param]
449+
450+
mock_module = MockModule(model.weight)
451+
ntp_map(mock_module, ddp_config, num_shards=hidden_size)
452+
453+
# Verify send_splits and recv_splits were added
454+
assert hasattr(model.weight, 'send_splits'), "Healthy rank should have send_splits"
455+
assert hasattr(model.weight, 'recv_splits'), "Healthy rank should have recv_splits"
456+
assert len(model.weight.send_splits) == 4, "Should have splits for all tp_base ranks"
457+
458+
# Test forward pass
459+
batch_size = 4
460+
input_tensor = torch.randn(batch_size, hidden_size, device='cuda')
461+
output = model(input_tensor)
462+
463+
# Verify output shape
464+
assert output.shape == (batch_size, hidden_size), f"Unexpected output shape: {output.shape}"
465+
466+
# Verify gradients work
467+
loss = output.sum()
468+
loss.backward()
469+
assert model.weight.grad is not None, "Gradients should be computed"
470+
471+
print(
472+
f"[Rank {rank}] NTP end-to-end test passed! "
473+
f"DP={dp_rank}, TP={tp_size_after}, tp_rank={tp_rank}"
474+
)
475+
476+
360477
if __name__ == '__main__':
361478
pytest.main([__file__, '-v'])

0 commit comments

Comments
 (0)