Skip to content

Correctness Issue with Torch FSDP2 #2769

@haok1402

Description

@haok1402

Hi, I'm using core_v0.15.0 release with Megatron-LM and running a mid-training from the Llama3-8B checkpoint.

Exact same configuration, using the Megatron's use-distributed-optimizer, the loss starts off at 2.404998E+00, which is a correct range for a pretrained model.

 [2025-12-28 00:20:24] iteration        1/     100 | consumed samples:         1024 | elapsed time per iteration (ms): 27003.9 | throughput per GPU (TFLOP/s/GPU): 452.7 | learning rate: 9.900000E-05 | global batch size:  1024 | lm loss: 2.404998E+00 | loss scale: 1.0 | grad norm: 3.048 | number of skipped iterations:   0 | number of nan iterations:   0 |Number of parameters in transformer block in billions:  6.98 
 [2025-12-28 00:20:49] iteration        2/     100 | consumed samples:         2048 | elapsed time per iteration (ms): 24621.6 | throughput per GPU (TFLOP/s/GPU): 496.6 | learning rate: 9.800000E-05 | global batch size:  1024 | lm loss: 4.466884E+00 | loss scale: 1.0 | grad norm: 44.051 | number of skipped iterations:   0 | number of nan iterations:   0 |

However, when I try to use torch FSDP2 via use-torch-fsdp2, the loss starts off with at 1.386388E+01, which isn't a correct range...

 [2025-12-28 00:23:52] iteration        1/     100 | consumed samples:         1024 | elapsed time per iteration (ms): 37835.9 | throughput per GPU (TFLOP/s/GPU): 323.1 | learning rate: 9.900000E-05 | global batch size:  1024 | lm loss: 1.386388E+01 | loss scale: 1.0 | grad norm: 346.039 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2025-12-28 00:24:26] iteration        2/     100 | consumed samples:         2048 | elapsed time per iteration (ms): 34441.6 | throughput per GPU (TFLOP/s/GPU): 355.0 | learning rate: 9.800000E-05 | global batch size:  1024 | lm loss: 3.159755E+01 | loss scale: 1.0 | grad norm: 544.112 | number of skipped iterations:   0 | number of nan iterations:   0 |

The full command line arguments are attached below for reference.

#!/bin/bash

declare -A MODEL_CONFIG
MODEL_CONFIG[tokenizer-type]=HuggingFaceTokenizer
MODEL_CONFIG[tokenizer-model]=meta-llama/Meta-Llama-3-8B
MODEL_CONFIG[untie-embeddings-and-output-weights]=true
MODEL_CONFIG[position-embedding-type]=rope
MODEL_CONFIG[rotary-base]=500000
MODEL_CONFIG[max-position-embeddings]=8192
MODEL_CONFIG[num-layers]=32
MODEL_CONFIG[hidden-size]=4096
MODEL_CONFIG[ffn-hidden-size]=14336
MODEL_CONFIG[hidden-dropout]=0.0
MODEL_CONFIG[disable-bias-linear]=true
MODEL_CONFIG[swiglu]=true
MODEL_CONFIG[group-query-attention]=true
MODEL_CONFIG[num-attention-heads]=32
MODEL_CONFIG[num-query-groups]=8
MODEL_CONFIG[kv-channels]=128
MODEL_CONFIG[attention-dropout]=0.0
MODEL_CONFIG[init-method-std]=0.02
MODEL_CONFIG[normalization]=RMSNorm
MODEL_CONFIG[norm-epsilon]=1e-5

declare -A TRAIN_CONFIG
TRAIN_CONFIG[lr]=1e-4
TRAIN_CONFIG[optimizer]=adam
TRAIN_CONFIG[adam-beta1]=0.9
TRAIN_CONFIG[adam-beta2]=0.95
TRAIN_CONFIG[weight-decay]=0.1
TRAIN_CONFIG[clip-grad]=1.0
TRAIN_CONFIG[train-iters]=100
TRAIN_CONFIG[micro-batch-size]=1
TRAIN_CONFIG[global-batch-size]=1024
TRAIN_CONFIG[seq-length]=2048
TRAIN_CONFIG[bf16]=true

# TRAIN_CONFIG[use-distributed-optimizer]=true
TRAIN_CONFIG[use-torch-fsdp2]=true
TRAIN_CONFIG[no-gradient-accumulation-fusion]=true

TRAIN_CONFIG[log-interval]=1
TRAIN_CONFIG[no-one-logger]=true
TRAIN_CONFIG[log-throughput]=true
TRAIN_CONFIG[pretrained-checkpoint]=checkpoints/llama3-8b/megatron/
TRAIN_CONFIG[load]=checkpoints/debug/megatron/
TRAIN_CONFIG[save]=checkpoints/debug/megatron/
TRAIN_CONFIG[per-split-data-args-path]=checkpoints/debug/per-split-data-args.json
TRAIN_CONFIG[save-interval]=125
TRAIN_CONFIG[eval-iters]=0

MAIN_ARGS=()
for key in ${!MODEL_CONFIG[@]}; do val=${MODEL_CONFIG[$key]}; [[ $val == true ]] && MAIN_ARGS+=(--$key) || MAIN_ARGS+=(--$key $val); done
for key in ${!TRAIN_CONFIG[@]}; do val=${TRAIN_CONFIG[$key]}; [[ $val == true ]] && MAIN_ARGS+=(--$key) || MAIN_ARGS+=(--$key $val); done

TRUN_ARGS=()
TRUN_ARGS+=(--nnodes=1 --nproc-per-node=8)
TRUN_ARGS+=(--rdzv-backend=c10d --rdzv-endpoint=localhost:15213)

SCRIPT=Megatron-LM/pretrain_gpt.py
STDOUT=checkpoints/debug/midtrain.log

mkdir -p $(dirname $STDOUT)
export WANDB_RUN_ID=$(echo ${TRAIN_CONFIG[WANDB_EXP_NAME]} | md5sum | cut -c1-5)
torchrun ${TRUN_ARGS[@]} $SCRIPT ${MAIN_ARGS[@]} 2>&1 | tee $STDOUT

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions