Skip to content

Commit 1dd1b49

Browse files
authored
Merge pull request #241 from LLaVA-VL/ov-chat-doc
add dpo training script to Ov chat doc
2 parents 56cdba2 + 6a6a9dd commit 1dd1b49

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

docs/LLaVA_OneVision_Chat.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,29 @@ Using the feedback data obtained in `Step 2`, we conduct DPO training in an iter
9494

9595
This iterative process is repeated for `N=3` rounds in total, with each round refining the model’s ability to generate high-quality visual chat responses by progressively incorporating feedback from both human and AI assessments.
9696

97+
**Training script and data format**
98+
99+
- Example training script: [`/scripts/train/dpo_ov7b.sh`](../scripts/train/dpo_ov7b.sh)
100+
- Format of training data:
101+
~~~json
102+
{
103+
"id": "<image-id>",
104+
"image": "<image path under args.image_folder>",
105+
"prompt": "<input prompt/question>",
106+
"chosen": "<chosen model response>",
107+
"rejected": "<rejected model response>"
108+
}
109+
~~~
110+
111+
97112
------
98113

99114
Stay tuned on how we develop AI feedback for self-improvement LMMs!
100115

101116
*Contributors to LLaVA-OneVision-Chat: [Tianyi Xiong](https://tyxiong23.github.io/), [Bo Li](https://brianboli.com/), [Dong Guo](https://www.linkedin.com/in/dongguoset/), [Huizhuo Yuan](https://scholar.google.com/citations?user=8foZzX4AAAAJ), [Quanquan Gu](https://web.cs.ucla.edu/~qgu/), [Chunyuan Li](https://scholar.google.com/citations?user=Zd7WmXUAAAAJ)*
102117

103118

119+
104120
### Citation
105121

106122
If you find it useful for your research and applications, please cite related papers/blogs using this BibTeX:

scripts/train/dpo_ov7b.sh

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
export OMP_NUM_THREADS=8
2+
export NCCL_IB_DISABLE=0
3+
export NCCL_IB_GID_INDEX=3
4+
# export NCCL_IB_HCA=${ARNOLD_RDMA_DEVICE}
5+
export NCCL_SOCKET_IFNAME=eth0
6+
export NCCL_DEBUG=INFO
7+
8+
VISION_MODEL_VERSION="google/siglip-so400m-patch14-384"
9+
VISION_MODEL_VERSION_CLEAN="${VISION_MODEL_VERSION//\//_}"
10+
11+
# DPO Stage
12+
PROMPT_VERSION="qwen_1_5"
13+
SFT_MODEL="lmms-lab/llava-onevision-qwen2-7b-ov"
14+
EPOCH=1
15+
beta=0.1
16+
17+
DPO_RUN_NAME="llava-onevision-qwen2-7b-ov_dpo-beta${beta}-epoch${EPOCH}"
18+
DPO_CLEAN_NAME="${DPO_RUN_NAME##*/}"
19+
OUTPUT_DIR="<your-output-folder>/${DPO_CLEAN_NAME}"
20+
DATA_PATH="<your-data-path>"
21+
22+
echo $DPO_RUN_NAME
23+
24+
ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${RANK}" --master_addr="${ADDR}" --master_port="${PORT}" \
25+
llava/train/train_dpo.py \
26+
--deepspeed scripts/zero3.json \
27+
--model_name_or_path=${SFT_MODEL} \
28+
--dpo_alpha=1.0 \
29+
--beta=${beta} \
30+
--gamma=0 \
31+
--version $PROMPT_VERSION \
32+
--data_path=$DATA_PATH \
33+
--image_folder "<your-image-folder>" \
34+
--mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
35+
--unfreeze_mm_vision_tower True \
36+
--vision_tower ${VISION_MODEL_VERSION} \
37+
--mm_projector_type mlp2x_gelu \
38+
--mm_vision_select_layer -2 \
39+
--mm_use_im_start_end False \
40+
--mm_use_im_patch_token False \
41+
--group_by_modality_length True \
42+
--image_aspect_ratio anyres_max_9 \
43+
--image_grid_pinpoints "(1x1),...,(6x6)" \
44+
--mm_patch_merge_type spatial_unpad \
45+
--bf16 True \
46+
--run_name $DPO_CLEAN_NAME \
47+
--output_dir $OUTPUT_DIR \
48+
--num_train_epochs $EPOCH \
49+
--per_device_train_batch_size 1 \
50+
--per_device_eval_batch_size 1 \
51+
--gradient_accumulation_steps 8 \
52+
--evaluation_strategy "no" \
53+
--save_strategy "steps" \
54+
--save_steps 1000 \
55+
--save_total_limit 1 \
56+
--learning_rate 5e-7 \
57+
--weight_decay 0. \
58+
--warmup_ratio 0.1 \
59+
--lr_scheduler_type "cosine" \
60+
--logging_steps 1 \
61+
--tf32 True \
62+
--model_max_length 32768 \
63+
--gradient_checkpointing True \
64+
--dataloader_num_workers 4 \
65+
--lazy_preprocess True \
66+
--report_to wandb \
67+
--dataloader_drop_last True
68+
69+

0 commit comments

Comments
 (0)