Skip to content

Commit 43fbb46

Browse files
committed
refactor: move NTP config to separate NonuniformTPConfig class
- Create NonuniformTPConfig dataclass in nonuniform_tp.py - Remove NTP fields from DistributedDataParallelConfig (non-intrusive) - Update all NTP functions/classes to use NonuniformTPConfig - Update all tests to use NonuniformTPConfig - Update CLAUDE.md documentation This makes the NTP implementation completely self-contained with zero modifications to core Megatron files.
1 parent 3a7dd94 commit 43fbb46

File tree

4 files changed

+132
-107
lines changed

4 files changed

+132
-107
lines changed

CLAUDE.md

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,16 @@ This branch implements **Nonuniform Tensor Parallelism (NTP)**, a fault toleranc
1717

1818
### Key Changes
1919

20-
**New Module**: `megatron/core/distributed/nonuniform_tp.py` (404 lines)
20+
**New Module**: `megatron/core/distributed/nonuniform_tp.py` (699 lines)
2121
- Implements nonuniform TP where a subset of TP ranks ("spares") provide fault tolerance
2222
- Supports arbitrary non-contiguous GPU failures across all parallelism dimensions (DP, CP, PP)
2323
- Core ranks handle computation; spare ranks enable recovery from failures
24+
- Defines `NonuniformTPConfig` dataclass for NTP configuration
25+
- Contains all NTP logic in subclasses: `NonuniformTPDistributedDataParallel`, `NonuniformTPParamAndGradBuffer`, `NonuniformTPOptimizer`
26+
- **Non-intrusive design**: All NTP functionality is self-contained, no modifications to core Megatron files required
2427

2528
**Modified Files**:
26-
- `megatron/core/parallel_state.py`: Added NTP configuration support to `initialize_model_parallel()`
27-
- `megatron/core/distributed/distributed_data_parallel_config.py`: New fields for NTP config
28-
- `tp_base`: Base tensor parallel size (e.g., 8)
29-
- `tp_spares`: Number of spare ranks (e.g., 2 for reduced TP=6)
30-
- `num_reduced_tp_dp_ranks`: How many DP ranks use reduced TP
31-
- `non_active_ranks_per_dp`: Mapping of (DP, CP, PP) rank to list of non-active local TP ranks
32-
- `megatron/core/distributed/param_and_grad_buffer.py`: Parameter resharding for NTP
33-
- `megatron/core/optimizer/optimizer.py`: Optimizer integration
29+
- **None** - All NTP code is self-contained in `nonuniform_tp.py`
3430

3531
### NTP Concepts
3632

@@ -44,10 +40,10 @@ This branch implements **Nonuniform Tensor Parallelism (NTP)**, a fault toleranc
4440
### Example NTP Configuration
4541

4642
```python
47-
from megatron.core.distributed import DistributedDataParallelConfig
43+
from megatron.core.distributed.nonuniform_tp import NonuniformTPConfig
4844

4945
# Configure NTP with 2 spare ranks out of 8
50-
ddp_config = DistributedDataParallelConfig(
46+
ntp_config = NonuniformTPConfig(
5147
tp_base=8, # Original TP size
5248
tp_spares=2, # 2 spares = 6 active ranks
5349
num_reduced_tp_dp_ranks=1, # First DP rank uses reduced TP

megatron/core/distributed/distributed_data_parallel_config.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -162,28 +162,6 @@ class DistributedDataParallelConfig:
162162
delay_wgrad_compute: bool = False
163163
"""Delay the weight gradient computation to improve batch-level communication overlapping"""
164164

165-
tp_base: int = 8
166-
"""Base for tensor parallelism. This is the number of ranks in healthy tensor parallel groups.
167-
Used for nonuniform tensor parallelism."""
168-
169-
tp_spares: int = 0
170-
"""Number of spares for nonuniform tensor parallelism. When > 0, enables nonuniform TP mode
171-
where (tp_base - tp_spares) ranks handle computation and tp_spares ranks provide fault tolerance."""
172-
173-
num_reduced_tp_dp_ranks: int = 1
174-
"""Number of DP ranks that use reduced TP (tp_base - tp_spares). The remaining DP ranks use
175-
full tp_base. Reduced TP ranks are assumed to come first in the global rank ordering."""
176-
177-
non_active_ranks_per_dp: Optional[Dict[Tuple[int, int, int], List[int]]] = None
178-
"""Mapping of (DP rank, CP rank, PP rank) to list of non-active (spare) local TP rank IDs.
179-
This allows specifying arbitrary GPU failures across all parallelism dimensions.
180-
Example: {(0,0,0): [0,3], (0,1,0): [1,2], (1,0,0): [0,3]} means:
181-
- DP rank 0, CP rank 0, PP rank 0 has local TP ranks 0,3 as spares
182-
- DP rank 0, CP rank 1, PP rank 0 has local TP ranks 1,2 as spares
183-
- DP rank 1, CP rank 0, PP rank 0 has local TP ranks 0,3 as spares
184-
The number of non-active ranks must be consistent across CP replicas within each DP rank.
185-
If None, defaults to last tp_spares ranks as non-active."""
186-
187165
def __post_init__(self):
188166
import os
189167

0 commit comments

Comments
 (0)