Skip to content

Commit d66ea24

Browse files
committed
revert video change
1 parent c97bb24 commit d66ea24

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

examples/scripts/sft_video_llm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from datasets import load_dataset
6363
from peft import LoraConfig
6464
from qwen_vl_utils import process_vision_info
65-
from transformers import AutoModelForImageTextToText, BitsAndBytesConfig, Qwen2VLProcessor
65+
from transformers import AutoModelForImageTextToText, AutoProcessor, BitsAndBytesConfig, Qwen2VLProcessor
6666

6767
from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_kbit_device_map
6868

@@ -224,6 +224,10 @@ class CustomScriptArguments(ScriptArguments):
224224
model.config.use_reentrant = False
225225
model.enable_input_require_grads()
226226

227+
processor = AutoProcessor.from_pretrained(
228+
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
229+
)
230+
227231
# Prepare dataset
228232
prepared_dataset = [prepare_dataset(example, script_args.video_cache_dir) for example in dataset]
229233

@@ -234,6 +238,7 @@ class CustomScriptArguments(ScriptArguments):
234238
train_dataset=prepared_dataset,
235239
data_collator=collate_fn,
236240
peft_config=peft_config,
241+
processing_class=processor,
237242
)
238243

239244
# Train model

0 commit comments

Comments
 (0)