Skip to content

Commit 8623ba6

Browse files
committed
add resiliency examples for ft launcher and straggler detection
Signed-off-by: Ananth Subramaniam <[email protected]>
1 parent 7a50d2e commit 8623ba6

File tree

6 files changed

+921
-0
lines changed

6 files changed

+921
-0
lines changed

examples/resiliency/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Resiliency
2+
3+
Examples demonstrating Megatron-Bridge resiliency features powered by [nvidia-resiliency-ext](https://github.com/NVIDIA/nvidia-resiliency-ext).
4+
5+
## Documentation
6+
7+
- [Megatron-Bridge Resiliency Guide](https://docs.nvidia.com/nemo/megatron-bridge/latest/training/resiliency.html)
8+
- [nvidia-resiliency-ext Documentation](https://nvidia.github.io/nvidia-resiliency-ext/)
9+
10+
## Prerequisites
11+
12+
- [NeMo Framework Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo)
13+
- 2+ GPUs for distributed examples
14+
15+
## Examples
16+
17+
### Straggler Detection
18+
19+
Detects slow-performing GPUs during training using NVRx straggler detection.
20+
21+
```bash
22+
uv run python -m torch.distributed.run --nproc_per_node=2 examples/resiliency/straggler_detection/basic_straggler_detection.py
23+
```
24+
25+
Or use the launch script:
26+
27+
```bash
28+
./examples/resiliency/straggler_detection/run_straggler_detection.sh
29+
```
30+
31+
### Fault Tolerance
32+
33+
Enables automatic hang detection and job restart using the `ft_launcher`.
34+
35+
```bash
36+
./examples/resiliency/fault_tolerance/run_fault_tolerance.sh
37+
```
38+
39+
To test fault recovery with simulated failures:
40+
41+
```bash
42+
./examples/resiliency/fault_tolerance/run_fault_tolerance.sh --simulate-fault
43+
```
44+
45+
Note: Fault simulation requires careful timing - the fault must trigger after a checkpoint is saved but before training completes. See `simulate_fault.py` for details.
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Basic Fault Tolerance Example
18+
19+
This example demonstrates how to enable fault tolerance during training
20+
using the nvidia-resiliency-ext package. Fault tolerance monitors training
21+
progress through sections (setup, step, checkpointing) and enables automatic
22+
restart on hang detection.
23+
24+
IMPORTANT: This script must be run with ft_launcher, not torch.distributed.run.
25+
26+
Usage:
27+
uv run ft_launcher \\
28+
--rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29500 \\
29+
--nnodes=1 --nproc-per-node=2 \\
30+
--ft-param-rank_section_timeouts=setup:600,step:180,checkpointing:420 \\
31+
--ft-param-rank_out_of_section_timeout=300 \\
32+
examples/resiliency/fault_tolerance/basic_fault_tolerance.py
33+
34+
# Or use the launch script:
35+
./examples/resiliency/fault_tolerance/run_fault_tolerance.sh
36+
37+
Documentation:
38+
- Megatron-Bridge: https://docs.nvidia.com/nemo/megatron-bridge/latest/training/resiliency.html
39+
- NVRx Fault Tolerance: https://nvidia.github.io/nvidia-resiliency-ext/
40+
"""
41+
42+
import argparse
43+
import logging
44+
import os
45+
from dataclasses import dataclass
46+
47+
import torch
48+
49+
from megatron.bridge.models.llama import Llama3ModelProvider
50+
from megatron.bridge.training.config import (
51+
CheckpointConfig,
52+
ConfigContainer,
53+
DistributedDataParallelConfig,
54+
FaultToleranceConfig,
55+
LoggerConfig,
56+
MockGPTDatasetConfig,
57+
OptimizerConfig,
58+
RNGConfig,
59+
SchedulerConfig,
60+
TokenizerConfig,
61+
TrainingConfig,
62+
)
63+
from megatron.bridge.training.gpt_step import forward_step
64+
from megatron.bridge.training.pretrain import pretrain
65+
66+
67+
@dataclass
68+
class TinyLlama3Config(Llama3ModelProvider):
69+
"""Tiny Llama3 model (~145M params) for fast example execution."""
70+
71+
rotary_base: int = 500_000
72+
num_layers: int = 4
73+
hidden_size: int = 768
74+
ffn_hidden_size: int = 2688
75+
num_attention_heads: int = 16
76+
vocab_size: int | None = None
77+
78+
79+
def create_config(
80+
checkpoint_dir: str,
81+
train_iters: int = 50,
82+
calc_timeouts: bool = True,
83+
) -> ConfigContainer:
84+
"""Create training configuration with fault tolerance enabled.
85+
86+
Args:
87+
checkpoint_dir: Directory for checkpoints (required for FT state persistence).
88+
train_iters: Number of training iterations.
89+
calc_timeouts: Whether to calculate and update FT timeouts based on observed times.
90+
"""
91+
seq_length = 2048
92+
93+
model_config = TinyLlama3Config(
94+
tensor_model_parallel_size=1,
95+
pipeline_model_parallel_size=1,
96+
context_parallel_size=1,
97+
sequence_parallel=False,
98+
attention_softmax_in_fp32=True,
99+
pipeline_dtype=torch.bfloat16,
100+
bf16=True,
101+
seq_length=seq_length,
102+
make_vocab_size_divisible_by=128,
103+
)
104+
105+
train_config = TrainingConfig(
106+
train_iters=train_iters,
107+
micro_batch_size=4,
108+
global_batch_size=8,
109+
eval_interval=train_iters + 1, # Disable evaluation
110+
eval_iters=0,
111+
exit_signal_handler=True,
112+
)
113+
114+
dataset_config = MockGPTDatasetConfig(
115+
random_seed=1234,
116+
reset_attention_mask=False,
117+
reset_position_ids=False,
118+
eod_mask_loss=False,
119+
seq_length=seq_length,
120+
num_dataset_builder_threads=1,
121+
data_sharding=True,
122+
dataloader_type="single",
123+
num_workers=1,
124+
)
125+
126+
optimizer_config = OptimizerConfig(
127+
optimizer="adam",
128+
bf16=True,
129+
fp16=False,
130+
adam_beta1=0.9,
131+
adam_beta2=0.95,
132+
adam_eps=1e-5,
133+
use_distributed_optimizer=True,
134+
clip_grad=1.0,
135+
lr=1e-4,
136+
weight_decay=0.01,
137+
min_lr=1e-6,
138+
)
139+
140+
scheduler_config = SchedulerConfig(
141+
start_weight_decay=0.01,
142+
end_weight_decay=0.01,
143+
weight_decay_incr_style="constant",
144+
lr_decay_style="cosine",
145+
lr_warmup_iters=2,
146+
lr_warmup_init=0.0,
147+
lr_decay_iters=train_iters,
148+
override_opt_param_scheduler=True,
149+
)
150+
151+
ddp_config = DistributedDataParallelConfig(
152+
check_for_nan_in_grad=True,
153+
grad_reduce_in_fp32=True,
154+
overlap_grad_reduce=True,
155+
overlap_param_gather=True,
156+
average_in_collective=True,
157+
use_distributed_optimizer=True,
158+
)
159+
160+
# Checkpoint configuration (required for fault tolerance)
161+
checkpoint_config = CheckpointConfig(
162+
save=checkpoint_dir,
163+
load=checkpoint_dir,
164+
save_interval=25, # Save every 25 iterations
165+
ckpt_format="torch_dist",
166+
async_save=True, # Async checkpoints for better performance
167+
)
168+
169+
# Fault Tolerance Configuration
170+
# See: https://nvidia.github.io/nvidia-resiliency-ext/
171+
ft_config = FaultToleranceConfig(
172+
enable_ft_package=True,
173+
calc_ft_timeouts=calc_timeouts, # Learn optimal timeouts from observed intervals
174+
)
175+
176+
return ConfigContainer(
177+
train=train_config,
178+
model=model_config,
179+
optimizer=optimizer_config,
180+
scheduler=scheduler_config,
181+
dataset=dataset_config,
182+
logger=LoggerConfig(log_interval=10, tensorboard_dir=None),
183+
tokenizer=TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=10000),
184+
checkpoint=checkpoint_config,
185+
rng=RNGConfig(seed=1234),
186+
ddp=ddp_config,
187+
ft=ft_config,
188+
)
189+
190+
191+
def main() -> None:
192+
"""Run fault tolerance example with configurable parameters."""
193+
parser = argparse.ArgumentParser(description="Fault Tolerance Example")
194+
parser.add_argument("--train-iters", type=int, default=50, help="Number of training iterations")
195+
parser.add_argument(
196+
"--checkpoint-dir",
197+
type=str,
198+
default="/tmp/megatron_bridge_ft_example",
199+
help="Checkpoint directory (must be shared across all ranks)",
200+
)
201+
parser.add_argument(
202+
"--no-calc-timeouts",
203+
action="store_true",
204+
help="Disable automatic timeout calculation",
205+
)
206+
args = parser.parse_args()
207+
208+
logging.basicConfig(level=logging.INFO, format="%(name)s: %(message)s")
209+
210+
# Ensure checkpoint directory exists (all ranks use the same path)
211+
os.makedirs(args.checkpoint_dir, exist_ok=True)
212+
213+
config = create_config(
214+
checkpoint_dir=args.checkpoint_dir,
215+
train_iters=args.train_iters,
216+
calc_timeouts=not args.no_calc_timeouts,
217+
)
218+
pretrain(config=config, forward_step_func=forward_step)
219+
220+
221+
if __name__ == "__main__":
222+
main()
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
#!/bin/bash
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# Fault Tolerance Example Launch Script
17+
#
18+
# This script uses ft_launcher from nvidia-resiliency-ext to run training
19+
# with fault tolerance enabled.
20+
#
21+
# Usage:
22+
# ./examples/resiliency/fault_tolerance/run_fault_tolerance.sh
23+
# ./examples/resiliency/fault_tolerance/run_fault_tolerance.sh --simulate-fault
24+
# ./examples/resiliency/fault_tolerance/run_fault_tolerance.sh --nproc 4
25+
#
26+
# Fault Simulation Mode (--simulate-fault):
27+
# Demonstrates fault recovery by killing a rank after a delay.
28+
#
29+
# Timing requirements for successful recovery:
30+
# checkpoint_time < fault_delay < total_training_time
31+
#
32+
# Defaults are tuned for tiny model (~145M params):
33+
# - train_iters=2000 (~90s total training)
34+
# - save_interval=200 (first checkpoint at ~9s)
35+
# - fault_delay=60s (fault triggers after checkpoint, before completion)
36+
37+
set -euo pipefail
38+
39+
# Default values
40+
NPROC_PER_NODE="${NPROC_PER_NODE:-2}"
41+
TRAIN_ITERS="" # Will be set based on mode
42+
MASTER_PORT="${MASTER_PORT:-29500}"
43+
SIMULATE_FAULT=false
44+
MAX_RESTARTS=0
45+
FAULT_DELAY="${FAULT_DELAY:-60}" # Seconds before fault injection (must be after first checkpoint)
46+
47+
# Parse arguments
48+
while [[ $# -gt 0 ]]; do
49+
case $1 in
50+
--nproc)
51+
NPROC_PER_NODE="$2"
52+
shift 2
53+
;;
54+
--train-iters)
55+
TRAIN_ITERS="$2"
56+
shift 2
57+
;;
58+
--simulate-fault)
59+
SIMULATE_FAULT=true
60+
MAX_RESTARTS=3
61+
shift
62+
;;
63+
--fault-delay)
64+
FAULT_DELAY="$2"
65+
shift 2
66+
;;
67+
--max-restarts)
68+
MAX_RESTARTS="$2"
69+
shift 2
70+
;;
71+
*)
72+
echo "Unknown option: $1"
73+
exit 1
74+
;;
75+
esac
76+
done
77+
78+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
79+
80+
# Set default TRAIN_ITERS based on mode (if not explicitly provided)
81+
# For --simulate-fault: Use 2000 iterations (~90s) so training outlasts the fault delay (60s)
82+
# For basic mode: Use 50 iterations for quick demonstration
83+
if [ -z "$TRAIN_ITERS" ]; then
84+
if [ "$SIMULATE_FAULT" = true ]; then
85+
TRAIN_ITERS=2000
86+
else
87+
TRAIN_ITERS=50
88+
fi
89+
fi
90+
91+
# Check if ft_launcher is available via uv
92+
if ! uv run ft_launcher --help &> /dev/null; then
93+
echo "Error: ft_launcher not found."
94+
echo "Please use the NeMo Framework Docker container: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo"
95+
exit 1
96+
fi
97+
98+
# Reduce torch distributed noise
99+
export TORCH_CPP_LOG_LEVEL="${TORCH_CPP_LOG_LEVEL:-error}"
100+
101+
echo "Running Fault Tolerance Example"
102+
echo " GPUs: ${NPROC_PER_NODE}"
103+
echo " Iterations: ${TRAIN_ITERS}"
104+
echo " Simulate Fault: ${SIMULATE_FAULT}"
105+
echo " Max Restarts: ${MAX_RESTARTS}"
106+
if [ "$SIMULATE_FAULT" = true ]; then
107+
echo " Fault Delay: ${FAULT_DELAY}s"
108+
fi
109+
echo ""
110+
111+
if [ "$SIMULATE_FAULT" = true ]; then
112+
SCRIPT="${SCRIPT_DIR}/simulate_fault.py"
113+
SCRIPT_ARGS="--train-iters ${TRAIN_ITERS} --fault-delay ${FAULT_DELAY}"
114+
else
115+
SCRIPT="${SCRIPT_DIR}/basic_fault_tolerance.py"
116+
SCRIPT_ARGS="--train-iters ${TRAIN_ITERS}"
117+
fi
118+
119+
uv run ft_launcher \
120+
--rdzv_backend=c10d \
121+
--rdzv_endpoint="127.0.0.1:${MASTER_PORT}" \
122+
--nnodes=1 \
123+
--nproc-per-node="${NPROC_PER_NODE}" \
124+
--ft-param-rank_section_timeouts=setup:600,step:180,checkpointing:420 \
125+
--ft-param-rank_out_of_section_timeout=300 \
126+
--monitor-interval=5 \
127+
--max-restarts="${MAX_RESTARTS}" \
128+
"${SCRIPT}" ${SCRIPT_ARGS}

0 commit comments

Comments
 (0)