Skip to content

Commit f9bc507

Browse files
committed
feat: non-intrusive nonuniform tensor parallelism implementation
- All NTP logic contained in nonuniform_tp.py as subclasses - NonuniformTPDistributedDataParallel: inherits from DistributedDataParallel - NonuniformTPParamAndGradBuffer: handles gradient buffer splitting for NTP - NonuniformTPOptimizer: wrapper for gradient contiguity - initialize_nonuniform_tp_process_groups(): reconfigures process groups after init - Only config changes to core files (distributed_data_parallel_config.py) - Added comprehensive CLAUDE.md documentation
1 parent a51c1c8 commit f9bc507

File tree

3 files changed

+1101
-1
lines changed

3 files changed

+1101
-1
lines changed

CLAUDE.md

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Project Overview
6+
7+
**Megatron-LM** is NVIDIA's reference implementation for training large language models at scale using advanced parallelism strategies. It consists of two main components:
8+
9+
- **Megatron Core** (`megatron/core/`): Production library providing composable building blocks, GPU-optimized kernels, and parallelism strategies (TP, PP, DP, EP, CP)
10+
- **Megatron-LM** (top-level): Reference training scripts and examples for GPT, LLaMA, DeepSeek, Qwen, Mamba, and other models
11+
12+
The codebase supports training models from 2B to 462B+ parameters across thousands of GPUs with state-of-the-art performance (up to 47% MFU on H100 clusters).
13+
14+
## Current Branch: nonuniform-tp
15+
16+
This branch implements **Nonuniform Tensor Parallelism (NTP)**, a fault tolerance mechanism that allows training to continue when GPU failures occur within a tensor-parallel group.
17+
18+
### Key Changes
19+
20+
**New Module**: `megatron/core/distributed/nonuniform_tp.py` (404 lines)
21+
- Implements nonuniform TP where a subset of TP ranks ("spares") provide fault tolerance
22+
- Supports arbitrary non-contiguous GPU failures across all parallelism dimensions (DP, CP, PP)
23+
- Core ranks handle computation; spare ranks enable recovery from failures
24+
25+
**Modified Files**:
26+
- `megatron/core/parallel_state.py`: Added NTP configuration support to `initialize_model_parallel()`
27+
- `megatron/core/distributed/distributed_data_parallel_config.py`: New fields for NTP config
28+
- `tp_base`: Base tensor parallel size (e.g., 8)
29+
- `tp_spares`: Number of spare ranks (e.g., 2 for reduced TP=6)
30+
- `num_reduced_tp_dp_ranks`: How many DP ranks use reduced TP
31+
- `non_active_ranks_per_dp`: Mapping of (DP, CP, PP) rank to list of non-active local TP ranks
32+
- `megatron/core/distributed/param_and_grad_buffer.py`: Parameter resharding for NTP
33+
- `megatron/core/optimizer/optimizer.py`: Optimizer integration
34+
35+
### NTP Concepts
36+
37+
- **tp_base**: Original/healthy tensor parallel size (e.g., 8 GPUs)
38+
- **tp_spares**: Number of spare GPUs for fault tolerance (e.g., 2)
39+
- **Reduced TP**: Actual working size = tp_base - tp_spares (e.g., 6 active GPUs)
40+
- **Healthy ranks**: DP replicas using full tp_base (no failures)
41+
- **Unhealthy ranks**: DP replicas with failures, using reduced TP
42+
- **send_splits/recv_splits**: Parameter resharding metadata for synchronizing between healthy and unhealthy ranks
43+
44+
### Example NTP Configuration
45+
46+
```python
47+
from megatron.core.distributed import DistributedDataParallelConfig
48+
49+
# Configure NTP with 2 spare ranks out of 8
50+
ddp_config = DistributedDataParallelConfig(
51+
tp_base=8, # Original TP size
52+
tp_spares=2, # 2 spares = 6 active ranks
53+
num_reduced_tp_dp_ranks=1, # First DP rank uses reduced TP
54+
non_active_ranks_per_dp={
55+
(0, 0, 0): [2, 5], # DP=0, CP=0, PP=0 has GPU 2,5 failed
56+
(0, 1, 0): [1, 3], # DP=0, CP=1, PP=0 has GPU 1,3 failed
57+
}
58+
)
59+
```
60+
61+
## Development Setup
62+
63+
### Installation
64+
65+
```bash
66+
# Install with dev dependencies
67+
pip install "setuptools<80.0.0,>=77.0.0" "packaging>=24.2"
68+
pip install --no-build-isolation .[mlm,dev]
69+
70+
# Or use Docker (recommended)
71+
docker run --runtime=nvidia --gpus all -it --rm \
72+
-v $(pwd):/workspace/megatron \
73+
-e PIP_CONSTRAINT= \
74+
nvcr.io/nvidia/pytorch:25.04-py3
75+
```
76+
77+
### Running Tests
78+
79+
```bash
80+
# Run all unit tests
81+
pytest tests/unit_tests/
82+
83+
# Run specific test file
84+
pytest tests/unit_tests/test_optimizer.py
85+
86+
# Run specific test
87+
pytest tests/unit_tests/test_optimizer.py::TestOptimizer::test_specific_case
88+
89+
# Run with verbose output
90+
pytest tests/unit_tests/ -v -s
91+
92+
# Run functional tests (requires GPUs)
93+
pytest tests/functional_tests/
94+
```
95+
96+
### Linting and Formatting
97+
98+
```bash
99+
# Run pre-commit hooks (black, pylint, isort)
100+
pre-commit run --all-files
101+
102+
# Format code with black
103+
black megatron/core/ tests/unit_tests/
104+
105+
# Run pylint
106+
pylint megatron/core/
107+
108+
# Sort imports
109+
isort megatron/core/
110+
```
111+
112+
Black formatting uses `--skip-magic-trailing-comma` and `--skip-string-normalization` options.
113+
114+
### Training Examples
115+
116+
```bash
117+
# Simple training example (2 GPUs, mock data)
118+
torchrun --nproc_per_node=2 examples/run_simple_mcore_train_loop.py
119+
120+
# LLaMA-3 8B with FP8 (8 GPUs)
121+
./examples/llama/train_llama3_8b_fp8.sh
122+
123+
# GPT-3 pretraining
124+
torchrun --nproc_per_node=8 pretrain_gpt.py \
125+
--tensor-model-parallel-size 2 \
126+
--pipeline-model-parallel-size 2 \
127+
--num-layers 24 \
128+
--hidden-size 1024 \
129+
--num-attention-heads 16 \
130+
--micro-batch-size 4 \
131+
--global-batch-size 32 \
132+
--seq-length 1024 \
133+
--train-iters 100000 \
134+
--data-path /path/to/data
135+
```
136+
137+
## Architecture Overview
138+
139+
### Directory Structure
140+
141+
```
142+
megatron/core/ # Megatron Core library (composable components)
143+
├── models/ # Model architectures (GPT, LLaMA, Mixtral, etc.)
144+
├── transformer/ # Transformer building blocks (attention, MLP, layers)
145+
├── tensor_parallel/ # Tensor parallelism implementations
146+
├── pipeline_parallel/ # Pipeline parallelism implementations
147+
├── distributed/ # Distributed training (FSDP, DDP, NTP)
148+
│ ├── distributed_data_parallel.py
149+
│ ├── nonuniform_tp.py # NEW: Nonuniform TP for fault tolerance
150+
│ └── param_and_grad_buffer.py
151+
├── optimizer/ # Distributed optimizers
152+
├── datasets/ # Dataset utilities
153+
├── inference/ # Inference engines
154+
└── export/ # Model export (TensorRT-LLM, etc.)
155+
156+
megatron/training/ # Training utilities
157+
megatron/inference/ # Inference server
158+
megatron/legacy/ # Legacy components
159+
160+
Top-level scripts: # Entry points for training
161+
├── pretrain_gpt.py
162+
├── pretrain_bert.py
163+
├── pretrain_mamba.py
164+
├── pretrain_vlm.py
165+
└── train_rl.py
166+
167+
examples/ # Training recipes and tutorials
168+
tools/ # Data preprocessing, checkpoint conversion
169+
tests/ # Unit and functional tests
170+
```
171+
172+
### Key Architecture Concepts
173+
174+
#### Parallelism Strategies
175+
176+
Megatron supports multiple parallelism dimensions that can be combined:
177+
178+
1. **Tensor Parallelism (TP)**: Splits individual layers across GPUs
179+
- `--tensor-model-parallel-size N`: N-way tensor parallelism
180+
- Use with Sequence Parallelism for memory efficiency
181+
- Recommended for large models where layers don't fit on single GPU
182+
183+
2. **Pipeline Parallelism (PP)**: Splits model depth across GPUs
184+
- `--pipeline-model-parallel-size N`: N pipeline stages
185+
- `--virtual-pipeline-model-parallel-size M`: Virtual stages for load balancing
186+
- Reduces memory per GPU but adds pipeline bubbles
187+
188+
3. **Data Parallelism (DP)**: Replicates model across GPUs
189+
- Standard DDP or FSDP (ZeRO-1/2/3)
190+
- `--use-custom-fsdp`: Megatron's optimized FSDP (~15% faster)
191+
- `--data-parallel-sharding-strategy`: optim, optim_grads, optim_grads_params
192+
193+
4. **Context Parallelism (CP)**: Splits long sequences across GPUs
194+
- `--context-parallel-size N`: N-way context parallelism
195+
- For handling very long sequences (8K+)
196+
- Communication types: p2p, a2a, allgather
197+
198+
5. **Expert Parallelism (EP)**: Distributes MoE experts across GPUs
199+
- `--expert-model-parallel-size N`: N-way expert parallelism
200+
- Required for large MoE models (Mixtral, DeepSeek-V3)
201+
- **Important**: When using EP with TP, Sequence Parallelism must be enabled
202+
203+
6. **Nonuniform Tensor Parallelism (NTP)** *(This Branch)*: Fault tolerance for TP
204+
- Allows subset of TP ranks to fail while continuing training
205+
- Configure via `DistributedDataParallelConfig` with tp_base/tp_spares
206+
- Healthy and unhealthy DP ranks synchronize via parameter resharding
207+
208+
#### Global Rank Calculation
209+
210+
```
211+
global_rank = local_tp_rank +
212+
dp_rank * tp_size * cp_size +
213+
pp_rank * tp_size * cp_size * dp_size
214+
```
215+
216+
For NTP, use `tp_base` instead of actual `tp_size` when calculating DP rank.
217+
218+
#### Process Group Initialization
219+
220+
The `initialize_model_parallel()` function in `parallel_state.py` creates all process groups:
221+
- Tensor parallel groups (TP)
222+
- Pipeline parallel groups (PP)
223+
- Data parallel groups (DP)
224+
- Context parallel groups (CP)
225+
- Combined groups (TP+CP, model parallel)
226+
227+
For NTP, pass `ntp_config` dict to enable nonuniform TP mode.
228+
229+
### Data Preprocessing
230+
231+
Data must be preprocessed into binary format:
232+
233+
```bash
234+
python tools/preprocess_data.py \
235+
--input data.jsonl \
236+
--output-prefix processed_data \
237+
--tokenizer-type HuggingFaceTokenizer \
238+
--tokenizer-model /path/to/tokenizer.model \
239+
--workers 8 \
240+
--append-eod
241+
```
242+
243+
Input format: JSONL with `{"text": "..."}` per line
244+
Output: `.bin` and `.idx` files for efficient loading
245+
246+
### Transformer Layer Structure
247+
248+
Each transformer layer typically contains:
249+
- **Self-Attention**: Multi-head attention with optional TP sharding
250+
- QKV projections are TP-sharded across `num_attention_heads`
251+
- Output projection is column-parallel
252+
- **MLP**: Feed-forward network with optional TP sharding
253+
- First projection (up) is TP-sharded across `ffn_hidden_size`
254+
- Second projection (down) is row-parallel
255+
- **LayerNorm/RMSNorm**: Normalization layers
256+
- **Residual connections**: With optional LayerNorm
257+
258+
For NTP, call `ntp_init(layer, ddp_config)` after layer creation to set up parameter resharding metadata.
259+
260+
### Communication Patterns
261+
262+
- **TP communication**: All-reduce within TP group for row-parallel layers, identity for column-parallel
263+
- **PP communication**: Point-to-point between adjacent pipeline stages
264+
- **DP communication**: All-reduce gradients across DP group
265+
- **CP communication**: All-gather or point-to-point for sequence chunks
266+
- **NTP communication**: Custom parameter resharding between healthy/unhealthy DP replicas
267+
268+
### Performance Optimizations
269+
270+
Key flags for performance:
271+
- `--fp8-hybrid`: FP8 training (Hopper, Ada, Blackwell GPUs)
272+
- `--attention-backend`: FlashAttention via Transformer Engine (default, recommended)
273+
- `--recompute-activations`: Activation checkpointing for memory savings
274+
- `--overlap-grad-reduce`: Overlap DP gradient communication with backward pass
275+
- `--overlap-param-gather`: Overlap parameter gathering with computation
276+
- `--use-distributed-optimizer`: Shard optimizer state across DP ranks
277+
278+
## Important Conventions
279+
280+
### Parameter Attributes
281+
282+
Tensor-parallel parameters have special attributes:
283+
- `param.tensor_model_parallel` (bool): Whether parameter is TP-sharded
284+
- `param.partition_dim` (int): Which dimension is sharded (0 for rows, 1 for columns)
285+
- `param.send_splits` (list): For NTP, how to split parameter when sending to other ranks
286+
- `param.recv_splits` (list): For NTP, how to split parameter when receiving from other ranks
287+
288+
### Model Configuration
289+
290+
Models use config dataclasses (e.g., `TransformerConfig`) that contain:
291+
- Architecture params: num_layers, hidden_size, num_attention_heads, ffn_hidden_size
292+
- Parallelism settings: tensor_model_parallel_size, pipeline_model_parallel_size
293+
- Training settings: fp16, bf16, params_dtype, pipeline_dtype
294+
- Optimization settings: recompute_granularity, activation_checkpointing
295+
296+
### Checkpointing
297+
298+
Checkpoints are saved with distributed optimizer state sharding:
299+
- Model state is saved per rank or in a consolidated format
300+
- Optimizer state is sharded across DP ranks when using distributed optimizer
301+
- Use `tools/checkpoint/` utilities for conversion between formats
302+
303+
## Working with Nonuniform TP
304+
305+
When modifying or testing NTP code:
306+
307+
1. **Process group reconfiguration**: NTP modifies TP/CP groups after initialization
308+
- Healthy ranks keep full tp_base size
309+
- Unhealthy ranks recreate groups with only active ranks
310+
- Non-active (spare) ranks exit after group creation
311+
312+
2. **Parameter initialization**: Call `ntp_init(layer, ddp_config)` after creating transformer layers
313+
- Computes send_splits/recv_splits for healthy ranks
314+
- Unhealthy ranks skip initialization (no resharding needed)
315+
316+
3. **Gradient synchronization**: Modified in `param_and_grad_buffer.py`
317+
- Healthy ranks reshard parameters before sending to unhealthy ranks
318+
- Unhealthy ranks synchronize directly without resharding
319+
320+
4. **Testing NTP**: Use `test_ntp()` function in `nonuniform_tp.py` for basic validation
321+
322+
## Related Files
323+
324+
When working on specific features, these files are commonly modified together:
325+
326+
**Distributed Training**:
327+
- `megatron/core/parallel_state.py`: Process group management
328+
- `megatron/core/distributed/*.py`: DDP, FSDP, NTP implementations
329+
- `megatron/core/optimizer/optimizer.py`: Distributed optimizer
330+
331+
**Model Architecture**:
332+
- `megatron/core/models/*.py`: Model definitions
333+
- `megatron/core/transformer/*.py`: Layer building blocks
334+
- `gpt_builders.py`, `mamba_builders.py`: Model factory functions
335+
336+
**Tensor Parallelism**:
337+
- `megatron/core/tensor_parallel/*.py`: TP implementations
338+
- `megatron/core/transformer/dot_product_attention.py`: TP-aware attention
339+
340+
**Pipeline Parallelism**:
341+
- `megatron/core/pipeline_parallel/*.py`: PP schedules and communication

megatron/core/distributed/distributed_data_parallel_config.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
22

33
from dataclasses import dataclass
4-
from typing import Optional
4+
from typing import Dict, List, Optional, Tuple
55

66

77
@dataclass
@@ -157,6 +157,28 @@ class DistributedDataParallelConfig:
157157
delay_wgrad_compute: bool = False
158158
"""Delay the weight gradient computation to improve batch-level communication overlapping"""
159159

160+
tp_base: int = 8
161+
"""Base for tensor parallelism. This is the number of ranks in healthy tensor parallel groups.
162+
Used for nonuniform tensor parallelism."""
163+
164+
tp_spares: int = 0
165+
"""Number of spares for nonuniform tensor parallelism. When > 0, enables nonuniform TP mode
166+
where (tp_base - tp_spares) ranks handle computation and tp_spares ranks provide fault tolerance."""
167+
168+
num_reduced_tp_dp_ranks: int = 1
169+
"""Number of DP ranks that use reduced TP (tp_base - tp_spares). The remaining DP ranks use
170+
full tp_base. Reduced TP ranks are assumed to come first in the global rank ordering."""
171+
172+
non_active_ranks_per_dp: Optional[Dict[Tuple[int, int, int], List[int]]] = None
173+
"""Mapping of (DP rank, CP rank, PP rank) to list of non-active (spare) local TP rank IDs.
174+
This allows specifying arbitrary GPU failures across all parallelism dimensions.
175+
Example: {(0,0,0): [0,3], (0,1,0): [1,2], (1,0,0): [0,3]} means:
176+
- DP rank 0, CP rank 0, PP rank 0 has local TP ranks 0,3 as spares
177+
- DP rank 0, CP rank 1, PP rank 0 has local TP ranks 1,2 as spares
178+
- DP rank 1, CP rank 0, PP rank 0 has local TP ranks 0,3 as spares
179+
The number of non-active ranks must be consistent across CP replicas within each DP rank.
180+
If None, defaults to last tp_spares ranks as non-active."""
181+
160182
def __post_init__(self):
161183
import os
162184

0 commit comments

Comments
 (0)