|
18 | 18 | get_active_ranks_for_dp, |
19 | 19 | ntp_map, |
20 | 20 | ntp_init, |
| 21 | + initialize_nonuniform_tp_process_groups, |
21 | 22 | NonuniformTPDistributedDataParallel, |
22 | 23 | NonuniformTPOptimizer, |
23 | 24 | NonuniformTPParamAndGradBuffer, |
@@ -357,5 +358,121 @@ def test_ntp_backward_hook_core_gpu(self, mock_parallel_state): |
357 | 358 | pytest.skip(f"Skipping due to initialization requirements: {e}") |
358 | 359 |
|
359 | 360 |
|
| 361 | +class TestNonuniformTPEndToEnd: |
| 362 | + """ |
| 363 | + End-to-end test for NTP without mocking. |
| 364 | +
|
| 365 | + Tests NTP with 8 GPUs configured as: |
| 366 | + - 2 data-parallel workers |
| 367 | + - DP rank 0: TP=2 (reduced, using 2 out of 4 GPUs) |
| 368 | + - DP rank 1: TP=4 (healthy, using all 4 GPUs) |
| 369 | + - Total: 2 + 4 = 6 active GPUs out of 8 |
| 370 | + """ |
| 371 | + |
| 372 | + @classmethod |
| 373 | + def setup_class(cls): |
| 374 | + """Initialize model parallel for NTP testing.""" |
| 375 | + # Initialize with tp_base=4 |
| 376 | + Utils.initialize_model_parallel(tensor_model_parallel_size=4) |
| 377 | + |
| 378 | + @classmethod |
| 379 | + def teardown_class(cls): |
| 380 | + """Clean up model parallel.""" |
| 381 | + Utils.destroy_model_parallel() |
| 382 | + |
| 383 | + def test_ntp_end_to_end_with_8_gpus(self): |
| 384 | + """ |
| 385 | + End-to-end test using 8 GPUs with 2 DP workers: |
| 386 | + - DP rank 0: uses TP=2 (reduced from tp_base=4) |
| 387 | + - DP rank 1: uses TP=4 (healthy, full tp_base) |
| 388 | + """ |
| 389 | + import torch.distributed as dist |
| 390 | + from megatron.core import parallel_state |
| 391 | + |
| 392 | + # Check we have 8 GPUs |
| 393 | + world_size = dist.get_world_size() if dist.is_initialized() else 1 |
| 394 | + if world_size != 8: |
| 395 | + pytest.skip(f"This test requires 8 GPUs, but only {world_size} are available") |
| 396 | + |
| 397 | + # Get current rank info |
| 398 | + rank = dist.get_rank() |
| 399 | + tp_rank = parallel_state.get_tensor_model_parallel_rank() |
| 400 | + tp_size = parallel_state.get_tensor_model_parallel_world_size() |
| 401 | + dp_rank = parallel_state.get_data_parallel_rank() |
| 402 | + |
| 403 | + # Configure NTP: first DP rank uses reduced TP=2 |
| 404 | + ddp_config = DistributedDataParallelConfig( |
| 405 | + tp_base=4, |
| 406 | + tp_spares=2, |
| 407 | + num_reduced_tp_dp_ranks=1, |
| 408 | + non_active_ranks_per_dp={(0, 0, 0): [2, 3]}, # DP=0: GPUs 2,3 are spares |
| 409 | + ) |
| 410 | + |
| 411 | + # Reconfigure process groups for NTP |
| 412 | + from megatron.core.distributed.nonuniform_tp import initialize_nonuniform_tp_process_groups |
| 413 | + |
| 414 | + initialize_nonuniform_tp_process_groups(ddp_config) |
| 415 | + |
| 416 | + # After reconfiguration, check TP size |
| 417 | + tp_size_after = parallel_state.get_tensor_model_parallel_world_size() |
| 418 | + |
| 419 | + # Verify the configuration |
| 420 | + if dp_rank == 0: |
| 421 | + # First DP rank should have reduced TP=2 |
| 422 | + assert tp_size_after == 2, f"DP rank 0 should have TP=2, got {tp_size_after}" |
| 423 | + assert tp_rank < 2, f"DP rank 0 should have tp_rank < 2, got {tp_rank}" |
| 424 | + else: |
| 425 | + # Other DP ranks keep TP=4 |
| 426 | + assert tp_size_after == 4, f"DP rank {dp_rank} should have TP=4, got {tp_size_after}" |
| 427 | + assert tp_rank < 4, f"DP rank {dp_rank} should have tp_rank < 4, got {tp_rank}" |
| 428 | + |
| 429 | + # Create a simple model with tensor-parallel parameters |
| 430 | + hidden_size = 128 |
| 431 | + model = torch.nn.Linear(hidden_size, hidden_size, bias=False).cuda() |
| 432 | + |
| 433 | + # Mark it as tensor-parallel |
| 434 | + model.weight.tensor_model_parallel = True |
| 435 | + model.weight.partition_dim = 0 |
| 436 | + |
| 437 | + # Initialize NTP mappings |
| 438 | + from megatron.core.distributed.nonuniform_tp import ntp_map |
| 439 | + |
| 440 | + # For healthy ranks (DP=1), initialize send/recv splits |
| 441 | + if dp_rank == 1: |
| 442 | + # Create a mock module to test ntp_map |
| 443 | + class MockModule: |
| 444 | + def __init__(self, param): |
| 445 | + self.param = param |
| 446 | + |
| 447 | + def parameters(self): |
| 448 | + return [self.param] |
| 449 | + |
| 450 | + mock_module = MockModule(model.weight) |
| 451 | + ntp_map(mock_module, ddp_config, num_shards=hidden_size) |
| 452 | + |
| 453 | + # Verify send_splits and recv_splits were added |
| 454 | + assert hasattr(model.weight, 'send_splits'), "Healthy rank should have send_splits" |
| 455 | + assert hasattr(model.weight, 'recv_splits'), "Healthy rank should have recv_splits" |
| 456 | + assert len(model.weight.send_splits) == 4, "Should have splits for all tp_base ranks" |
| 457 | + |
| 458 | + # Test forward pass |
| 459 | + batch_size = 4 |
| 460 | + input_tensor = torch.randn(batch_size, hidden_size, device='cuda') |
| 461 | + output = model(input_tensor) |
| 462 | + |
| 463 | + # Verify output shape |
| 464 | + assert output.shape == (batch_size, hidden_size), f"Unexpected output shape: {output.shape}" |
| 465 | + |
| 466 | + # Verify gradients work |
| 467 | + loss = output.sum() |
| 468 | + loss.backward() |
| 469 | + assert model.weight.grad is not None, "Gradients should be computed" |
| 470 | + |
| 471 | + print( |
| 472 | + f"[Rank {rank}] NTP end-to-end test passed! " |
| 473 | + f"DP={dp_rank}, TP={tp_size_after}, tp_rank={tp_rank}" |
| 474 | + ) |
| 475 | + |
| 476 | + |
360 | 477 | if __name__ == '__main__': |
361 | 478 | pytest.main([__file__, '-v']) |
0 commit comments