|
| 1 | +TRAIN_BACKEND=${SLIME_SCRIPT_TRAIN_BACKEND:-"megatron"} |
| 2 | +MODEL_NAME=${SLIME_SCRIPT_MODEL_NAME:-"Qwen3-VL-8B-Instruct"} |
| 3 | +DATASET_NAME=${SLIME_SCRIPT_DATASET_NAME:-"chenhegu/geo3k_imgurl"} |
| 4 | +NUM_GPUS=${SLIME_SCRIPT_NUM_GPUS:-8} |
| 5 | +DATASET_LOCAL_NAME=$(basename "$DATASET_NAME") |
| 6 | + |
| 7 | +# Validate MODEL_NAME |
| 8 | +VALID_MODELS=" |
| 9 | + Qwen3-VL-2B-Instruct |
| 10 | + Qwen3-VL-4B-Instruct |
| 11 | + Qwen3-VL-8B-Instruct |
| 12 | + Qwen3-VL-2B-Thinking |
| 13 | + Qwen3-VL-4B-Thinking |
| 14 | + Qwen3-VL-8B-Thinking |
| 15 | + Qwen3-VL-30B-A3B-Instruct |
| 16 | + Qwen3-VL-235B-A22B-Instruct |
| 17 | + Qwen3-VL-30B-A3B-Thinking |
| 18 | + Qwen3-VL-235B-A22B-Thinking |
| 19 | +" |
| 20 | +if ! echo "$VALID_MODELS" | grep -qw "$MODEL_NAME"; then |
| 21 | + echo "Error: MODEL_NAME must be one of: $VALID_MODELS" |
| 22 | + exit 1 |
| 23 | +fi |
| 24 | + |
| 25 | +MODEL_NAME_LOWER=$(echo "$MODEL_NAME" | tr '[:upper:]' '[:lower:]') |
| 26 | + |
| 27 | +# External Ray flag |
| 28 | +if [ -z "$SLIME_SCRIPT_EXTERNAL_RAY" ] || [ "$SLIME_SCRIPT_EXTERNAL_RAY" = "0" ]; then |
| 29 | + USE_EXTERNAL_RAY=0 |
| 30 | +else |
| 31 | + USE_EXTERNAL_RAY=1 |
| 32 | +fi |
| 33 | + |
| 34 | +# Cleanup |
| 35 | +pkill -9 sglang |
| 36 | +sleep 3 |
| 37 | +if [ "$USE_EXTERNAL_RAY" = "0" ]; then |
| 38 | + ray stop --force |
| 39 | + pkill -9 ray |
| 40 | +fi |
| 41 | +pkill -9 slime |
| 42 | +sleep 3 |
| 43 | +if [ "$USE_EXTERNAL_RAY" = "0" ]; then |
| 44 | + pkill -9 ray |
| 45 | +fi |
| 46 | +pkill -9 slime |
| 47 | +pkill -9 redis |
| 48 | + |
| 49 | +set -ex |
| 50 | + |
| 51 | +export PYTHONBUFFERED=16 |
| 52 | + |
| 53 | +# Detect NVLink |
| 54 | +NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l) |
| 55 | +if [ "$NVLINK_COUNT" -gt 0 ]; then |
| 56 | + HAS_NVLINK=1 |
| 57 | +else |
| 58 | + HAS_NVLINK=0 |
| 59 | +fi |
| 60 | +echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)" |
| 61 | + |
| 62 | +# Download model and dataset |
| 63 | +mkdir -p /root/models /root/datasets |
| 64 | +if [ ! -d "/root/models/${MODEL_NAME}" ]; then |
| 65 | + hf download Qwen/${MODEL_NAME} --local-dir /root/models/${MODEL_NAME} |
| 66 | +fi |
| 67 | +if [ ! -d "/root/datasets/${DATASET_LOCAL_NAME}" ]; then |
| 68 | + hf download --repo-type dataset ${DATASET_NAME} --local-dir /root/datasets/${DATASET_LOCAL_NAME} |
| 69 | +fi |
| 70 | + |
| 71 | +# Common args |
| 72 | +CKPT_ARGS=( |
| 73 | + --hf-checkpoint /root/models/${MODEL_NAME} |
| 74 | + --load /root/models/${MODEL_NAME} |
| 75 | + --rotary-base 5000000 |
| 76 | +) |
| 77 | + |
| 78 | +SFT_ARGS=( |
| 79 | + --rollout-function-path slime.rollout.sft_rollout.generate_rollout |
| 80 | + --prompt-data /root/datasets/${DATASET_LOCAL_NAME}/train_formatted.parquet |
| 81 | + --input-key messages |
| 82 | + --apply-chat-template |
| 83 | + --rollout-shuffle |
| 84 | + --num-epoch 3000 |
| 85 | + --rollout-batch-size 128 |
| 86 | + --global-batch-size 128 |
| 87 | + |
| 88 | + --loss-type sft_loss |
| 89 | + --calculate-per-token-loss |
| 90 | + --disable-compute-advantages-and-returns |
| 91 | + --debug-train-only |
| 92 | +) |
| 93 | + |
| 94 | +# required for vlm datasets |
| 95 | +MULTIMODAL_KEYS='{"image": "images"}' |
| 96 | + |
| 97 | + |
| 98 | +OPTIMIZER_ARGS=( |
| 99 | + --optimizer adam |
| 100 | + --lr 1e-5 |
| 101 | + --lr-decay-style cosine |
| 102 | + --min-lr 1e-6 |
| 103 | + --lr-warmup-fraction 0.1 |
| 104 | + --weight-decay 0.1 |
| 105 | + --adam-beta1 0.9 |
| 106 | + --adam-beta2 0.95 |
| 107 | +) |
| 108 | + |
| 109 | +if [ -n "$WANDB_API_KEY" ]; then |
| 110 | + WANDB_ARGS=( |
| 111 | + --use-wandb |
| 112 | + --wandb-project slime-geo3k-vlm-sft |
| 113 | + --wandb-group ${MODEL_NAME_LOWER}-${TRAIN_BACKEND} |
| 114 | + --wandb-key ${WANDB_API_KEY} |
| 115 | + --disable-wandb-random-suffix |
| 116 | + ) |
| 117 | +else |
| 118 | + WANDB_ARGS=() |
| 119 | +fi |
| 120 | + |
| 121 | +# Backend-specific args |
| 122 | +if [ "$TRAIN_BACKEND" = "fsdp" ]; then |
| 123 | + BACKEND_ARGS=( |
| 124 | + --train-backend fsdp |
| 125 | + --gradient-checkpointing |
| 126 | + --attn-implementation flash_attention_3 |
| 127 | + --update-weight-buffer-size 536870912 |
| 128 | + ) |
| 129 | +else |
| 130 | + # megatron backend (default) |
| 131 | + BACKEND_ARGS=( |
| 132 | + --train-backend megatron |
| 133 | + --tensor-model-parallel-size 4 |
| 134 | + --sequence-parallel |
| 135 | + --pipeline-model-parallel-size 1 |
| 136 | + --context-parallel-size 1 |
| 137 | + --expert-model-parallel-size 1 |
| 138 | + --expert-tensor-parallel-size 1 |
| 139 | + --recompute-granularity full |
| 140 | + --recompute-method uniform |
| 141 | + --recompute-num-layers 1 |
| 142 | + --use-dynamic-batch-size |
| 143 | + --max-tokens-per-gpu 4096 |
| 144 | + --attention-dropout 0.0 |
| 145 | + --hidden-dropout 0.0 |
| 146 | + --accumulate-allreduce-grads-in-fp32 |
| 147 | + --attention-softmax-in-fp32 |
| 148 | + --attention-backend flash |
| 149 | + --megatron-to-hf-mode bridge |
| 150 | + ) |
| 151 | + |
| 152 | + # get MODEL_ARGS from scripts/models for megatron backend |
| 153 | + SLIME_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." &>/dev/null && pwd)" |
| 154 | + MODEL_ARGS_FILE=$(echo "$MODEL_NAME" | sed 's/-Instruct//g; s/-Thinking//g; s/Qwen3-VL-/qwen3-/g; s/-2B/-1.7B/g') |
| 155 | + source "${SLIME_DIR}/scripts/models/${MODEL_ARGS_FILE}.sh" |
| 156 | +fi |
| 157 | + |
| 158 | +# Start Ray if not using external Ray |
| 159 | +if [ "$USE_EXTERNAL_RAY" = "0" ]; then |
| 160 | + export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} |
| 161 | + export no_proxy="127.0.0.1,${MASTER_ADDR}" |
| 162 | + ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus ${NUM_GPUS} --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265 |
| 163 | +fi |
| 164 | + |
| 165 | +# Build runtime env |
| 166 | +RUNTIME_ENV_JSON="{ |
| 167 | + \"env_vars\": { |
| 168 | + \"PYTHONPATH\": \"/root/Megatron-LM/\", |
| 169 | + \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\", |
| 170 | + \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\" |
| 171 | + } |
| 172 | +}" |
| 173 | + |
| 174 | +ray job submit --address="http://127.0.0.1:8265" \ |
| 175 | + --runtime-env-json="${RUNTIME_ENV_JSON}" \ |
| 176 | + -- python3 train_async.py \ |
| 177 | + --actor-num-nodes 1 \ |
| 178 | + --actor-num-gpus-per-node ${NUM_GPUS} \ |
| 179 | + --multimodal-keys "${MULTIMODAL_KEYS}" \ |
| 180 | + ${MODEL_ARGS[@]} \ |
| 181 | + ${CKPT_ARGS[@]} \ |
| 182 | + ${SFT_ARGS[@]} \ |
| 183 | + ${EVAL_ARGS[@]} \ |
| 184 | + ${OPTIMIZER_ARGS[@]} \ |
| 185 | + ${WANDB_ARGS[@]} \ |
| 186 | + ${BACKEND_ARGS[@]} |
0 commit comments