Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions paddlemix/examples/qwen2_5_vl/merge_tensor_parallel_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from collections import OrderedDict

import paddle

from paddlemix.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLForConditionalGeneration,
)


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name_or_path", default=None, required=True, help="The directory of pretrained model.")
parser.add_argument("--merge_model_path", default=None, help="The directory of merged parameters. Default to None")
parser.add_argument("--device", type=str, default="gpu", help="Device")
parser.add_argument("--dtype", type=str, default="bfloat16", help="dtype")
parser.add_argument("--tensor_parallel_degree", type=int, default=2, help="tp_degree")
return parser.parse_args()


def merge():
args = parse_arguments()
paddle.set_device(args.device)
config = Qwen2_5_VLConfig.from_pretrained(args.model_name_or_path)
config.tensor_parallel_degree = 1
# Qwen2_5_VLForConditionalGeneration.from_pretrained(pretrained_model_name_or_path, dtype=dtype, attn_implementation="flash_attention_2", config = config)

# config = Qwen2_5_VLConfig()
merge_mapping = Qwen2_5_VLForConditionalGeneration._get_tensor_parallel_mappings(config, is_split=False)

# rootdir = 'work_dirs/baseline_330k_3b_bs32_1e8_debug_parallel_tp2_gpu4'
state_dicts = []

for i in range(args.tensor_parallel_degree):
other_rank_file = os.path.join(args.model_name_or_path, "model_state.tp{:0>2d}.pdparams".format(i))
state_dicts.append(paddle.load(other_rank_file))

merged_state_dict = OrderedDict()
for k, v in state_dicts[0].items():
map_k = k.replace("model.", "")
if map_k in merge_mapping:
v_lst = []
for j in range(args.tensor_parallel_degree):
v_lst.append(state_dicts[j][k])
new_v = merge_mapping[map_k](v_lst)
print(f"key: {k}, merged weight shape: {new_v.shape}")
else:
new_v = v

merged_state_dict[k] = new_v
complete_save_file = os.path.join(args.model_name_or_path, "model_state.pdparams")
paddle.save(merged_state_dict, complete_save_file)


if __name__ == "__main__":
merge()
114 changes: 98 additions & 16 deletions paddlemix/examples/qwen2_5_vl/qwen2_5_vl_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,33 @@
import sys
import traceback
from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence, Any
from typing import Any, Dict, Optional, Sequence

import numpy as np
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
from paddle.io import Dataset
from paddlenlp.data import DataCollatorForSeq2Seq
from paddlenlp.peft import LoRAConfig, LoRAModel
from paddlenlp.trainer import PdArgumentParser, TrainingArguments, set_seed
from paddlenlp.trainer.trainer import Trainer
from paddlenlp.trainer.trainer_utils import get_last_checkpoint
from paddlenlp.transformers.processing_utils import ProcessorMixin
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError

from paddlemix.datasets.internvl_dataset import ConcatDataset, WeightedConcatDataset
from paddlemix.models.qwen2_5_vl import MIXQwen2_5_Tokenizer
from paddlemix.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from paddlemix.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLForConditionalGeneration,
)
from paddlemix.models.qwen2_5_vl.supervised import _encode_supervised_example
from paddlemix.models.qwen2_5_vl.template import TEMPLATES
from paddlemix.processors.qwen2_5_vl_processing import Qwen2_5_VLImageProcessor, Qwen2_5_VLProcessor
from paddlenlp.transformers.processing_utils import ProcessorMixin
from paddlemix.processors.qwen2_5_vl_processing import (
Qwen2_5_VLImageProcessor,
Qwen2_5_VLProcessor,
)

Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
Expand All @@ -55,6 +62,71 @@
IMAGE_PLACEHOLDER = "<image>"


def set_hyrbid_parallel_seed(basic_seed, data_world_rank, mp_rank, pp_rank=0):
device_id = paddle.device.get_device()
assert "gpu" in device_id

random.seed(basic_seed + data_world_rank)
np.random.seed(basic_seed + data_world_rank)
paddle.seed(basic_seed + data_world_rank)

# local_seed/ global_seed is used to control dropout in ModelParallel
local_seed = 1024 + basic_seed + mp_rank * 100 + data_world_rank
global_seed = 2048 + basic_seed + data_world_rank
tracker = get_rng_state_tracker()
tracker.add("global_seed", global_seed)
tracker.add("local_seed", local_seed)


def setdistenv(args):
world_size = dist.get_world_size()
if world_size > 1:
args.dp_degree = max(args.data_parallel_degree, 1)
args.sharding_parallel_degree = max(args.sharding_parallel_degree, 1)
args.tensor_parallel_degree = max(args.tensor_parallel_degree, 1)
args.sep_parallel_degree = max(args.sep_parallel_degree, 1)
args.pipeline_parallel_degree = max(args.pipeline_parallel_degree, 1)

assert (
world_size % (args.tensor_parallel_degree * args.pipeline_parallel_degree) == 0
), f"Total world_size:{world_size} should be divided by tensor_parallel_degree: {args.tensor_parallel_degree} and pipeline_parallel_degree: {args.pipeline_parallel_degree}."

args.dp_degree = world_size // (
args.tensor_parallel_degree * args.sharding_parallel_degree * args.pipeline_parallel_degree
)
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": args.dp_degree,
"mp_degree": args.tensor_parallel_degree,
"sharding_degree": args.sharding_parallel_degree,
"pp_degree": args.pipeline_parallel_degree,
}
# strategy.find_unused_parameters = True

# set control in tensor parallel
strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed}

fleet.init(is_collective=True, strategy=strategy)

args.rank = dist.get_rank()
# obtain rank message of hybrid parallel
hcg = fleet.get_hybrid_communicate_group()
args.mp_rank = hcg.get_model_parallel_rank()
args.dp_rank = hcg.get_data_parallel_rank()
args.sharding_rank = hcg.get_sharding_parallel_rank()

args.data_world_rank = args.dp_rank * args.sharding_parallel_degree + args.sharding_rank
args.data_world_size = world_size // abs(args.tensor_parallel_degree * args.pipeline_parallel_degree)
else:
args.data_world_rank = 0
args.data_world_size = 1
args.mp_rank = 0
args.rank = 0

# seed control in hybrid parallel
set_hyrbid_parallel_seed(args.seed, args.data_world_rank, args.mp_rank)


@dataclass
class ProcessorArguments:
r"""
Expand Down Expand Up @@ -352,7 +424,7 @@ def pure_text_get_item(self, data_item):
attention_mask=attention_mask,
images=[],
)

return ret

def __getitem__(self, i) -> Dict[str, paddle.Tensor]:
Expand Down Expand Up @@ -457,7 +529,7 @@ def __post_init__(self):

def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tensor"]:
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []

for feature in features:
images = feature.pop("images", None) or []
videos = feature.pop("videos", None) or []
Expand All @@ -467,9 +539,7 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
batch_vidlens.append(len(videos))
batch_input_ids.append(feature["input_ids"])

if (
self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0
):
if self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0:
fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]
fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]
fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)
Expand All @@ -480,12 +550,16 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens

if len(fake_input_ids) != 0:
if self.tokenizer.padding_side == "right":
features[0]["input_ids"] = features[0]["input_ids"]+ fake_input_ids["input_ids"]
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids["input_ids"])
features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids["input_ids"]
features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(
fake_input_ids["input_ids"]
)
features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids["input_ids"])
else:
features[0]["input_ids"] = fake_input_ids["input_ids"] + features[0]["input_ids"]
features[0]["attention_mask"] = [0] * len(fake_input_ids["input_ids"]) + features[0]["attention_mask"]
features[0]["attention_mask"] = [0] * len(fake_input_ids["input_ids"]) + features[0][
"attention_mask"
]
features[0]["labels"] = [IGNORE_INDEX] * len(fake_input_ids["input_ids"]) + features[0]["labels"]

batch_images = fake_images
Expand Down Expand Up @@ -514,8 +588,6 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens

features["position_ids"], features["rope_deltas"] = self.model.get_rope_index(**rope_index_kwargs)



if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled
cross_attention_mask = mm_inputs.pop("cross_attention_mask")
seq_len = features["input_ids"].size(1)
Expand All @@ -534,7 +606,6 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens
return features



def main():
parser = PdArgumentParser((ModelArguments, DataTrainingArguments, PreTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
Expand All @@ -547,6 +618,10 @@ def main():
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")

setdistenv(training_args)
hcg = fleet.get_hybrid_communicate_group()
tensor_parallel_rank = hcg.get_model_parallel_rank()

# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
Expand Down Expand Up @@ -592,7 +667,14 @@ def main():
print(f"Loading Tokenizer: {tokenizer_path}")

MODEL_NAME = model_args.model_name_or_path
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(MODEL_NAME, dtype=dtype, attn_implementation="flash_attention_2")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_NAME,
tensor_parallel_degree=training_args.tensor_parallel_degree,
tensor_parallel_rank=tensor_parallel_rank,
tensor_parallel_output=False,
dtype=dtype,
attn_implementation="flash_attention_2",
) # ,tensor_parallel_output=False
image_processor = Qwen2_5_VLImageProcessor()
tokenizer = MIXQwen2_5_Tokenizer.from_pretrained(MODEL_NAME, padding_side="right")
processor = Qwen2_5_VLProcessor(image_processor, tokenizer)
Expand Down
15 changes: 12 additions & 3 deletions paddlemix/examples/qwen2_5_vl/single_image_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
import paddle

from paddlemix.models.qwen2_5_vl import MIXQwen2_5_Tokenizer
from paddlemix.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from paddlemix.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLConfig,
Qwen2_5_VLForConditionalGeneration,
)
from paddlemix.processors.qwen2_5_vl_processing import (
Qwen2_5_VLImageProcessor,
Qwen2_5_VLProcessor,
process_vision_info,
)
from paddlemix.utils.log import logger


def main(args):
paddle.seed(seed=0)
compute_dtype = args.dtype
Expand All @@ -38,7 +42,11 @@ def main(args):
print("compute_dtype", compute_dtype)

paddle.set_default_dtype(compute_dtype)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(args.model_path, dtype=compute_dtype, attn_implementation=args.attn_implementation)
config = Qwen2_5_VLConfig.from_pretrained(args.model_path)
config.tensor_parallel_degree = 1
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
args.model_path, config=config, dtype=compute_dtype, attn_implementation=args.attn_implementation
)

image_processor = Qwen2_5_VLImageProcessor()
tokenizer = MIXQwen2_5_Tokenizer.from_pretrained(args.model_path)
Expand Down Expand Up @@ -74,6 +82,7 @@ def main(args):

if args.benchmark:
import time

start = 0.0
total = 0.0
for i in range(20):
Expand Down Expand Up @@ -123,4 +132,4 @@ def main(args):
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--attn_implementation", type=str, default="eager")
args = parser.parse_args()
main(args)
main(args)