Skip to content

Commit e1d7813

Browse files
👁️ Added SFT support for SmolVLM models via standalone script sft_vlm_smol_vlm.py (huggingface#2409)
* Added SFT VLM script for SmolVLM * Run make precommit * Updated command example
1 parent a34e9bf commit e1d7813

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
pip install pillow
16+
17+
# Tested on 8x H100 GPUs
18+
accelerate launch
19+
--config_file=examples/accelerate_configs/deepspeed_zero3.yaml \
20+
sft_vlm_smol_vlm.py \
21+
--dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
22+
--model_name_or_path HuggingFaceTB/SmolVLM-Instruct \
23+
--per_device_train_batch_size 1 \
24+
--gradient_accumulation_steps 1 \
25+
--output_dir sft-smol-vlm-hf \
26+
--bf16 \
27+
--torch_dtype bfloat16 \
28+
--gradient_checkpointing \
29+
--use_peft \
30+
--lora_target_modules down_proj, o_proj, k_proj, q_proj, gate_proj, up_proj, v_proj
31+
32+
For LLaVA-NeXT, use: (requires transformers>=4.45)
33+
--model_name_or_path llava-hf/llava-v1.6-mistral-7b-hf
34+
35+
For meta-llama/Llama-3.2-11B-Vision-Instruct, use: (requires transformers>=4.45.1)
36+
--model_name_or_path meta-llama/Llama-3.2-11B-Vision-Instruct
37+
"""
38+
39+
import torch
40+
from datasets import load_dataset
41+
from transformers import (
42+
AutoModelForVision2Seq,
43+
AutoProcessor,
44+
Idefics3ForConditionalGeneration,
45+
LlavaForConditionalGeneration,
46+
)
47+
48+
from trl import (
49+
ModelConfig,
50+
ScriptArguments,
51+
SFTConfig,
52+
SFTTrainer,
53+
TrlParser,
54+
get_kbit_device_map,
55+
get_peft_config,
56+
get_quantization_config,
57+
)
58+
59+
60+
if __name__ == "__main__":
61+
parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
62+
script_args, training_args, model_config = parser.parse_args_and_config()
63+
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
64+
training_args.remove_unused_columns = False
65+
training_args.dataset_kwargs = {"skip_prepare_dataset": True}
66+
67+
################
68+
# Model, Tokenizer & Processor
69+
################
70+
torch_dtype = (
71+
model_config.torch_dtype
72+
if model_config.torch_dtype in ["auto", None]
73+
else getattr(torch, model_config.torch_dtype)
74+
)
75+
quantization_config = get_quantization_config(model_config)
76+
model_kwargs = dict(
77+
revision=model_config.model_revision,
78+
attn_implementation=model_config.attn_implementation,
79+
torch_dtype=torch_dtype,
80+
device_map=get_kbit_device_map() if quantization_config is not None else None,
81+
quantization_config=quantization_config,
82+
)
83+
processor = AutoProcessor.from_pretrained(
84+
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code
85+
)
86+
87+
model = AutoModelForVision2Seq.from_pretrained(
88+
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
89+
)
90+
91+
################
92+
# Create a data collator to encode text and image pairs
93+
################
94+
def collate_fn(examples):
95+
# Get the texts and images, and apply the chat template
96+
texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
97+
images = [example["images"] for example in examples]
98+
if isinstance(model, LlavaForConditionalGeneration):
99+
# LLava1.5 does not support multiple images
100+
images = [image[0] for image in images]
101+
102+
# Tokenize the texts and process the images
103+
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
104+
105+
# The labels are the input_ids, and we mask the padding tokens in the loss computation
106+
labels = batch["input_ids"].clone()
107+
labels[labels == processor.tokenizer.pad_token_id] = -100 #
108+
# Ignore the image token index in the loss computation (model specific)
109+
if isinstance(model, Idefics3ForConditionalGeneration):
110+
image_token_id = processor.tokenizer.additional_special_tokens_ids[
111+
processor.tokenizer.additional_special_tokens.index("<image>")
112+
]
113+
else:
114+
image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
115+
labels[labels == image_token_id] = -100
116+
batch["labels"] = labels
117+
118+
return batch
119+
120+
################
121+
# Dataset
122+
################
123+
dataset = load_dataset(script_args.dataset_name)
124+
125+
################
126+
# Training
127+
################
128+
trainer = SFTTrainer(
129+
model=model,
130+
args=training_args,
131+
data_collator=collate_fn,
132+
train_dataset=dataset[script_args.dataset_train_split],
133+
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
134+
processing_class=processor.tokenizer,
135+
peft_config=get_peft_config(model_config),
136+
)
137+
138+
trainer.train()
139+
140+
# Save and push to hub
141+
trainer.save_model(training_args.output_dir)
142+
if training_args.push_to_hub:
143+
trainer.push_to_hub(dataset_name=script_args.dataset_name)
144+
if trainer.accelerator.is_main_process:
145+
processor.push_to_hub(training_args.hub_model_id)

0 commit comments

Comments
 (0)