Skip to content

Commit 6c19a96

Browse files
authored
Add resiliency examples for ft launcher and straggler detection (#2115)
Signed-off-by: Ananth Subramaniam <ansubramania@nvidia.com>
1 parent 40b1cf8 commit 6c19a96

File tree

5 files changed

+722
-0
lines changed

5 files changed

+722
-0
lines changed

examples/resiliency/README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Resiliency
2+
3+
Examples demonstrating Megatron-Bridge resiliency features powered by [nvidia-resiliency-ext](https://github.com/NVIDIA/nvidia-resiliency-ext).
4+
5+
## Documentation
6+
7+
- [Megatron-Bridge Resiliency Guide](https://docs.nvidia.com/nemo/megatron-bridge/latest/training/resiliency.html)
8+
- [nvidia-resiliency-ext Documentation](https://nvidia.github.io/nvidia-resiliency-ext/)
9+
10+
## Prerequisites
11+
12+
- [NeMo Framework Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo)
13+
- 2+ GPUs for distributed examples
14+
- HuggingFace token with Llama access (set `HF_TOKEN` env var)
15+
- Accept license at https://huggingface.co/meta-llama/Llama-3.2-1B
16+
17+
## Examples
18+
19+
### Straggler Detection
20+
21+
Detects slow-performing GPUs during training using NVRx straggler detection.
22+
23+
```bash
24+
uv run python -m torch.distributed.run --nproc_per_node=2 examples/resiliency/straggler_detection/straggler_detection_example.py
25+
```
26+
27+
Or use the launch script:
28+
29+
```bash
30+
./examples/resiliency/straggler_detection/run_straggler_detection.sh
31+
```
32+
33+
### Fault Tolerance
34+
35+
Enables automatic hang detection and job restart using the `ft_launcher`.
36+
37+
```bash
38+
./examples/resiliency/fault_tolerance/run_fault_tolerance.sh
39+
```
40+
41+
To test fault recovery with simulated failures:
42+
43+
```bash
44+
./examples/resiliency/fault_tolerance/run_fault_tolerance.sh --simulate-fault
45+
```
46+
47+
Note: Fault simulation requires careful timing - the fault must trigger after a checkpoint is saved but before training completes. See the `--simulate-fault` option in `fault_tolerance_example.py` for details.
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Fault Tolerance Example
18+
19+
Demonstrates fault tolerance during training using nvidia-resiliency-ext.
20+
Fault tolerance monitors training progress through sections (setup, step,
21+
checkpointing) and enables automatic restart on hang detection.
22+
23+
Prerequisites:
24+
- HuggingFace token with access to Llama models (set HF_TOKEN env var)
25+
- Accept Llama license at https://huggingface.co/meta-llama/Llama-3.2-1B
26+
27+
IMPORTANT: This script must be run with ft_launcher, not torch.distributed.run.
28+
29+
Fault Simulation Mode (--simulate-fault):
30+
Demonstrates fault recovery by killing a rank after a delay.
31+
32+
Timing requirements for successful recovery:
33+
checkpoint_time < fault_delay < total_training_time
34+
35+
Where:
36+
- checkpoint_time: Wall-clock time to reach and finalize the first checkpoint
37+
- fault_delay: Seconds before fault injection (--fault-delay)
38+
- total_training_time: Wall-clock time for all training iterations
39+
40+
If fault_delay < checkpoint_time: Job restarts from iteration 0 indefinitely
41+
If fault_delay > total_training_time: Training completes before fault triggers
42+
43+
Usage:
44+
uv run ft_launcher \\
45+
--rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29500 \\
46+
--nnodes=1 --nproc-per-node=2 \\
47+
--ft-param-rank_section_timeouts=setup:600,step:180,checkpointing:420 \\
48+
--ft-param-rank_out_of_section_timeout=300 \\
49+
examples/resiliency/fault_tolerance/fault_tolerance_example.py
50+
51+
# With fault simulation:
52+
uv run ft_launcher ... --max-restarts=3 \\
53+
examples/resiliency/fault_tolerance/fault_tolerance_example.py --simulate-fault
54+
55+
# Or use the launch script:
56+
./examples/resiliency/fault_tolerance/run_fault_tolerance.sh
57+
./examples/resiliency/fault_tolerance/run_fault_tolerance.sh --simulate-fault
58+
59+
Documentation:
60+
- Megatron-Bridge: https://docs.nvidia.com/nemo/megatron-bridge/latest/training/resiliency.html
61+
- NVRx Fault Tolerance: https://nvidia.github.io/nvidia-resiliency-ext/
62+
"""
63+
64+
import argparse
65+
import logging
66+
import os
67+
68+
import torch
69+
70+
from megatron.bridge.models import AutoBridge
71+
from megatron.bridge.training.config import (
72+
CheckpointConfig,
73+
ConfigContainer,
74+
DistributedDataParallelConfig,
75+
FaultToleranceConfig,
76+
LoggerConfig,
77+
MockGPTDatasetConfig,
78+
OptimizerConfig,
79+
RNGConfig,
80+
SchedulerConfig,
81+
TokenizerConfig,
82+
TrainingConfig,
83+
)
84+
from megatron.bridge.training.gpt_step import forward_step
85+
from megatron.bridge.training.pretrain import pretrain
86+
87+
88+
# Default model - smallest Llama 3.2 for fast examples
89+
DEFAULT_MODEL = "meta-llama/Llama-3.2-1B"
90+
91+
92+
def create_config(
93+
checkpoint_dir: str,
94+
model_id: str = DEFAULT_MODEL,
95+
train_iters: int = 50,
96+
save_interval: int = 25,
97+
simulate_fault: bool = False,
98+
fault_type: str = "rank_killed",
99+
fault_rank: int = 1,
100+
fault_delay: float = 60.0,
101+
) -> ConfigContainer:
102+
"""Create training configuration with fault tolerance enabled.
103+
104+
Args:
105+
checkpoint_dir: Directory for checkpoints (required for FT state persistence).
106+
model_id: HuggingFace model ID to load.
107+
train_iters: Number of training iterations.
108+
save_interval: Checkpoint save interval.
109+
simulate_fault: Whether to simulate a fault for testing recovery.
110+
fault_type: Type of fault to simulate ("rank_killed", "rank_hung", "random").
111+
fault_rank: Which rank to fail (use -1 for random selection).
112+
fault_delay: Seconds to wait before injecting the fault.
113+
"""
114+
seq_length = 512 # Short sequence for fast examples
115+
116+
# Load model configuration from HuggingFace
117+
bridge = AutoBridge.from_hf_pretrained(model_id, torch_dtype=torch.bfloat16)
118+
model_config = bridge.to_megatron_provider()
119+
model_config.tensor_model_parallel_size = 1
120+
model_config.pipeline_model_parallel_size = 1
121+
model_config.context_parallel_size = 1
122+
model_config.sequence_parallel = False
123+
model_config.bf16 = True
124+
model_config.seq_length = seq_length
125+
126+
train_config = TrainingConfig(
127+
train_iters=train_iters,
128+
micro_batch_size=4,
129+
global_batch_size=8,
130+
eval_interval=train_iters + 1, # Disable evaluation
131+
eval_iters=0,
132+
exit_signal_handler=True,
133+
)
134+
135+
dataset_config = MockGPTDatasetConfig(
136+
random_seed=1234,
137+
reset_attention_mask=False,
138+
reset_position_ids=False,
139+
eod_mask_loss=False,
140+
seq_length=seq_length,
141+
num_dataset_builder_threads=1,
142+
data_sharding=True,
143+
dataloader_type="single",
144+
num_workers=1,
145+
)
146+
147+
optimizer_config = OptimizerConfig(
148+
optimizer="adam",
149+
bf16=True,
150+
fp16=False,
151+
adam_beta1=0.9,
152+
adam_beta2=0.95,
153+
adam_eps=1e-5,
154+
use_distributed_optimizer=True,
155+
clip_grad=1.0,
156+
lr=1e-4,
157+
weight_decay=0.01,
158+
min_lr=1e-6,
159+
)
160+
161+
scheduler_config = SchedulerConfig(
162+
start_weight_decay=0.01,
163+
end_weight_decay=0.01,
164+
weight_decay_incr_style="constant",
165+
lr_decay_style="cosine",
166+
lr_warmup_iters=2,
167+
lr_warmup_init=0.0,
168+
lr_decay_iters=train_iters,
169+
override_opt_param_scheduler=True,
170+
)
171+
172+
ddp_config = DistributedDataParallelConfig(
173+
check_for_nan_in_grad=True,
174+
grad_reduce_in_fp32=True,
175+
overlap_grad_reduce=True,
176+
overlap_param_gather=True,
177+
average_in_collective=True,
178+
use_distributed_optimizer=True,
179+
)
180+
181+
checkpoint_config = CheckpointConfig(
182+
save=checkpoint_dir,
183+
load=checkpoint_dir,
184+
save_interval=save_interval,
185+
ckpt_format="torch_dist",
186+
async_save=True,
187+
)
188+
189+
# Fault Tolerance Configuration
190+
# See: https://nvidia.github.io/nvidia-resiliency-ext/
191+
# When simulating faults, disable timeout calculation since we want to
192+
# demonstrate recovery behavior, not timeout learning.
193+
ft_config = FaultToleranceConfig(
194+
enable_ft_package=True,
195+
calc_ft_timeouts=not simulate_fault,
196+
# Fault simulation settings (only used when simulate_fault=True)
197+
simulate_fault=simulate_fault,
198+
simulated_fault_type=fault_type,
199+
simulated_fault_rank=fault_rank if fault_rank >= 0 else None,
200+
simulated_fault_base_delay=fault_delay,
201+
)
202+
203+
return ConfigContainer(
204+
train=train_config,
205+
model=model_config,
206+
optimizer=optimizer_config,
207+
scheduler=scheduler_config,
208+
dataset=dataset_config,
209+
logger=LoggerConfig(log_interval=10, tensorboard_dir=None),
210+
tokenizer=TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=model_config.padded_vocab_size),
211+
checkpoint=checkpoint_config,
212+
rng=RNGConfig(seed=1234),
213+
ddp=ddp_config,
214+
ft=ft_config,
215+
)
216+
217+
218+
def main() -> None:
219+
"""Run fault tolerance example with configurable parameters."""
220+
parser = argparse.ArgumentParser(description="Fault Tolerance Example")
221+
parser.add_argument("--model", type=str, default=DEFAULT_MODEL, help="HuggingFace model ID")
222+
parser.add_argument("--train-iters", type=int, default=None, help="Number of training iterations")
223+
parser.add_argument(
224+
"--checkpoint-dir",
225+
type=str,
226+
default="/tmp/megatron_bridge_ft_example",
227+
help="Checkpoint directory (must be shared across all ranks)",
228+
)
229+
230+
# Fault simulation options
231+
parser.add_argument(
232+
"--simulate-fault",
233+
action="store_true",
234+
help="Enable fault simulation to test recovery",
235+
)
236+
parser.add_argument(
237+
"--fault-type",
238+
type=str,
239+
default="rank_killed",
240+
choices=["rank_killed", "rank_hung", "random"],
241+
help="Type of fault to simulate",
242+
)
243+
parser.add_argument(
244+
"--fault-rank",
245+
type=int,
246+
default=1,
247+
help="Rank to fail (-1 for random)",
248+
)
249+
parser.add_argument(
250+
"--fault-delay",
251+
type=float,
252+
default=60.0,
253+
help="Seconds before fault injection (must be after first checkpoint)",
254+
)
255+
args = parser.parse_args()
256+
257+
logging.basicConfig(level=logging.INFO, format="%(name)s: %(message)s")
258+
259+
# Set defaults based on mode
260+
if args.train_iters is None:
261+
# Fault simulation needs more iterations so training outlasts the fault delay
262+
args.train_iters = 2000 if args.simulate_fault else 50
263+
264+
# Checkpoint less frequently for longer runs
265+
save_interval = 200 if args.simulate_fault else 25
266+
267+
# Ensure checkpoint directory exists (all ranks use the same path)
268+
os.makedirs(args.checkpoint_dir, exist_ok=True)
269+
270+
config = create_config(
271+
checkpoint_dir=args.checkpoint_dir,
272+
model_id=args.model,
273+
train_iters=args.train_iters,
274+
save_interval=save_interval,
275+
simulate_fault=args.simulate_fault,
276+
fault_type=args.fault_type,
277+
fault_rank=args.fault_rank,
278+
fault_delay=args.fault_delay,
279+
)
280+
pretrain(config=config, forward_step_func=forward_step)
281+
282+
283+
if __name__ == "__main__":
284+
main()

0 commit comments

Comments
 (0)