Skip to content

Commit 3ec0419

Browse files
authored
add multi-batch training strategy (#1012)
1 parent 6ac30ff commit 3ec0419

File tree

1 file changed

+97
-21
lines changed

1 file changed

+97
-21
lines changed

paddlemix/examples/qwen2_vl/qwen2vl_finetune.py

Lines changed: 97 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import sys
2121
import traceback
2222
from dataclasses import dataclass, field
23-
from typing import Dict, Optional
23+
from typing import Dict, Optional, Sequence, Any
2424

2525
import numpy as np
2626
import paddle
@@ -42,13 +42,13 @@
4242
Qwen2VLImageProcessor,
4343
Qwen2VLProcessor,
4444
)
45+
from paddlenlp.transformers.processing_utils import ProcessorMixin
4546

4647
Image.MAX_IMAGE_PIXELS = None
4748
ImageFile.LOAD_TRUNCATED_IMAGES = True
4849
MaximumDecompressedSize = 1024
4950
MegaByte = 2**20
5051
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
51-
5252
logger = logging.getLogger(__name__)
5353

5454

@@ -303,8 +303,6 @@ def multi_modal_get_item(self, data_item):
303303

304304
# Merge the image path
305305
image_path = self.get_image_path(data_item["images"][0]) # TODO: now only single image
306-
image = self.load_image(image_path)
307-
image_data_dict = transform(image)
308306

309307
messages = data_item["messages"]
310308

@@ -328,19 +326,11 @@ def multi_modal_get_item(self, data_item):
328326
input_ids=input_ids,
329327
labels=labels,
330328
attention_mask=attention_mask,
331-
pixel_values=image_data_dict["pixel_values"],
332-
image_grid_thw=image_data_dict["image_grid_thw"][0],
329+
images=[image_path],
333330
)
334331
return ret
335332

336333
def pure_text_get_item(self, data_item):
337-
# Build transformation function
338-
transform = self.get_transform()
339-
340-
# Create a blank white image
341-
image = Image.new("RGB", (224, 224), (255, 255, 255))
342-
image_data_dict = transform(image)
343-
344334
messages = data_item["messages"]
345335

346336
input_ids, labels = _encode_supervised_example(
@@ -363,9 +353,9 @@ def pure_text_get_item(self, data_item):
363353
input_ids=input_ids,
364354
labels=labels,
365355
attention_mask=attention_mask,
366-
pixel_values=image_data_dict["pixel_values"],
367-
image_grid_thw=image_data_dict["image_grid_thw"][0],
356+
images=[],
368357
)
358+
369359
return ret
370360

371361
def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
@@ -374,10 +364,6 @@ def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
374364
try:
375365
data_item = self.raw_data[i]
376366
if "images" in data_item and len(data_item["images"]) != 0:
377-
# if type(data_item['images']) == list:
378-
# ret = self.multi_modal_multi_image_get_item(data_item)
379-
# else:
380-
# ret = self.multi_modal_get_item(data_item)
381367
ret = self.multi_modal_get_item(data_item) # TODO: 暂时都是单图
382368
else:
383369
ret = self.pure_text_get_item(data_item) # TODO: 纯文
@@ -457,6 +443,94 @@ def print_trainable_params(model: paddle.nn.Layer) -> None:
457443
)
458444

459445

446+
@dataclass
447+
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
448+
r"""
449+
Data collator that supports VLMs.
450+
451+
Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.
452+
"""
453+
454+
template: Optional["TEMPLATES"] = None
455+
processor: Optional["ProcessorMixin"] = None
456+
457+
def __post_init__(self):
458+
if self.template is None:
459+
raise ValueError("Template is required for MultiModalDataCollator.")
460+
461+
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tensor"]:
462+
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []
463+
464+
for feature in features:
465+
images = feature.pop("images", None) or []
466+
videos = feature.pop("videos", None) or []
467+
batch_images.extend(images)
468+
batch_videos.extend(videos)
469+
batch_imglens.append(len(images))
470+
batch_vidlens.append(len(videos))
471+
batch_input_ids.append(feature["input_ids"])
472+
473+
if (
474+
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
475+
):
476+
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
477+
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
478+
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
479+
fake_input_ids = self.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)
480+
fake_input_ids, _ = self.template.mm_plugin.process_token_ids(
481+
fake_input_ids, None, fake_images, [], self.tokenizer, self.processor
482+
)
483+
484+
if self.tokenizer.padding_side == "right":
485+
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids
486+
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)
487+
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)
488+
else:
489+
features[0]["input_ids"] = fake_input_ids + features[0]["input_ids"]
490+
features[0]["attention_mask"] = [0] * len(fake_input_ids) + features[0]["attention_mask"]
491+
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids) + features[0]["labels"]
492+
493+
batch_images = fake_images
494+
batch_imglens[0] = 1
495+
batch_input_ids[0] = features[0]["input_ids"]
496+
497+
mm_inputs = self.template.mm_plugin.get_mm_inputs(
498+
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor
499+
)
500+
if "token_type_ids" in mm_inputs:
501+
token_type_ids = mm_inputs.pop("token_type_ids")
502+
for i, feature in enumerate(features):
503+
feature["token_type_ids"] = token_type_ids[i]
504+
505+
features: Dict[str, "paddle.Tensor"] = super().__call__(features)
506+
507+
if self.model is not None and hasattr(self.model, "get_rope_index"): # for qwen2vl mrope
508+
features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(
509+
input_ids=features["input_ids"],
510+
image_grid_thw=mm_inputs.get("image_grid_thw", None),
511+
video_grid_thw=mm_inputs.get("video_grid_thw", None),
512+
attention_mask=features["attention_mask"],
513+
)
514+
515+
if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
516+
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
517+
seq_len = features["input_ids"].size(1)
518+
orig_len = cross_attention_mask.size(1)
519+
mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))
520+
521+
features.update(mm_inputs)
522+
if isinstance(features.get("pixel_values"), list): # for pixtral inputs
523+
features = features.data # use default_collate() instead of BatchEncoding.to()
524+
525+
if "image_bound" in features: # for minicpmv inputs
526+
bsz, seq_length = features["input_ids"].shape
527+
features["position_ids"] = paddle.arange(seq_length).long().repeat(bsz, 1)
528+
return {"data": features, "input_ids": features["input_ids"], "labels": features["labels"]}
529+
530+
return features
531+
532+
533+
460534
def main():
461535
parser = PdArgumentParser((ModelArguments, DataTrainingArguments, PreTrainingArguments))
462536
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
@@ -516,7 +590,7 @@ def main():
516590
MODEL_NAME = model_args.model_name_or_path
517591
model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype=dtype)
518592
image_processor = Qwen2VLImageProcessor.from_pretrained(MODEL_NAME)
519-
tokenizer = MIXQwen2Tokenizer.from_pretrained(MODEL_NAME)
593+
tokenizer = MIXQwen2Tokenizer.from_pretrained(MODEL_NAME, padding_side="right")
520594
processor = Qwen2VLProcessor(image_processor, tokenizer)
521595

522596
tokenizer.tokenizer_path = tokenizer_path
@@ -578,8 +652,10 @@ def _freeze_params(module):
578652
# set seed for paddle dataloaders
579653
set_seed(training_args.seed)
580654

581-
data_collator = DataCollatorForSeq2Seq(
655+
data_collator = MultiModalDataCollatorForSeq2Seq(
582656
tokenizer=tokenizer,
657+
template=TEMPLATES[data_args.conv_style],
658+
processor=processor,
583659
pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention
584660
label_pad_token_id=IGNORE_INDEX,
585661
)

0 commit comments

Comments
 (0)