Skip to content

Commit fa3ff20

Browse files
committed
refactor: Update image size and conversation logic for generating text from images
1 parent c183c08 commit fa3ff20

File tree

3 files changed

+79
-3
lines changed

3 files changed

+79
-3
lines changed

scripts/train/README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,9 @@
22

33
We first release the basic training scripts for LLaVA-NeXT. It's based on previous LLaVA's training scripts and researchers familiar with LLaVA will find it easy to use.
44

5-
We will later release the more detailed training scripts for our LLaVA OneVision models including the mid stage, single-image final stage and one-vision final stage.
6-
> They are basically the same as the basic training scripts, but with some modifications, such as the data yaml.
5+
We will gradually release the more detailed training scripts for our LLaVA OneVision models including the mid stage, single-image final stage and one-vision final stage.
6+
> They are basically the same as the basic training scripts, but with some modifications, such as the data yaml.
7+
8+
- `finetune_clip.sh`: This could be seen as the first image version LLaVA-NeXT (2024-01) training script, with `anyres` strategy and maximum 2x2 image grids.
9+
- `finetune_siglip.sh`: Same but with `siglip` encoder, each grid becomes 729 tokens.
10+
- `finetune_onevision.sh`: This is our latest training script, with `anyres_max_9` strategy and image grids weaving from 1x1 to 6x6, at most to 2304x2304 resolution. Inside the script, we also incorporate the multi-image and video data into training loop. the detail token strategy could be found in our paper.
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
export OMP_NUM_THREADS=8
2+
export NCCL_IB_DISABLE=0
3+
export NCCL_IB_GID_INDEX=3
4+
export NCCL_SOCKET_IFNAME=eth0
5+
export NCCL_DEBUG=INFO
6+
7+
LLM_VERSION="Qwen/Qwen2-7B-Instruct"
8+
# for 7b model we recommend bs=1, accum=2, 16 nodes, 128 gpus, lr=1e-5, warmup=0.03
9+
# for 72b model we recommend bs=1, accum=1, 32 nodes, 256 gpus, lr=1e-5, warmup=0.03
10+
LLM_VERSION_CLEAN="${LLM_VERSION//\//_}"
11+
VISION_MODEL_VERSION="google/siglip-so400m-patch14-384"
12+
VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
13+
14+
############### Pretrain ################
15+
16+
PROMPT_VERSION=plain
17+
18+
BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain"
19+
echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
20+
21+
CKPT_PATH=$LLM_VERSION # this could also be the previous stage checkpoint
22+
23+
ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \
24+
llava/train/train_mem.py \
25+
--deepspeed scripts/zero3.json \
26+
--model_name_or_path ${CKPT_PATH} \
27+
--version ${PROMPT_VERSION} \
28+
--data_path ./onevision_data.yaml \
29+
--image_folder ./onevision_data/images \
30+
--video_folder ./onevision_data/videos \
31+
--pretrain_mm_mlp_adapter="/checkpoints/projectors/${BASE_RUN_NAME}/mm_projector.bin" \
32+
--mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
33+
--mm_vision_tower_lr=2e-6 \
34+
--vision_tower ${VISION_MODEL_VERSION} \
35+
--mm_projector_type mlp2x_gelu \
36+
--mm_vision_select_layer -2 \
37+
--mm_use_im_start_end False \
38+
--mm_use_im_patch_token False \
39+
--group_by_modality_length True \
40+
--image_aspect_ratio anyres_max_9 \
41+
--image_grid_pinpoints "(1x1),...,(6x6)" \
42+
--mm_patch_merge_type spatial_unpad \
43+
--bf16 True \
44+
--run_name $MID_RUN_NAME \
45+
--output_dir "/checkpoints/${MID_RUN_NAME}" \
46+
--num_train_epochs 1 \
47+
--per_device_train_batch_size 1 \
48+
--per_device_eval_batch_size 4 \
49+
--gradient_accumulation_steps 2 \
50+
--evaluation_strategy "no" \
51+
--save_strategy "steps" \
52+
--save_steps 1000 \
53+
--save_total_limit 1 \
54+
--learning_rate 1e-5 \
55+
--weight_decay 0. \
56+
--warmup_ratio 0.03 \
57+
--lr_scheduler_type "cosine" \
58+
--logging_steps 1 \
59+
--tf32 True \
60+
--model_max_length 32768 \
61+
--gradient_checkpointing True \
62+
--dataloader_num_workers 4 \
63+
--lazy_preprocess True \
64+
--report_to wandb \
65+
--torch_compile True \
66+
--torch_compile_backend "inductor" \
67+
--dataloader_drop_last True \
68+
--frames_upbound 32
69+
70+
# You can delete the sdpa attn_implementation if you want to use flash attn
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ PROMPT_VERSION=plain
1616
BASE_RUN_NAME="llavanext-${VISION_MODEL_VERSION_CLEAN}-${LLM_VERSION_CLEAN}-mlp2x_gelu-pretrain_blip558k_plain"
1717
echo "BASE_RUN_NAME: ${BASE_RUN_NAME}"
1818

19+
CKPT_PATH=$LLM_VERSION # this could also be the previous stage checkpoint
20+
1921
ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \
2022
llava/train/train_mem.py \
2123
--deepspeed scripts/zero3.json \
22-
--model_name_or_path ${LLM_VERSION} \
24+
--model_name_or_path ${CKPT_PATH} \
2325
--version ${PROMPT_VERSION} \
2426
--data_path=llava_1_6.json \
2527
--image_folder your_image_folder \

0 commit comments

Comments
 (0)