diff --git a/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py b/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py index 0c712a52b..96a76c437 100644 --- a/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py +++ b/paddlemix/examples/qwen2_vl/qwen2vl_finetune.py @@ -52,6 +52,7 @@ MegaByte = 2**20 PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte from paddlenlp.utils.log import logger +from paddlenlp.utils.tools import get_env_device # Set constants for image processing and logging IGNORE_INDEX = -100 @@ -360,7 +361,7 @@ def pure_text_get_item(self, data_item): attention_mask=attention_mask, images=[], ) - + return ret def __getitem__(self, i) -> Dict[str, paddle.Tensor]: @@ -473,11 +474,11 @@ def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "paddle.Tens batch_videos.extend(videos) batch_imglens.append(len(images)) batch_vidlens.append(len(videos)) - batch_input_ids.append(feature["input_ids"]) + batch_input_ids.append(feature["input_ids"]) if ( self.processor is not None and sum(batch_imglens) == 0 and sum(batch_vidlens) == 0 - ): + ): fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}] fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))] fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor) @@ -679,6 +680,16 @@ def main(): "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) + if get_env_device == "xpu" and training_args.gradient_accumulation_steps > 1: + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + # It's OK, not use accumulate_steps optimization + pass + # Load model if "npu" in paddle.get_device(): is_bfloat16_supported = True diff --git a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py index e84a4ca0b..7bdf6ac7c 100644 --- a/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py +++ b/paddlemix/models/qwen2_vl/modeling_qwen2_vl.py @@ -37,6 +37,7 @@ from paddlenlp.transformers.linear_utils import Linear from paddlenlp.transformers.model_outputs import BaseModelOutputWithPast, ModelOutput from paddlenlp.transformers.model_utils import PretrainedModel +from paddlenlp.utils.tools import get_env_device from paddlemix.models.flash_attn_utils import ( create_attention_module, @@ -48,6 +49,11 @@ from .bert_padding import index_first_axis, pad_input, unpad_input from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLVisionConfig +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + logger = logging.get_logger(__name__) flash_attn_func, flash_attn_varlen_func = has_flash_attn_func() @@ -407,7 +413,12 @@ def apply_rotary_pos_emb_vision(tensor: paddle.Tensor, freqs: paddle.Tensor) -> sin = freqs.sin() cos = cos.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32") sin = sin.unsqueeze(1).tile(repeat_times=[1, 1, 2]).unsqueeze(0).astype(dtype="float32") - output = tensor * cos + rotate_half(tensor) * sin + if get_env_device() == "xpu" and fused_rotary_position_embedding is not None: + output, _, _ = fused_rotary_position_embedding( + tensor, sin=sin, cos=cos, use_neox_rotary_style=False + ) + else: + output = tensor * cos + rotate_half(tensor) * sin output = paddle.cast(output, orig_dtype) return output @@ -463,6 +474,12 @@ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> N nn.GELU(), nn.Linear(self.hidden_size, dim), ) + if get_env_device() == "xpu": + self.mlp = nn.Sequential( + Linear(self.hidden_size, self.hidden_size), + nn.GELU(), + Linear(self.hidden_size, dim), + ) def forward(self, x: paddle.Tensor) -> paddle.Tensor: x = self.mlp(self.ln_q(x).reshape([-1, self.hidden_size])) @@ -475,6 +492,9 @@ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None: self.fc1 = nn.Linear(dim, hidden_dim) self.act = ACT2FN[hidden_act] self.fc2 = nn.Linear(hidden_dim, dim) + if get_env_device() == "xpu": + self.fc1 = Linear(dim, hidden_dim) + self.fc2 = Linear(hidden_dim, dim) def forward(self, x) -> paddle.Tensor: return self.fc2(self.act(self.fc1(x))) @@ -486,6 +506,9 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias_attr=True) self.proj = nn.Linear(dim, dim) + if get_env_device() == "xpu": + self.qkv = Linear(dim, dim * 3, bias_attr=True) + self.proj = Linear(dim, dim) self.head_dim = dim // num_heads # must added def forward( @@ -525,6 +548,9 @@ def __init__(self, dim: int, num_heads: int = 16) -> None: self.num_heads = num_heads self.qkv = nn.Linear(dim, dim * 3, bias_attr=True) self.proj = nn.Linear(dim, dim) + if get_env_device() == "xpu": + self.qkv = Linear(dim, dim * 3, bias_attr=True) + self.proj = Linear(dim, dim) self.head_dim = dim // num_heads # must added def forward( @@ -657,6 +683,15 @@ def __init__(self, config: Qwen2VLConfig, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): + if get_env_device() == "xpu": + try: + import paddle_xpu_nn # noqa: F821 + + return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + except ImportError: + raise NotImplementedError( + f"Implementation of fused_rms_norm is not available on xpu. Please install paddle_xpu to use this feature" + ) if paddle.in_dynamic_mode(): with paddle.amp.auto_cast(False): variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) @@ -1193,7 +1228,7 @@ class Qwen2VLPreTrainedModel(PretrainedModel): def _init_weights(self, layer): std = 0.2 - if isinstance(layer, (nn.Linear, nn.Conv3D)): + if isinstance(layer, (nn.Linear, nn.Conv3D, Linear)): nn.initializer.Normal(mean=0.0, std=std)(layer.weight) if layer.bias is not None: nn.initializer.Constant(0.0)(layer.bias) @@ -1558,6 +1593,9 @@ def __init__(self, config, embedding_weights=None, transpose_y=False): shape=[config.hidden_size, vocab_size], dtype=paddle.get_default_dtype(), ) + if get_env_device() == "xpu": + import paddle_xpu.layers.nn.linear as xpu_linear + self.xpu_parallel_matmul = xpu_linear.parallel_matmul() # Must set distributed attr for Tensor Parallel ! self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False @@ -1573,9 +1611,14 @@ def forward(self, hidden_states, tensor_parallel_output=None): if self.weight.dtype != hidden_states.dtype: hidden_states = paddle.cast(hidden_states, self.weight.dtype) - logits = parallel_matmul( - hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output - ) + if get_env_device() == "xpu": + logits = self.xpu_parallel_matmul.forward( + hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output + ) + else: + logits = parallel_matmul( + hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output + ) return logits